/**
 * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
 * SPDX-License-Identifier: Apache-2.0.
 */

#include <aws/http/private/websocket_encoder.h>

/* TODO: encoder logging */
/* TODO: implement masking function in aws-c-common */
/* TODO: use nospec advance? */

typedef int(state_fn)(struct aws_websocket_encoder *encoder, struct aws_byte_buf *out_buf);

/* STATE_INIT: Outputs no data */
static int s_state_init(struct aws_websocket_encoder *encoder, struct aws_byte_buf *out_buf) {
    (void)out_buf;

    if (!encoder->is_frame_in_progress) {
        return aws_raise_error(AWS_ERROR_INVALID_STATE);
    }

    encoder->state = AWS_WEBSOCKET_ENCODER_STATE_OPCODE_BYTE;
    return AWS_OP_SUCCESS;
}

/* STATE_OPCODE_BYTE: Outputs 1st byte of frame, which is packed with goodies. */
static int s_state_opcode_byte(struct aws_websocket_encoder *encoder, struct aws_byte_buf *out_buf) {

    AWS_ASSERT((encoder->frame.opcode & 0xF0) == 0); /* Should be impossible, the opcode was checked in start_frame() */

    /* Right 4 bits are opcode, left 4 bits are fin|rsv1|rsv2|rsv3 */
    uint8_t byte = encoder->frame.opcode;
    byte |= (encoder->frame.fin << 7);
    byte |= (encoder->frame.rsv[0] << 6);
    byte |= (encoder->frame.rsv[1] << 5);
    byte |= (encoder->frame.rsv[2] << 4);

    /* If buffer has room to write, proceed to next state */
    if (aws_byte_buf_write_u8(out_buf, byte)) {
        encoder->state = AWS_WEBSOCKET_ENCODER_STATE_LENGTH_BYTE;
    }

    return AWS_OP_SUCCESS;
}

/* STATE_LENGTH_BYTE: Output 2nd byte of frame, which indicates payload length */
static int s_state_length_byte(struct aws_websocket_encoder *encoder, struct aws_byte_buf *out_buf) {
    /* First bit is masking bool */
    uint8_t byte = (uint8_t)(encoder->frame.masked << 7);

    /* Next 7bits are length, if length is small.
     * Otherwise next 7bits are a magic number indicating how many bytes will be required to encode actual length */
    bool extended_length_required;

    if (encoder->frame.payload_length < AWS_WEBSOCKET_2BYTE_EXTENDED_LENGTH_MIN_VALUE) {
        byte |= (uint8_t)encoder->frame.payload_length;
        extended_length_required = false;
    } else if (encoder->frame.payload_length <= AWS_WEBSOCKET_2BYTE_EXTENDED_LENGTH_MAX_VALUE) {
        byte |= AWS_WEBSOCKET_7BIT_VALUE_FOR_2BYTE_EXTENDED_LENGTH;
        extended_length_required = true;
    } else {
        AWS_ASSERT(encoder->frame.payload_length <= AWS_WEBSOCKET_8BYTE_EXTENDED_LENGTH_MAX_VALUE);
        byte |= AWS_WEBSOCKET_7BIT_VALUE_FOR_8BYTE_EXTENDED_LENGTH;
        extended_length_required = true;
    }

    /* If buffer has room to write, proceed to next appropriate state */
    if (aws_byte_buf_write_u8(out_buf, byte)) {
        if (extended_length_required) {
            encoder->state = AWS_WEBSOCKET_ENCODER_STATE_EXTENDED_LENGTH;
            encoder->state_bytes_processed = 0;
        } else {
            encoder->state = AWS_WEBSOCKET_ENCODER_STATE_MASKING_KEY_CHECK;
        }
    }

    return AWS_OP_SUCCESS;
}

/* STATE_EXTENDED_LENGTH: Output extended length (state skipped if not using extended length). */
static int s_state_extended_length(struct aws_websocket_encoder *encoder, struct aws_byte_buf *out_buf) {
    /* Fill tmp buffer with extended-length in network byte order */
    uint8_t network_bytes_array[8];
    struct aws_byte_buf network_bytes_buf =
        aws_byte_buf_from_empty_array(network_bytes_array, sizeof(network_bytes_array));
    if (encoder->frame.payload_length <= AWS_WEBSOCKET_2BYTE_EXTENDED_LENGTH_MAX_VALUE) {
        aws_byte_buf_write_be16(&network_bytes_buf, (uint16_t)encoder->frame.payload_length);
    } else {
        aws_byte_buf_write_be64(&network_bytes_buf, encoder->frame.payload_length);
    }

    /* Use cursor to iterate over tmp buffer */
    struct aws_byte_cursor network_bytes_cursor = aws_byte_cursor_from_buf(&network_bytes_buf);

    /* Advance cursor if some bytes already written */
    aws_byte_cursor_advance(&network_bytes_cursor, (size_t)encoder->state_bytes_processed);

    /* Shorten cursor if it won't all fit in out_buf */
    bool all_data_written = true;
    size_t space_available = out_buf->capacity - out_buf->len;
    if (network_bytes_cursor.len > space_available) {
        network_bytes_cursor.len = space_available;
        all_data_written = false;
    }

    aws_byte_buf_write_from_whole_cursor(out_buf, network_bytes_cursor);
    encoder->state_bytes_processed += network_bytes_cursor.len;

    /* If all bytes written, advance to next state */
    if (all_data_written) {
        encoder->state = AWS_WEBSOCKET_ENCODER_STATE_MASKING_KEY_CHECK;
    }

    return AWS_OP_SUCCESS;
}

/* MASKING_KEY_CHECK: Outputs no data. Gets things ready for (or decides to skip) the STATE_MASKING_KEY */
static int s_state_masking_key_check(struct aws_websocket_encoder *encoder, struct aws_byte_buf *out_buf) {
    (void)out_buf;

    if (encoder->frame.masked) {
        encoder->state_bytes_processed = 0;
        encoder->state = AWS_WEBSOCKET_ENCODER_STATE_MASKING_KEY;
    } else {
        encoder->state = AWS_WEBSOCKET_ENCODER_STATE_PAYLOAD_CHECK;
    }

    return AWS_OP_SUCCESS;
}

/* MASKING_KEY: Output masking-key (state skipped if no masking key). */
static int s_state_masking_key(struct aws_websocket_encoder *encoder, struct aws_byte_buf *out_buf) {
    /* Prepare cursor to iterate over masking-key bytes */
    struct aws_byte_cursor cursor =
        aws_byte_cursor_from_array(encoder->frame.masking_key, sizeof(encoder->frame.masking_key));

    /* Advance cursor if some bytes already written (moves ptr forward but shortens len so end stays in place) */
    aws_byte_cursor_advance(&cursor, (size_t)encoder->state_bytes_processed);

    /* Shorten cursor if it won't all fit in out_buf */
    bool all_data_written = true;
    size_t space_available = out_buf->capacity - out_buf->len;
    if (cursor.len > space_available) {
        cursor.len = space_available;
        all_data_written = false;
    }

    aws_byte_buf_write_from_whole_cursor(out_buf, cursor);
    encoder->state_bytes_processed += cursor.len;

    /* If all bytes written, advance to next state */
    if (all_data_written) {
        encoder->state = AWS_WEBSOCKET_ENCODER_STATE_PAYLOAD_CHECK;
    }

    return AWS_OP_SUCCESS;
}

/* MASKING_KEY_CHECK: Outputs no data. Gets things ready for (or decides to skip) STATE_PAYLOAD */
static int s_state_payload_check(struct aws_websocket_encoder *encoder, struct aws_byte_buf *out_buf) {
    (void)out_buf;

    if (encoder->frame.payload_length > 0) {
        encoder->state_bytes_processed = 0;
        encoder->state = AWS_WEBSOCKET_ENCODER_STATE_PAYLOAD;
    } else {
        encoder->state = AWS_WEBSOCKET_ENCODER_STATE_DONE;
    }

    return AWS_OP_SUCCESS;
}

/* PAYLOAD: Output payload until we're done (state skipped if no payload). */
static int s_state_payload(struct aws_websocket_encoder *encoder, struct aws_byte_buf *out_buf) {

    /* Bail early if out_buf has no space for writing */
    if (out_buf->len >= out_buf->capacity) {
        return AWS_OP_SUCCESS;
    }

    const uint64_t prev_bytes_processed = encoder->state_bytes_processed;
    const struct aws_byte_buf prev_buf = *out_buf;

    /* Invoke callback which will write to buffer */
    int err = encoder->stream_outgoing_payload(out_buf, encoder->user_data);
    if (err) {
        return AWS_OP_ERR;
    }

    /* Ensure that user did not commit forbidden acts upon the out_buf */
    AWS_FATAL_ASSERT(
        (out_buf->buffer == prev_buf.buffer) && (out_buf->capacity == prev_buf.capacity) &&
        (out_buf->len >= prev_buf.len));

    size_t bytes_written = out_buf->len - prev_buf.len;

    err = aws_add_u64_checked(encoder->state_bytes_processed, bytes_written, &encoder->state_bytes_processed);
    if (err) {
        return aws_raise_error(AWS_ERROR_HTTP_OUTGOING_STREAM_LENGTH_INCORRECT);
    }

    /* Mask data, if necessary.
     * RFC-6455 Section 5.3 Client-to-Server Masking
     * Each byte of payload is XOR against a byte of the masking-key */
    if (encoder->frame.masked) {
        uint64_t mask_index = prev_bytes_processed;

        /* Optimization idea: don't do this 1 byte at a time */
        uint8_t *current_byte = out_buf->buffer + prev_buf.len;
        uint8_t *end_byte = out_buf->buffer + out_buf->len;
        while (current_byte != end_byte) {
            *current_byte++ ^= encoder->frame.masking_key[mask_index++ % 4];
        }
    }

    /* If done writing payload, proceed to next state */
    if (encoder->state_bytes_processed == encoder->frame.payload_length) {
        encoder->state = AWS_WEBSOCKET_ENCODER_STATE_DONE;
    } else {
        /* Some more error-checking... */
        if (encoder->state_bytes_processed > encoder->frame.payload_length) {
            return aws_raise_error(AWS_ERROR_HTTP_OUTGOING_STREAM_LENGTH_INCORRECT);
        }
    }

    return AWS_OP_SUCCESS;
}

static state_fn *s_state_functions[AWS_WEBSOCKET_ENCODER_STATE_DONE] = {
    s_state_init,
    s_state_opcode_byte,
    s_state_length_byte,
    s_state_extended_length,
    s_state_masking_key_check,
    s_state_masking_key,
    s_state_payload_check,
    s_state_payload,
};

int aws_websocket_encoder_process(struct aws_websocket_encoder *encoder, struct aws_byte_buf *out_buf) {

    /* Run state machine until frame is completely decoded, or the state stops changing.
     * Note that we don't necessarily stop looping when out_buf is full, because not all states need to output data */
    while (encoder->state != AWS_WEBSOCKET_ENCODER_STATE_DONE) {
        const enum aws_websocket_encoder_state prev_state = encoder->state;

        int err = s_state_functions[encoder->state](encoder, out_buf);
        if (err) {
            return AWS_OP_ERR;
        }

        if (prev_state == encoder->state) {
            /* dev-assert: Check that each state is doing as much work as it possibly can.
             * Except for the PAYLOAD state, where it's up to the user to fill the buffer. */
            AWS_ASSERT((out_buf->len == out_buf->capacity) || (encoder->state == AWS_WEBSOCKET_ENCODER_STATE_PAYLOAD));

            break;
        }
    }

    if (encoder->state == AWS_WEBSOCKET_ENCODER_STATE_DONE) {
        encoder->state = AWS_WEBSOCKET_ENCODER_STATE_INIT;
        encoder->is_frame_in_progress = false;
    }

    return AWS_OP_SUCCESS;
}

int aws_websocket_encoder_start_frame(struct aws_websocket_encoder *encoder, const struct aws_websocket_frame *frame) {
    /* Error-check as much as possible before accepting next frame */
    if (encoder->is_frame_in_progress) {
        return aws_raise_error(AWS_ERROR_INVALID_STATE);
    }

    /* RFC-6455 Section 5.2 contains all these rules... */

    /* Opcode must fit in 4bits */
    if (frame->opcode != (frame->opcode & 0x0F)) {
        return aws_raise_error(AWS_ERROR_INVALID_ARGUMENT);
    }

    /* High bit of 8byte length must be clear */
    if (frame->payload_length > AWS_WEBSOCKET_8BYTE_EXTENDED_LENGTH_MAX_VALUE) {
        return aws_raise_error(AWS_ERROR_INVALID_ARGUMENT);
    }

    /* Data frames with the FIN bit clear are considered fragmented and must be followed by
     * 1+ CONTINUATION frames, where only the final CONTINUATION frame's FIN bit is set.
     *
     * Control frames may be injected in the middle of a fragmented message,
     * but control frames may not be fragmented themselves. */
    bool keep_expecting_continuation_data_frame = encoder->expecting_continuation_data_frame;
    if (aws_websocket_is_data_frame(frame->opcode)) {
        bool is_continuation_frame = (AWS_WEBSOCKET_OPCODE_CONTINUATION == frame->opcode);

        if (encoder->expecting_continuation_data_frame != is_continuation_frame) {
            return aws_raise_error(AWS_ERROR_INVALID_STATE);
        }

        keep_expecting_continuation_data_frame = !frame->fin;
    } else {
        /* Control frames themselves MUST NOT be fragmented. */
        if (!frame->fin) {
            return aws_raise_error(AWS_ERROR_INVALID_ARGUMENT);
        }
    }

    /* Frame accepted */
    encoder->frame = *frame;
    encoder->is_frame_in_progress = true;
    encoder->expecting_continuation_data_frame = keep_expecting_continuation_data_frame;

    return AWS_OP_SUCCESS;
}

bool aws_websocket_encoder_is_frame_in_progress(const struct aws_websocket_encoder *encoder) {
    return encoder->is_frame_in_progress;
}

void aws_websocket_encoder_init(
    struct aws_websocket_encoder *encoder,
    aws_websocket_encoder_payload_fn *stream_outgoing_payload,
    void *user_data) {

    AWS_ZERO_STRUCT(*encoder);
    encoder->user_data = user_data;
    encoder->stream_outgoing_payload = stream_outgoing_payload;
}

uint64_t aws_websocket_frame_encoded_size(const struct aws_websocket_frame *frame) {
    /* This is an internal function, so asserts are sufficient error handling */
    AWS_ASSERT(frame);
    AWS_ASSERT(frame->payload_length <= AWS_WEBSOCKET_8BYTE_EXTENDED_LENGTH_MAX_VALUE);

    /* All frames start with at least 2 bytes */
    uint64_t total = 2;

    /* If masked, add 4 bytes for masking-key */
    if (frame->masked) {
        total += 4;
    }

    /* If extended payload length, add 2 or 8 bytes */
    if (frame->payload_length >= AWS_WEBSOCKET_8BYTE_EXTENDED_LENGTH_MIN_VALUE) {
        total += 8;
    } else if (frame->payload_length >= AWS_WEBSOCKET_2BYTE_EXTENDED_LENGTH_MIN_VALUE) {
        total += 2;
    }

    /* Plus payload itself */
    total += frame->payload_length;

    return total;
}
