/**
 * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
 * SPDX-License-Identifier: Apache-2.0.
 */

#include <aws/http/private/websocket_decoder.h>

/* TODO: decoder logging */

typedef int(state_fn)(struct aws_websocket_decoder *decoder, struct aws_byte_cursor *data);

/* STATE_INIT: Resets things, consumes no data */
static int s_state_init(struct aws_websocket_decoder *decoder, struct aws_byte_cursor *data) {
    (void)data;
    AWS_ZERO_STRUCT(decoder->current_frame);
    decoder->state = AWS_WEBSOCKET_DECODER_STATE_OPCODE_BYTE;
    return AWS_OP_SUCCESS;
}

/* STATE_OPCODE_BYTE: Decode first byte of frame, which has all kinds of goodies in it. */
static int s_state_opcode_byte(struct aws_websocket_decoder *decoder, struct aws_byte_cursor *data) {
    if (data->len == 0) {
        return AWS_OP_SUCCESS;
    }

    uint8_t byte = data->ptr[0];
    aws_byte_cursor_advance(data, 1);

    /* first 4 bits are all bools */
    decoder->current_frame.fin = byte & 0x80;
    decoder->current_frame.rsv[0] = byte & 0x40;
    decoder->current_frame.rsv[1] = byte & 0x20;
    decoder->current_frame.rsv[2] = byte & 0x10;

    /* next 4 bits are opcode */
    decoder->current_frame.opcode = byte & 0x0F;

    /* RFC-6455 Section 5.2 - Opcode
     * If an unknown opcode is received, the receiving endpoint MUST _Fail the WebSocket Connection_. */
    switch (decoder->current_frame.opcode) {
        case AWS_WEBSOCKET_OPCODE_CONTINUATION:
        case AWS_WEBSOCKET_OPCODE_TEXT:
        case AWS_WEBSOCKET_OPCODE_BINARY:
        case AWS_WEBSOCKET_OPCODE_CLOSE:
        case AWS_WEBSOCKET_OPCODE_PING:
        case AWS_WEBSOCKET_OPCODE_PONG:
            break;
        default:
            return aws_raise_error(AWS_ERROR_HTTP_PROTOCOL_ERROR);
    }

    /* RFC-6455 Section 5.2 Fragmentation
     *
     * Data frames with the FIN bit clear are considered fragmented and must be followed by
     * 1+ CONTINUATION frames, where only the final CONTINUATION frame's FIN bit is set.
     *
     * Control frames may be injected in the middle of a fragmented message,
     * but control frames may not be fragmented themselves.
     */
    if (aws_websocket_is_data_frame(decoder->current_frame.opcode)) {
        bool is_continuation_frame = AWS_WEBSOCKET_OPCODE_CONTINUATION == decoder->current_frame.opcode;

        if (decoder->expecting_continuation_data_frame != is_continuation_frame) {
            return aws_raise_error(AWS_ERROR_HTTP_PROTOCOL_ERROR);
        }

        decoder->expecting_continuation_data_frame = !decoder->current_frame.fin;

    } else {
        /* Control frames themselves MUST NOT be fragmented. */
        if (!decoder->current_frame.fin) {
            return aws_raise_error(AWS_ERROR_HTTP_PROTOCOL_ERROR);
        }
    }

    decoder->state = AWS_WEBSOCKET_DECODER_STATE_LENGTH_BYTE;
    return AWS_OP_SUCCESS;
}

/* STATE_LENGTH_BYTE: Decode byte containing length, determine if we need to decode extended length. */
static int s_state_length_byte(struct aws_websocket_decoder *decoder, struct aws_byte_cursor *data) {
    if (data->len == 0) {
        return AWS_OP_SUCCESS;
    }

    uint8_t byte = data->ptr[0];
    aws_byte_cursor_advance(data, 1);

    /* first bit is a bool */
    decoder->current_frame.masked = byte & 0x80;

    /* remaining 7 bits are payload length */
    decoder->current_frame.payload_length = byte & 0x7F;

    if (decoder->current_frame.payload_length >= AWS_WEBSOCKET_7BIT_VALUE_FOR_2BYTE_EXTENDED_LENGTH) {
        /* If 7bit payload length has a high value, then the next few bytes contain the real payload length */
        decoder->state_bytes_processed = 0;
        decoder->state = AWS_WEBSOCKET_DECODER_STATE_EXTENDED_LENGTH;
    } else {
        /* If 7bit payload length has low value, that's the actual payload size, jump past EXTENDED_LENGTH state */
        decoder->state = AWS_WEBSOCKET_DECODER_STATE_MASKING_KEY_CHECK;
    }

    return AWS_OP_SUCCESS;
}

/* STATE_EXTENDED_LENGTH: Decode extended length (state skipped if no extended length). */
static int s_state_extended_length(struct aws_websocket_decoder *decoder, struct aws_byte_cursor *data) {
    if (data->len == 0) {
        return AWS_OP_SUCCESS;
    }

    /* The 7bit payload value loaded during the previous state indicated that
     * actual payload length is encoded across the next 2 or 8 bytes. */
    uint8_t total_bytes_extended_length;
    uint64_t min_acceptable_value;
    uint64_t max_acceptable_value;
    if (decoder->current_frame.payload_length == AWS_WEBSOCKET_7BIT_VALUE_FOR_2BYTE_EXTENDED_LENGTH) {
        total_bytes_extended_length = 2;
        min_acceptable_value = AWS_WEBSOCKET_2BYTE_EXTENDED_LENGTH_MIN_VALUE;
        max_acceptable_value = AWS_WEBSOCKET_2BYTE_EXTENDED_LENGTH_MAX_VALUE;
    } else {
        AWS_ASSERT(decoder->current_frame.payload_length == AWS_WEBSOCKET_7BIT_VALUE_FOR_8BYTE_EXTENDED_LENGTH);

        total_bytes_extended_length = 8;
        min_acceptable_value = AWS_WEBSOCKET_8BYTE_EXTENDED_LENGTH_MIN_VALUE;
        max_acceptable_value = AWS_WEBSOCKET_8BYTE_EXTENDED_LENGTH_MAX_VALUE;
    }

    /* Copy bytes of extended-length to state_cache, we'll process them later.*/
    AWS_ASSERT(total_bytes_extended_length > decoder->state_bytes_processed);

    size_t remaining_bytes = (size_t)(total_bytes_extended_length - decoder->state_bytes_processed);
    size_t bytes_to_consume = remaining_bytes <= data->len ? remaining_bytes : data->len;

    AWS_ASSERT(bytes_to_consume + decoder->state_bytes_processed <= sizeof(decoder->state_cache));

    memcpy(decoder->state_cache + decoder->state_bytes_processed, data->ptr, bytes_to_consume);

    aws_byte_cursor_advance(data, bytes_to_consume);
    decoder->state_bytes_processed += bytes_to_consume;

    /* Return, still waiting on more bytes */
    if (decoder->state_bytes_processed < total_bytes_extended_length) {
        return AWS_OP_SUCCESS;
    }

    /* All bytes have been copied into state_cache, now read them together as one number,
     * transforming from network byte order (big endian) to native endianness. */
    struct aws_byte_cursor cache_cursor = aws_byte_cursor_from_array(decoder->state_cache, total_bytes_extended_length);
    if (total_bytes_extended_length == 2) {
        uint16_t val;
        if (!aws_byte_cursor_read_be16(&cache_cursor, &val)) {
            return aws_raise_error(AWS_ERROR_HTTP_PROTOCOL_ERROR);
        }

        decoder->current_frame.payload_length = val;
    } else {
        if (!aws_byte_cursor_read_be64(&cache_cursor, &decoder->current_frame.payload_length)) {
            return aws_raise_error(AWS_ERROR_HTTP_PROTOCOL_ERROR);
        }
    }

    if (decoder->current_frame.payload_length < min_acceptable_value ||
        decoder->current_frame.payload_length > max_acceptable_value) {

        return aws_raise_error(AWS_ERROR_HTTP_PROTOCOL_ERROR);
    }

    decoder->state = AWS_WEBSOCKET_DECODER_STATE_MASKING_KEY_CHECK;
    return AWS_OP_SUCCESS;
}

/* MASKING_KEY_CHECK: Determine if we need to decode masking-key. Consumes no data. */
static int s_state_masking_key_check(struct aws_websocket_decoder *decoder, struct aws_byte_cursor *data) {
    (void)data;

    /* If mask bit was set, move to next state to process 4 bytes of masking key.
     * Otherwise skip next step, there is no masking key. */
    if (decoder->current_frame.masked) {
        decoder->state = AWS_WEBSOCKET_DECODER_STATE_MASKING_KEY;
        decoder->state_bytes_processed = 0;
    } else {
        decoder->state = AWS_WEBSOCKET_DECODER_STATE_PAYLOAD_CHECK;
    }

    return AWS_OP_SUCCESS;
}

/* MASKING_KEY: Decode masking-key (state skipped if no masking key). */
static int s_state_masking_key(struct aws_websocket_decoder *decoder, struct aws_byte_cursor *data) {
    if (data->len == 0) {
        return AWS_OP_SUCCESS;
    }

    AWS_ASSERT(4 > decoder->state_bytes_processed);
    size_t bytes_remaining = 4 - (size_t)decoder->state_bytes_processed;
    size_t bytes_to_consume = bytes_remaining < data->len ? bytes_remaining : data->len;

    memcpy(decoder->current_frame.masking_key + decoder->state_bytes_processed, data->ptr, bytes_to_consume);

    aws_byte_cursor_advance(data, bytes_to_consume);
    decoder->state_bytes_processed += bytes_to_consume;

    /* If all bytes consumed, proceed to next state */
    if (decoder->state_bytes_processed == 4) {
        decoder->state = AWS_WEBSOCKET_DECODER_STATE_PAYLOAD_CHECK;
    }

    return AWS_OP_SUCCESS;
}

/* PAYLOAD_CHECK: Determine if we need to decode a payload. Consumes no data. */
static int s_state_payload_check(struct aws_websocket_decoder *decoder, struct aws_byte_cursor *data) {
    (void)data;

    /* Invoke on_frame() callback to inform user of non-payload data. */
    int err = decoder->on_frame(&decoder->current_frame, decoder->user_data);
    if (err) {
        return AWS_OP_ERR;
    }

    /* Choose next state: either we have payload to process or we don't. */
    if (decoder->current_frame.payload_length > 0) {
        decoder->state_bytes_processed = 0;
        decoder->state = AWS_WEBSOCKET_DECODER_STATE_PAYLOAD;
    } else {
        decoder->state = AWS_WEBSOCKET_DECODER_STATE_DONE;
    }

    return AWS_OP_SUCCESS;
}

/* PAYLOAD: Decode payload until we're done (state skipped if no payload). */
static int s_state_payload(struct aws_websocket_decoder *decoder, struct aws_byte_cursor *data) {
    if (data->len == 0) {
        return AWS_OP_SUCCESS;
    }

    AWS_ASSERT(decoder->current_frame.payload_length > decoder->state_bytes_processed);
    uint64_t bytes_remaining = decoder->current_frame.payload_length - decoder->state_bytes_processed;
    size_t bytes_to_consume = bytes_remaining < data->len ? (size_t)bytes_remaining : data->len;

    struct aws_byte_cursor payload = aws_byte_cursor_advance(data, bytes_to_consume);

    /* Unmask data, if necessary.
     * RFC-6455 Section 5.3 Client-to-Server Masking
     * Each byte of payload is XOR against a byte of the masking-key */
    if (decoder->current_frame.masked) {
        uint64_t mask_index = decoder->state_bytes_processed;

        /* Optimization idea: don't do this 1 byte at a time */
        uint8_t *current_byte = payload.ptr;
        uint8_t *end_byte = payload.ptr + payload.len;
        while (current_byte != end_byte) {
            *current_byte++ ^= decoder->current_frame.masking_key[mask_index++ % 4];
        }
    }

    /* TODO: validate utf-8 */
    /* TODO: validate payload of CLOSE frame */

    /* Invoke on_payload() callback to inform user of payload data */
    int err = decoder->on_payload(payload, decoder->user_data);
    if (err) {
        return AWS_OP_ERR;
    }

    decoder->state_bytes_processed += payload.len;
    AWS_ASSERT(decoder->state_bytes_processed <= decoder->current_frame.payload_length);

    /* If all data consumed, proceed to next state. */
    if (decoder->state_bytes_processed == decoder->current_frame.payload_length) {
        decoder->state++;
    }

    return AWS_OP_SUCCESS;
}

static state_fn *s_state_functions[AWS_WEBSOCKET_DECODER_STATE_DONE] = {
    s_state_init,
    s_state_opcode_byte,
    s_state_length_byte,
    s_state_extended_length,
    s_state_masking_key_check,
    s_state_masking_key,
    s_state_payload_check,
    s_state_payload,
};

int aws_websocket_decoder_process(
    struct aws_websocket_decoder *decoder,
    struct aws_byte_cursor *data,
    bool *frame_complete) {

    /* Run state machine until frame is completely decoded, or the state stops changing.
     * Note that we don't stop looping when data->len reaches zero, because some states consume no data. */
    while (decoder->state != AWS_WEBSOCKET_DECODER_STATE_DONE) {
        enum aws_websocket_decoder_state prev_state = decoder->state;

        int err = s_state_functions[decoder->state](decoder, data);
        if (err) {
            return AWS_OP_ERR;
        }

        if (decoder->state == prev_state) {
            AWS_ASSERT(data->len == 0); /* If no more work to do, all possible data should have been consumed */
            break;
        }
    }

    if (decoder->state == AWS_WEBSOCKET_DECODER_STATE_DONE) {
        decoder->state = AWS_WEBSOCKET_DECODER_STATE_INIT;
        *frame_complete = true;
        return AWS_OP_SUCCESS;
    }

    *frame_complete = false;
    return AWS_OP_SUCCESS;
}

void aws_websocket_decoder_init(
    struct aws_websocket_decoder *decoder,
    aws_websocket_decoder_frame_fn *on_frame,
    aws_websocket_decoder_payload_fn *on_payload,
    void *user_data) {

    AWS_ZERO_STRUCT(*decoder);
    decoder->user_data = user_data;
    decoder->on_frame = on_frame;
    decoder->on_payload = on_payload;
}
