/**
 * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
 * SPDX-License-Identifier: Apache-2.0.
 */

#include <aws/mqtt/private/client_impl.h>

#include <aws/mqtt/private/packets.h>
#include <aws/mqtt/private/topic_tree.h>

#include <aws/io/logging.h>

#include <aws/common/clock.h>
#include <aws/common/math.h>
#include <aws/common/task_scheduler.h>

#include <inttypes.h>

#ifdef _MSC_VER
#    pragma warning(disable : 4204)
#endif

/*******************************************************************************
 * Packet State Machine
 ******************************************************************************/

typedef int(packet_handler_fn)(struct aws_mqtt_client_connection *connection, struct aws_byte_cursor message_cursor);

static int s_packet_handler_default(
    struct aws_mqtt_client_connection *connection,
    struct aws_byte_cursor message_cursor) {
    (void)connection;
    (void)message_cursor;

    AWS_LOGF_ERROR(AWS_LS_MQTT_CLIENT, "id=%p: Unhandled packet type received", (void *)connection);
    return aws_raise_error(AWS_ERROR_MQTT_INVALID_PACKET_TYPE);
}

static void s_on_time_to_ping(struct aws_channel_task *channel_task, void *arg, enum aws_task_status status);
static void s_schedule_ping(struct aws_mqtt_client_connection *connection) {
    aws_channel_task_init(&connection->ping_task, s_on_time_to_ping, connection, "mqtt_ping");

    uint64_t now = 0;
    aws_channel_current_clock_time(connection->slot->channel, &now);
    AWS_LOGF_TRACE(
        AWS_LS_MQTT_CLIENT, "id=%p: Scheduling PING. current timestamp is %" PRIu64, (void *)connection, now);

    uint64_t schedule_time =
        now + aws_timestamp_convert(connection->keep_alive_time_secs, AWS_TIMESTAMP_SECS, AWS_TIMESTAMP_NANOS, NULL);

    AWS_LOGF_TRACE(
        AWS_LS_MQTT_CLIENT,
        "id=%p: The next ping will be run at timestamp %" PRIu64,
        (void *)connection,
        schedule_time);
    aws_channel_schedule_task_future(connection->slot->channel, &connection->ping_task, schedule_time);
}

static void s_on_time_to_ping(struct aws_channel_task *channel_task, void *arg, enum aws_task_status status) {
    (void)channel_task;

    if (status == AWS_TASK_STATUS_RUN_READY) {
        struct aws_mqtt_client_connection *connection = arg;
        AWS_LOGF_TRACE(AWS_LS_MQTT_CLIENT, "id=%p: Sending PING", (void *)connection);
        aws_mqtt_client_connection_ping(connection);
        s_schedule_ping(connection);
    }
}
static int s_packet_handler_connack(
    struct aws_mqtt_client_connection *connection,
    struct aws_byte_cursor message_cursor) {

    AWS_LOGF_TRACE(AWS_LS_MQTT_CLIENT, "id=%p: CONNACK received", (void *)connection);

    struct aws_mqtt_packet_connack connack;
    if (aws_mqtt_packet_connack_decode(&message_cursor, &connack)) {
        AWS_LOGF_ERROR(
            AWS_LS_MQTT_CLIENT, "id=%p: error %d parsing CONNACK packet", (void *)connection, aws_last_error());

        return AWS_OP_ERR;
    }
    bool was_reconnecting;
    struct aws_linked_list requests;
    aws_linked_list_init(&requests);
    { /* BEGIN CRITICAL SECTION */
        mqtt_connection_lock_synced_data(connection);
        /* User requested disconnect, don't do anything */
        if (connection->synced_data.state >= AWS_MQTT_CLIENT_STATE_DISCONNECTING) {
            mqtt_connection_unlock_synced_data(connection);
            AWS_LOGF_TRACE(
                AWS_LS_MQTT_CLIENT, "id=%p: User has requested disconnect, dropping connection", (void *)connection);
            return AWS_OP_SUCCESS;
        }

        was_reconnecting = connection->synced_data.state == AWS_MQTT_CLIENT_STATE_RECONNECTING;
        if (connack.connect_return_code == AWS_MQTT_CONNECT_ACCEPTED) {
            AWS_LOGF_DEBUG(
                AWS_LS_MQTT_CLIENT,
                "id=%p: connection was accepted, switch state from %d to CONNECTED.",
                (void *)connection,
                (int)connection->synced_data.state);
            /* Don't change the state if it's not ACCEPTED by broker */
            mqtt_connection_set_state(connection, AWS_MQTT_CLIENT_STATE_CONNECTED);
            aws_linked_list_swap_contents(&connection->synced_data.pending_requests_list, &requests);
        }
        mqtt_connection_unlock_synced_data(connection);
    } /* END CRITICAL SECTION */
    connection->connection_count++;

    /* Reset the current timeout timer */
    connection->reconnect_timeouts.current = connection->reconnect_timeouts.min;

    if (connack.connect_return_code == AWS_MQTT_CONNECT_ACCEPTED) {
        /* If successfully connected, schedule all pending tasks */
        AWS_LOGF_TRACE(
            AWS_LS_MQTT_CLIENT, "id=%p: connection was accepted processing offline requests.", (void *)connection);

        if (!aws_linked_list_empty(&requests)) {

            struct aws_linked_list_node *current = aws_linked_list_front(&requests);
            const struct aws_linked_list_node *end = aws_linked_list_end(&requests);

            do {
                struct aws_mqtt_request *request = AWS_CONTAINER_OF(current, struct aws_mqtt_request, list_node);
                AWS_LOGF_TRACE(
                    AWS_LS_MQTT_CLIENT,
                    "id=%p: processing offline request %" PRIu16,
                    (void *)connection,
                    request->packet_id);
                aws_channel_schedule_task_now(connection->slot->channel, &request->outgoing_task);
                current = current->next;
            } while (current != end);
        }
    } else {
        AWS_LOGF_ERROR(
            AWS_LS_MQTT_CLIENT,
            "id=%p: invalid connect return code %d, disconnecting",
            (void *)connection,
            connack.connect_return_code);
        /* If error code returned, disconnect, on_completed will be invoked from shutdown process */
        aws_channel_shutdown(connection->slot->channel, AWS_ERROR_MQTT_PROTOCOL_ERROR);

        return AWS_OP_SUCCESS;
    }

    /* It is possible for a connection to complete, and a hangup to occur before the
     * CONNECT/CONNACK cycle completes. In that case, we must deliver on_connection_complete
     * on the first successful CONNACK or user code will never think it's connected */
    if (was_reconnecting && connection->connection_count > 1) {

        AWS_LOGF_TRACE(
            AWS_LS_MQTT_CLIENT,
            "id=%p: connection is a resumed connection, invoking on_resumed callback",
            (void *)connection);

        MQTT_CLIENT_CALL_CALLBACK_ARGS(connection, on_resumed, connack.connect_return_code, connack.session_present);
    } else {

        aws_create_reconnect_task(connection);

        AWS_LOGF_TRACE(
            AWS_LS_MQTT_CLIENT,
            "id=%p: connection is a new connection, invoking on_connection_complete callback",
            (void *)connection);
        MQTT_CLIENT_CALL_CALLBACK_ARGS(
            connection, on_connection_complete, AWS_OP_SUCCESS, connack.connect_return_code, connack.session_present);
    }

    AWS_LOGF_TRACE(AWS_LS_MQTT_CLIENT, "id=%p: connection callback completed", (void *)connection);

    s_schedule_ping(connection);
    return AWS_OP_SUCCESS;
}

static int s_packet_handler_publish(
    struct aws_mqtt_client_connection *connection,
    struct aws_byte_cursor message_cursor) {

    /* TODO: need to handle the QoS 2 message to avoid processing the message a second time */
    struct aws_mqtt_packet_publish publish;
    if (aws_mqtt_packet_publish_decode(&message_cursor, &publish)) {
        return AWS_OP_ERR;
    }

    aws_mqtt_topic_tree_publish(&connection->thread_data.subscriptions, &publish);

    bool dup = aws_mqtt_packet_publish_get_dup(&publish);
    enum aws_mqtt_qos qos = aws_mqtt_packet_publish_get_qos(&publish);
    bool retain = aws_mqtt_packet_publish_get_retain(&publish);

    MQTT_CLIENT_CALL_CALLBACK_ARGS(connection, on_any_publish, &publish.topic_name, &publish.payload, dup, qos, retain);

    AWS_LOGF_TRACE(
        AWS_LS_MQTT_CLIENT,
        "id=%p: publish received with msg id=%" PRIu16 " dup=%d qos=%d retain=%d payload-size=%zu topic=" PRInSTR,
        (void *)connection,
        publish.packet_identifier,
        dup,
        qos,
        retain,
        publish.payload.len,
        AWS_BYTE_CURSOR_PRI(publish.topic_name));
    struct aws_mqtt_packet_ack puback;
    AWS_ZERO_STRUCT(puback);

    /* Switch on QoS flags (bits 1 & 2) */
    switch (qos) {
        case AWS_MQTT_QOS_AT_MOST_ONCE:
            AWS_LOGF_TRACE(
                AWS_LS_MQTT_CLIENT, "id=%p: received publish QOS is 0, not sending puback", (void *)connection);
            /* No more communication necessary */
            break;
        case AWS_MQTT_QOS_AT_LEAST_ONCE:
            AWS_LOGF_TRACE(AWS_LS_MQTT_CLIENT, "id=%p: received publish QOS is 1, sending puback", (void *)connection);
            aws_mqtt_packet_puback_init(&puback, publish.packet_identifier);
            break;
        case AWS_MQTT_QOS_EXACTLY_ONCE:
            AWS_LOGF_TRACE(AWS_LS_MQTT_CLIENT, "id=%p: received publish QOS is 2, sending pubrec", (void *)connection);
            aws_mqtt_packet_pubrec_init(&puback, publish.packet_identifier);
            break;
        default:
            /* Impossible to hit this branch. QoS value is checked when decoding */
            AWS_FATAL_ASSERT(0);
            break;
    }

    if (puback.packet_identifier) {
        struct aws_io_message *message = mqtt_get_message_for_packet(connection, &puback.fixed_header);
        if (!message) {
            return AWS_OP_ERR;
        }

        if (aws_mqtt_packet_ack_encode(&message->message_data, &puback)) {
            aws_mem_release(message->allocator, message);
            return AWS_OP_ERR;
        }

        if (aws_channel_slot_send_message(connection->slot, message, AWS_CHANNEL_DIR_WRITE)) {
            aws_mem_release(message->allocator, message);
            return AWS_OP_ERR;
        }
    }

    return AWS_OP_SUCCESS;
}

static int s_packet_handler_ack(struct aws_mqtt_client_connection *connection, struct aws_byte_cursor message_cursor) {
    struct aws_mqtt_packet_ack ack;
    if (aws_mqtt_packet_ack_decode(&message_cursor, &ack)) {
        return AWS_OP_ERR;
    }

    AWS_LOGF_TRACE(
        AWS_LS_MQTT_CLIENT, "id=%p: received ack for message id %" PRIu16, (void *)connection, ack.packet_identifier);

    mqtt_request_complete(connection, AWS_ERROR_SUCCESS, ack.packet_identifier);

    return AWS_OP_SUCCESS;
}

static int s_packet_handler_suback(
    struct aws_mqtt_client_connection *connection,
    struct aws_byte_cursor message_cursor) {
    struct aws_mqtt_packet_suback suback;
    if (aws_mqtt_packet_suback_init(&suback, connection->allocator, 0 /* fake packet_id */)) {
        return AWS_OP_ERR;
    }

    if (aws_mqtt_packet_suback_decode(&message_cursor, &suback)) {
        goto error;
    }

    AWS_LOGF_TRACE(
        AWS_LS_MQTT_CLIENT,
        "id=%p: received suback for message id %" PRIu16,
        (void *)connection,
        suback.packet_identifier);

    struct aws_mqtt_request *request = NULL;

    { /* BEGIN CRITICAL SECTION */
        mqtt_connection_lock_synced_data(connection);
        struct aws_hash_element *elem = NULL;
        aws_hash_table_find(&connection->synced_data.outstanding_requests_table, &suback.packet_identifier, &elem);
        if (elem != NULL) {
            request = elem->value;
        }
        mqtt_connection_unlock_synced_data(connection);
    } /* END CRITICAL SECTION */

    if (request == NULL) {
        /* no corresponding request found */
        goto done;
    }

    struct subscribe_task_arg *task_arg = request->on_complete_ud;
    size_t request_topics_len = aws_array_list_length(&task_arg->topics);
    size_t suback_return_code_len = aws_array_list_length(&suback.return_codes);
    if (request_topics_len != suback_return_code_len) {
        goto error;
    }
    size_t num_filters = aws_array_list_length(&suback.return_codes);
    for (size_t i = 0; i < num_filters; ++i) {

        uint8_t return_code = 0;
        struct subscribe_task_topic *topic = NULL;
        aws_array_list_get_at(&suback.return_codes, (void *)&return_code, i);
        aws_array_list_get_at(&task_arg->topics, &topic, i);
        topic->request.qos = return_code;
    }

done:
    mqtt_request_complete(connection, AWS_ERROR_SUCCESS, suback.packet_identifier);
    aws_mqtt_packet_suback_clean_up(&suback);
    return AWS_OP_SUCCESS;
error:
    aws_mqtt_packet_suback_clean_up(&suback);
    return AWS_OP_ERR;
}

static int s_packet_handler_pubrec(
    struct aws_mqtt_client_connection *connection,
    struct aws_byte_cursor message_cursor) {

    struct aws_mqtt_packet_ack ack;
    if (aws_mqtt_packet_ack_decode(&message_cursor, &ack)) {
        return AWS_OP_ERR;
    }

    /* TODO: When sending PUBLISH with QoS 2, we should be storing the data until this packet is received, at which
     * point we may discard it. */

    /* Send PUBREL */
    aws_mqtt_packet_pubrel_init(&ack, ack.packet_identifier);
    struct aws_io_message *message = mqtt_get_message_for_packet(connection, &ack.fixed_header);
    if (!message) {
        return AWS_OP_ERR;
    }

    if (aws_mqtt_packet_ack_encode(&message->message_data, &ack)) {
        goto on_error;
    }

    if (aws_channel_slot_send_message(connection->slot, message, AWS_CHANNEL_DIR_WRITE)) {
        goto on_error;
    }

    return AWS_OP_SUCCESS;

on_error:

    if (message) {
        aws_mem_release(message->allocator, message);
    }

    return AWS_OP_ERR;
}

static int s_packet_handler_pubrel(
    struct aws_mqtt_client_connection *connection,
    struct aws_byte_cursor message_cursor) {

    struct aws_mqtt_packet_ack ack;
    if (aws_mqtt_packet_ack_decode(&message_cursor, &ack)) {
        return AWS_OP_ERR;
    }

    /* Send PUBCOMP */
    aws_mqtt_packet_pubcomp_init(&ack, ack.packet_identifier);
    struct aws_io_message *message = mqtt_get_message_for_packet(connection, &ack.fixed_header);
    if (!message) {
        return AWS_OP_ERR;
    }

    if (aws_mqtt_packet_ack_encode(&message->message_data, &ack)) {
        goto on_error;
    }

    if (aws_channel_slot_send_message(connection->slot, message, AWS_CHANNEL_DIR_WRITE)) {
        goto on_error;
    }

    return AWS_OP_SUCCESS;

on_error:

    if (message) {
        aws_mem_release(message->allocator, message);
    }

    return AWS_OP_ERR;
}

static int s_packet_handler_pingresp(
    struct aws_mqtt_client_connection *connection,
    struct aws_byte_cursor message_cursor) {

    (void)message_cursor;

    AWS_LOGF_TRACE(AWS_LS_MQTT_CLIENT, "id=%p: PINGRESP received", (void *)connection);

    connection->thread_data.waiting_on_ping_response = false;

    return AWS_OP_SUCCESS;
}

/* Bake up a big ol' function table just like Gramma used to make */
static packet_handler_fn *s_packet_handlers[] = {
    [AWS_MQTT_PACKET_CONNECT] = &s_packet_handler_default,
    [AWS_MQTT_PACKET_CONNACK] = &s_packet_handler_connack,
    [AWS_MQTT_PACKET_PUBLISH] = &s_packet_handler_publish,
    [AWS_MQTT_PACKET_PUBACK] = &s_packet_handler_ack,
    [AWS_MQTT_PACKET_PUBREC] = &s_packet_handler_pubrec,
    [AWS_MQTT_PACKET_PUBREL] = &s_packet_handler_pubrel,
    [AWS_MQTT_PACKET_PUBCOMP] = &s_packet_handler_ack,
    [AWS_MQTT_PACKET_SUBSCRIBE] = &s_packet_handler_default,
    [AWS_MQTT_PACKET_SUBACK] = &s_packet_handler_suback,
    [AWS_MQTT_PACKET_UNSUBSCRIBE] = &s_packet_handler_default,
    [AWS_MQTT_PACKET_UNSUBACK] = &s_packet_handler_ack,
    [AWS_MQTT_PACKET_PINGREQ] = &s_packet_handler_default,
    [AWS_MQTT_PACKET_PINGRESP] = &s_packet_handler_pingresp,
    [AWS_MQTT_PACKET_DISCONNECT] = &s_packet_handler_default,
};

/*******************************************************************************
 * Channel Handler
 ******************************************************************************/

static int s_process_mqtt_packet(
    struct aws_mqtt_client_connection *connection,
    enum aws_mqtt_packet_type packet_type,
    struct aws_byte_cursor packet) {
    { /* BEGIN CRITICAL SECTION */
        mqtt_connection_lock_synced_data(connection);
        /* [MQTT-3.2.0-1] The first packet sent from the Server to the Client MUST be a CONNACK Packet */
        if (connection->synced_data.state == AWS_MQTT_CLIENT_STATE_CONNECTING &&
            packet_type != AWS_MQTT_PACKET_CONNACK) {
            mqtt_connection_unlock_synced_data(connection);
            AWS_LOGF_ERROR(
                AWS_LS_MQTT_CLIENT,
                "id=%p: First message received from the server was not a CONNACK. Terminating connection.",
                (void *)connection);
            aws_channel_shutdown(connection->slot->channel, AWS_ERROR_MQTT_PROTOCOL_ERROR);
            return aws_raise_error(AWS_ERROR_MQTT_PROTOCOL_ERROR);
        }
        mqtt_connection_unlock_synced_data(connection);
    } /* END CRITICAL SECTION */

    if (AWS_UNLIKELY(packet_type > AWS_MQTT_PACKET_DISCONNECT || packet_type < AWS_MQTT_PACKET_CONNECT)) {
        AWS_LOGF_ERROR(
            AWS_LS_MQTT_CLIENT,
            "id=%p: Invalid packet type received %d. Terminating connection.",
            (void *)connection,
            packet_type);
        return aws_raise_error(AWS_ERROR_MQTT_INVALID_PACKET_TYPE);
    }

    /* Handle the packet */
    return s_packet_handlers[packet_type](connection, packet);
}

/**
 * Handles incoming messages from the server.
 */
static int s_process_read_message(
    struct aws_channel_handler *handler,
    struct aws_channel_slot *slot,
    struct aws_io_message *message) {

    struct aws_mqtt_client_connection *connection = handler->impl;

    if (message->message_type != AWS_IO_MESSAGE_APPLICATION_DATA || message->message_data.len < 1) {
        return AWS_OP_ERR;
    }

    AWS_LOGF_TRACE(
        AWS_LS_MQTT_CLIENT,
        "id=%p: precessing read message of size %zu",
        (void *)connection,
        message->message_data.len);

    /* This cursor will be updated as we read through the message. */
    struct aws_byte_cursor message_cursor = aws_byte_cursor_from_buf(&message->message_data);

    /* If there's pending packet left over from last time, attempt to complete it. */
    if (connection->thread_data.pending_packet.len) {
        int result = AWS_OP_SUCCESS;

        /* This determines how much to read from the message (min(expected_remaining, message.len)) */
        size_t to_read = connection->thread_data.pending_packet.capacity - connection->thread_data.pending_packet.len;
        /* This will be set to false if this message still won't complete the packet object. */
        bool packet_complete = true;
        if (to_read > message_cursor.len) {
            to_read = message_cursor.len;
            packet_complete = false;
        }

        /* Write the chunk to the buffer.
         * This will either complete the packet, or be the entirety of message if more data is required. */
        struct aws_byte_cursor chunk = aws_byte_cursor_advance(&message_cursor, to_read);
        AWS_ASSERT(chunk.ptr); /* Guaranteed to be in bounds */
        result = (int)aws_byte_buf_write_from_whole_cursor(&connection->thread_data.pending_packet, chunk) - 1;
        if (result) {
            goto handle_error;
        }

        /* If the packet is still incomplete, don't do anything with the data. */
        if (!packet_complete) {
            AWS_LOGF_TRACE(
                AWS_LS_MQTT_CLIENT,
                "id=%p: partial message is still incomplete, waiting on another read.",
                (void *)connection);

            goto cleanup;
        }

        /* Handle the completed pending packet */
        struct aws_byte_cursor packet_data = aws_byte_cursor_from_buf(&connection->thread_data.pending_packet);
        AWS_LOGF_TRACE(AWS_LS_MQTT_CLIENT, "id=%p: full mqtt packet re-assembled, dispatching.", (void *)connection);
        result = s_process_mqtt_packet(connection, aws_mqtt_get_packet_type(packet_data.ptr), packet_data);

    handle_error:
        /* Clean up the pending packet */
        aws_byte_buf_clean_up(&connection->thread_data.pending_packet);
        AWS_ZERO_STRUCT(connection->thread_data.pending_packet);

        if (result) {
            return AWS_OP_ERR;
        }
    }

    while (message_cursor.len) {

        /* Temp byte cursor so we can decode the header without advancing message_cursor. */
        struct aws_byte_cursor header_decode = message_cursor;

        struct aws_mqtt_fixed_header packet_header;
        AWS_ZERO_STRUCT(packet_header);
        int result = aws_mqtt_fixed_header_decode(&header_decode, &packet_header);

        /* Calculate how much data was read. */
        const size_t fixed_header_size = message_cursor.len - header_decode.len;

        if (result) {
            if (aws_last_error() == AWS_ERROR_SHORT_BUFFER) {
                /* Message data too short, store data and come back later. */
                AWS_LOGF_TRACE(
                    AWS_LS_MQTT_CLIENT, "id=%p: message is incomplete, waiting on another read.", (void *)connection);
                if (aws_byte_buf_init(
                        &connection->thread_data.pending_packet,
                        connection->allocator,
                        fixed_header_size + packet_header.remaining_length)) {

                    return AWS_OP_ERR;
                }

                /* Write the partial packet. */
                if (!aws_byte_buf_write_from_whole_cursor(&connection->thread_data.pending_packet, message_cursor)) {
                    aws_byte_buf_clean_up(&connection->thread_data.pending_packet);
                    return AWS_OP_ERR;
                }

                aws_reset_error();
                goto cleanup;
            } else {
                return AWS_OP_ERR;
            }
        }

        struct aws_byte_cursor packet_data =
            aws_byte_cursor_advance(&message_cursor, fixed_header_size + packet_header.remaining_length);
        AWS_LOGF_TRACE(AWS_LS_MQTT_CLIENT, "id=%p: full mqtt packet read, dispatching.", (void *)connection);
        s_process_mqtt_packet(connection, packet_header.packet_type, packet_data);
    }

cleanup:
    /* Do cleanup */
    aws_channel_slot_increment_read_window(slot, message->message_data.len);
    aws_mem_release(message->allocator, message);

    return AWS_OP_SUCCESS;
}

static int s_shutdown(
    struct aws_channel_handler *handler,
    struct aws_channel_slot *slot,
    enum aws_channel_direction dir,
    int error_code,
    bool free_scarce_resources_immediately) {

    struct aws_mqtt_client_connection *connection = handler->impl;

    if (dir == AWS_CHANNEL_DIR_WRITE) {
        /* On closing write direction, send out disconnect packet before closing connection. */

        if (!free_scarce_resources_immediately) {

            if (error_code == AWS_OP_SUCCESS) {
                AWS_LOGF_INFO(
                    AWS_LS_MQTT_CLIENT,
                    "id=%p: sending disconnect message as part of graceful shutdown.",
                    (void *)connection);
                /* On clean shutdown, send the disconnect message */
                struct aws_mqtt_packet_connection disconnect;
                aws_mqtt_packet_disconnect_init(&disconnect);

                struct aws_io_message *message = mqtt_get_message_for_packet(connection, &disconnect.fixed_header);
                if (!message) {
                    goto done;
                }

                if (aws_mqtt_packet_connection_encode(&message->message_data, &disconnect)) {
                    AWS_LOGF_DEBUG(
                        AWS_LS_MQTT_CLIENT,
                        "id=%p: failed to encode courteous disconnect io message",
                        (void *)connection);
                    aws_mem_release(message->allocator, message);
                    goto done;
                }

                if (aws_channel_slot_send_message(slot, message, AWS_CHANNEL_DIR_WRITE)) {
                    AWS_LOGF_DEBUG(
                        AWS_LS_MQTT_CLIENT,
                        "id=%p: failed to send courteous disconnect io message",
                        (void *)connection);
                    aws_mem_release(message->allocator, message);
                    goto done;
                }
            }
        }
    }

done:
    return aws_channel_slot_on_handler_shutdown_complete(slot, dir, error_code, free_scarce_resources_immediately);
}

static size_t s_initial_window_size(struct aws_channel_handler *handler) {

    (void)handler;

    return SIZE_MAX;
}

static void s_destroy(struct aws_channel_handler *handler) {

    struct aws_mqtt_client_connection *connection = handler->impl;
    (void)connection;
}

static size_t s_message_overhead(struct aws_channel_handler *handler) {
    (void)handler;
    return 0;
}

struct aws_channel_handler_vtable *aws_mqtt_get_client_channel_vtable(void) {

    static struct aws_channel_handler_vtable s_vtable = {
        .process_read_message = &s_process_read_message,
        .process_write_message = NULL,
        .increment_read_window = NULL,
        .shutdown = &s_shutdown,
        .initial_window_size = &s_initial_window_size,
        .message_overhead = &s_message_overhead,
        .destroy = &s_destroy,
    };

    return &s_vtable;
}

/*******************************************************************************
 * Helpers
 ******************************************************************************/

struct aws_io_message *mqtt_get_message_for_packet(
    struct aws_mqtt_client_connection *connection,
    struct aws_mqtt_fixed_header *header) {

    const size_t required_length = 3 + header->remaining_length;

    struct aws_io_message *message = aws_channel_acquire_message_from_pool(
        connection->slot->channel, AWS_IO_MESSAGE_APPLICATION_DATA, required_length);

    AWS_LOGF_TRACE(
        AWS_LS_MQTT_CLIENT,
        "id=%p: Acquiring memory from pool of required_length %zu",
        (void *)connection,
        required_length);

    return message;
}

/*******************************************************************************
 * Requests
 ******************************************************************************/

/* Send the request */
static void s_request_outgoing_task(struct aws_channel_task *task, void *arg, enum aws_task_status status) {

    struct aws_mqtt_request *request = arg;
    struct aws_mqtt_client_connection *connection = request->connection;

    if (status == AWS_TASK_STATUS_CANCELED) {
        /* Connection lost before the request ever get send, check the request needs to be retried or not */
        if (request->retryable) {
            AWS_LOGF_DEBUG(
                AWS_LS_MQTT_CLIENT,
                "static: task id %p, was canceled due to the channel shutting down. Request for packet id "
                "%" PRIu16 ". will be retried",
                (void *)task,
                request->packet_id);
            /* put it into the offline queue. */
            { /* BEGIN CRITICAL SECTION */
                mqtt_connection_lock_synced_data(connection);
                aws_linked_list_push_back(&connection->synced_data.pending_requests_list, &request->list_node);
                mqtt_connection_unlock_synced_data(connection);
            } /* END CRITICAL SECTION */
        } else {
            AWS_LOGF_DEBUG(
                AWS_LS_MQTT_CLIENT,
                "static: task id %p, was canceled due to the channel shutting down. Request for packet id "
                "%" PRIu16 ". will NOT be retried, will be cancelled",
                (void *)task,
                request->packet_id);
            /* Fire the callback and clean up the memory, as the connection get destroyed. */
            if (request->on_complete) {
                request->on_complete(
                    connection, request->packet_id, AWS_ERROR_MQTT_NOT_CONNECTED, request->on_complete_ud);
            }
            { /* BEGIN CRITICAL SECTION */
                mqtt_connection_lock_synced_data(connection);
                aws_hash_table_remove(
                    &connection->synced_data.outstanding_requests_table, &request->packet_id, NULL, NULL);
                aws_memory_pool_release(&connection->synced_data.requests_pool, request);
                mqtt_connection_unlock_synced_data(connection);
            } /* END CRITICAL SECTION */
        }
        return;
    }

    /* Send the request */
    enum aws_mqtt_client_request_state state =
        request->send_request(request->packet_id, !request->initiated, request->send_request_ud);
    request->initiated = true;
    int error_code = AWS_ERROR_SUCCESS;
    switch (state) {
        case AWS_MQTT_CLIENT_REQUEST_ERROR:
            error_code = aws_last_error();
            AWS_LOGF_ERROR(
                AWS_LS_MQTT_CLIENT,
                "id=%p: sending request %" PRIu16 " failed with error %d.",
                (void *)request->connection,
                request->packet_id,
                error_code);
            /* fall-thru */

        case AWS_MQTT_CLIENT_REQUEST_COMPLETE:
            AWS_LOGF_TRACE(
                AWS_LS_MQTT_CLIENT,
                "id=%p: sending request %" PRIu16 " complete, invoking on_complete callback.",
                (void *)request->connection,
                request->packet_id);
            /* If the send_request function reports the request is complete,
             * remove from the hash table and call the callback. */
            if (request->on_complete) {
                request->on_complete(connection, request->packet_id, error_code, request->on_complete_ud);
            }
            { /* BEGIN CRITICAL SECTION */
                mqtt_connection_lock_synced_data(connection);
                aws_hash_table_remove(
                    &connection->synced_data.outstanding_requests_table, &request->packet_id, NULL, NULL);
                aws_memory_pool_release(&connection->synced_data.requests_pool, request);
                mqtt_connection_unlock_synced_data(connection);
            } /* END CRITICAL SECTION */
            break;

        case AWS_MQTT_CLIENT_REQUEST_ONGOING:
            AWS_LOGF_TRACE(
                AWS_LS_MQTT_CLIENT,
                "id=%p: request %" PRIu16 " sent, but waiting on an acknowledgement from peer.",
                (void *)request->connection,
                request->packet_id);
            /* Put the request into the ongoing list */
            aws_linked_list_push_back(&connection->thread_data.ongoing_requests_list, &request->list_node);
            break;
    }
}

uint16_t mqtt_create_request(
    struct aws_mqtt_client_connection *connection,
    aws_mqtt_send_request_fn *send_request,
    void *send_request_ud,
    aws_mqtt_op_complete_fn *on_complete,
    void *on_complete_ud,
    bool noRetry) {

    AWS_ASSERT(connection);
    AWS_ASSERT(send_request);
    struct aws_mqtt_request *next_request = NULL;
    bool should_schedule_task = false;
    struct aws_channel *channel = NULL;
    { /* BEGIN CRITICAL SECTION */
        mqtt_connection_lock_synced_data(connection);
        if (connection->synced_data.state == AWS_MQTT_CLIENT_STATE_DISCONNECTING) {
            mqtt_connection_unlock_synced_data(connection);
            /* User requested disconnecting, ensure no new requests are made until the channel finished shutting
             * down. */
            AWS_LOGF_ERROR(
                AWS_LS_MQTT_CLIENT,
                "id=%p: Disconnect requested, stop creating any new request until disconnect process finishes.",
                (void *)connection);
            aws_raise_error(AWS_ERROR_MQTT_CONNECTION_DISCONNECTING);
            return 0;
        }

        if (noRetry && connection->synced_data.state != AWS_MQTT_CLIENT_STATE_CONNECTED) {
            mqtt_connection_unlock_synced_data(connection);
            /* Not offline queueing QoS 0 publish or PINGREQ. Fail the call. */
            AWS_LOGF_DEBUG(
                AWS_LS_MQTT_CLIENT,
                "id=%p: Not currently connected. No offline queueing for QoS 0 publish or pingreq.",
                (void *)connection);
            aws_raise_error(AWS_ERROR_MQTT_NOT_CONNECTED);
            return 0;
        }
        /**
         * Find a free packet ID.
         * QoS 0 PUBLISH packets don't actually need an ID on the wire,
         * but we assign them internally anyway just so everything has a unique ID.
         *
         * Yes, this is an O(N) search.
         * We remember the last ID we assigned, so it's O(1) in the common case.
         * But it's theoretically possible to reach O(N) where N is just above 64000
         * if the user is letting a ton of un-ack'd messages queue up
         */
        uint16_t search_start = connection->synced_data.packet_id;
        struct aws_hash_element *elem = NULL;
        while (true) {
            /* Increment ID, watch out for overflow, ID cannot be 0 */
            if (connection->synced_data.packet_id == UINT16_MAX) {
                connection->synced_data.packet_id = 1;
            } else {
                connection->synced_data.packet_id++;
            }

            /* Is there already an outstanding request using this ID? */
            aws_hash_table_find(
                &connection->synced_data.outstanding_requests_table, &connection->synced_data.packet_id, &elem);

            if (elem == NULL) {
                /* Found a free ID! Break out of loop */
                break;
            } else if (connection->synced_data.packet_id == search_start) {
                /* Every ID is taken */
                mqtt_connection_unlock_synced_data(connection);
                AWS_LOGF_ERROR(
                    AWS_LS_MQTT_CLIENT,
                    "id=%p: Queue is full. No more packet IDs are available at this time.",
                    (void *)connection);
                aws_raise_error(AWS_ERROR_MQTT_QUEUE_FULL);
                return 0;
            }
        }

        next_request = aws_memory_pool_acquire(&connection->synced_data.requests_pool);
        if (!next_request) {
            mqtt_connection_unlock_synced_data(connection);
            return 0;
        }
        memset(next_request, 0, sizeof(struct aws_mqtt_request));

        next_request->packet_id = connection->synced_data.packet_id;

        if (aws_hash_table_put(
                &connection->synced_data.outstanding_requests_table, &next_request->packet_id, next_request, NULL)) {
            /* failed to put the next request into the table */
            aws_memory_pool_release(&connection->synced_data.requests_pool, next_request);
            mqtt_connection_unlock_synced_data(connection);
            return 0;
        }
        /* Store the request by packet_id */
        next_request->allocator = connection->allocator;
        next_request->connection = connection;
        next_request->initiated = false;
        next_request->retryable = !noRetry;
        next_request->send_request = send_request;
        next_request->send_request_ud = send_request_ud;
        next_request->on_complete = on_complete;
        next_request->on_complete_ud = on_complete_ud;
        aws_channel_task_init(
            &next_request->outgoing_task, s_request_outgoing_task, next_request, "mqtt_outgoing_request_task");
        if (connection->synced_data.state != AWS_MQTT_CLIENT_STATE_CONNECTED) {
            aws_linked_list_push_back(&connection->synced_data.pending_requests_list, &next_request->list_node);
        } else {
            AWS_ASSERT(connection->slot);
            AWS_ASSERT(connection->slot->channel);
            should_schedule_task = true;
            channel = connection->slot->channel;
            /* keep the channel alive until the task is scheduled */
            aws_channel_acquire_hold(channel);
        }
        mqtt_connection_unlock_synced_data(connection);
    } /* END CRITICAL SECTION */
    if (should_schedule_task) {
        AWS_LOGF_TRACE(
            AWS_LS_MQTT_CLIENT,
            "id=%p: Currently not in the event-loop thread, scheduling a task to send message id %" PRIu16 ".",
            (void *)connection,
            next_request->packet_id);
        aws_channel_schedule_task_now(channel, &next_request->outgoing_task);
        /* release the refcount we hold with the protection of lock */
        aws_channel_release_hold(channel);
    }

    return next_request->packet_id;
}

void mqtt_request_complete(struct aws_mqtt_client_connection *connection, int error_code, uint16_t packet_id) {

    AWS_LOGF_TRACE(
        AWS_LS_MQTT_CLIENT,
        "id=%p: message id %" PRIu16 " completed with error code %d, removing from outstanding requests list.",
        (void *)connection,
        packet_id,
        error_code);

    bool found_request = false;
    aws_mqtt_op_complete_fn *on_complete = NULL;
    void *on_complete_ud = NULL;

    { /* BEGIN CRITICAL SECTION */
        mqtt_connection_lock_synced_data(connection);
        struct aws_hash_element *elem = NULL;
        aws_hash_table_find(&connection->synced_data.outstanding_requests_table, &packet_id, &elem);
        if (elem != NULL) {
            found_request = true;

            struct aws_mqtt_request *request = elem->value;
            on_complete = request->on_complete;
            on_complete_ud = request->on_complete_ud;

            /* clean up request resources */
            aws_hash_table_remove_element(&connection->synced_data.outstanding_requests_table, elem);
            /* remove the request from the list, which is thread_data.ongoing_requests_list */
            aws_linked_list_remove(&request->list_node);
            aws_memory_pool_release(&connection->synced_data.requests_pool, request);
        }
        mqtt_connection_unlock_synced_data(connection);
    } /* END CRITICAL SECTION */

    if (!found_request) {
        AWS_LOGF_DEBUG(
            AWS_LS_MQTT_CLIENT,
            "id=%p: received completion for message id %" PRIu16
            " but no outstanding request exists.  Assuming this is an ack of a resend when the first request has "
            "already completed.",
            (void *)connection,
            packet_id);
        return;
    }

    /* Invoke the complete callback. */
    if (on_complete) {
        on_complete(connection, packet_id, error_code, on_complete_ud);
    }
}

struct mqtt_shutdown_task {
    int error_code;
    struct aws_channel_task task;
};

static void s_mqtt_disconnect_task(struct aws_channel_task *channel_task, void *arg, enum aws_task_status status) {

    (void)status;

    struct mqtt_shutdown_task *task = AWS_CONTAINER_OF(channel_task, struct mqtt_shutdown_task, task);
    struct aws_mqtt_client_connection *connection = arg;

    AWS_LOGF_TRACE(AWS_LS_MQTT_CLIENT, "id=%p: Doing disconnect", (void *)connection);
    { /* BEGIN CRITICAL SECTION */
        mqtt_connection_lock_synced_data(connection);
        /* If there is an outstanding reconnect task, cancel it */
        if (connection->synced_data.state == AWS_MQTT_CLIENT_STATE_DISCONNECTING && connection->reconnect_task) {
            aws_atomic_store_ptr(&connection->reconnect_task->connection_ptr, NULL);
            /* If the reconnect_task isn't scheduled, free it */
            if (connection->reconnect_task && !connection->reconnect_task->task.timestamp) {
                aws_mem_release(connection->reconnect_task->allocator, connection->reconnect_task);
            }
            connection->reconnect_task = NULL;
        }
        mqtt_connection_unlock_synced_data(connection);
    } /* END CRITICAL SECTION */

    if (connection->slot && connection->slot->channel) {
        aws_channel_shutdown(connection->slot->channel, task->error_code);
    }

    aws_mem_release(connection->allocator, task);
}

void mqtt_disconnect_impl(struct aws_mqtt_client_connection *connection, int error_code) {
    if (connection->slot) {
        struct mqtt_shutdown_task *shutdown_task =
            aws_mem_calloc(connection->allocator, 1, sizeof(struct mqtt_shutdown_task));
        shutdown_task->error_code = error_code;
        aws_channel_task_init(&shutdown_task->task, s_mqtt_disconnect_task, connection, "mqtt_disconnect");
        aws_channel_schedule_task_now(connection->slot->channel, &shutdown_task->task);
    }
}
