#define MS_CLASS "RTC::RTP::Codecs::VP8"
// #define MS_LOG_DEV_LEVEL 3

#include "RTC/RTP/Codecs/VP8.hpp"
#include "Logger.hpp"
#include <cstring> // std::memcpy()

namespace RTC
{
	namespace RTP
	{
		namespace Codecs
		{
			/* Class methods. */

			VP8::PayloadDescriptor* VP8::Parse(const uint8_t* data, size_t len)
			{
				MS_TRACE();

				if (len < 1)
				{
					MS_WARN_DEV("ignoring empty payload");

					return nullptr;
				}

				std::unique_ptr<PayloadDescriptor> payloadDescriptor(new PayloadDescriptor());

				size_t offset{ 0 };
				uint8_t byte = data[offset];

				payloadDescriptor->extended       = (byte >> 7) & 0x01;
				payloadDescriptor->nonReference   = (byte >> 5) & 0x01;
				payloadDescriptor->start          = (byte >> 4) & 0x01;
				payloadDescriptor->partitionIndex = byte & 0x07;

				if (payloadDescriptor->extended)
				{
					if (len < ++offset + 1)
					{
						MS_WARN_DEV("ignoring invalid payload (2)");

						return nullptr;
					}

					byte = data[offset];

					payloadDescriptor->i = (byte >> 7) & 0x01;
					payloadDescriptor->l = (byte >> 6) & 0x01;
					payloadDescriptor->t = (byte >> 5) & 0x01;
					payloadDescriptor->k = (byte >> 4) & 0x01;
				}

				if (payloadDescriptor->i)
				{
					if (len < ++offset + 1)
					{
						MS_WARN_DEV("ignoring invalid payload (3)");

						return nullptr;
					}

					byte = data[offset];

					if ((byte >> 7) & 0x01)
					{
						if (len < ++offset + 1)
						{
							MS_WARN_DEV("ignoring invalid payload (4)");

							return nullptr;
						}

						payloadDescriptor->hasTwoBytesPictureId = true;
						payloadDescriptor->pictureId            = (byte & 0x7F) << 8;
						payloadDescriptor->pictureId += data[offset];
					}
					else
					{
						payloadDescriptor->hasOneBytePictureId = true;
						payloadDescriptor->pictureId           = byte & 0x7F;
					}

					payloadDescriptor->hasPictureId = true;
				}

				if (payloadDescriptor->l)
				{
					if (len < ++offset + 1)
					{
						MS_WARN_DEV("ignoring invalid payload (5)");

						return nullptr;
					}

					payloadDescriptor->hasTl0PictureIndex = true;
					payloadDescriptor->tl0PictureIndex    = data[offset];
				}

				if (payloadDescriptor->t || payloadDescriptor->k)
				{
					if (len < ++offset + 1)
					{
						MS_WARN_DEV("ignoring invalid payload (6)");

						return nullptr;
					}

					byte = data[offset];

					payloadDescriptor->hasTlIndex = true;
					payloadDescriptor->tlIndex    = (byte >> 6) & 0x03;
					payloadDescriptor->y          = (byte >> 5) & 0x01;
					payloadDescriptor->keyIndex   = byte & 0x1F;
				}

				if (
				  // NOLINTNEXTLINE(bugprone-inc-dec-in-conditions)
				  (len >= ++offset + 1) && payloadDescriptor->start &&
				  payloadDescriptor->partitionIndex == 0 && (!(data[offset] & 0x01)) // Inverse Keyframe bit.
				)
				{
					payloadDescriptor->isKeyFrame = true;
				}

				return payloadDescriptor.release();
			}

			void VP8::ProcessRtpPacket(RTP::Packet* packet)
			{
				MS_TRACE();

				auto* data = packet->GetPayload();
				auto len   = packet->GetPayloadLength();

				PayloadDescriptor* payloadDescriptor = VP8::Parse(data, len);

				if (!payloadDescriptor)
				{
					return;
				}

				auto* payloadDescriptorHandler = new PayloadDescriptorHandler(payloadDescriptor);

				packet->SetPayloadDescriptorHandler(payloadDescriptorHandler);

				// Modify the RtpPacket payload in order to always have two byte pictureId.
				if (payloadDescriptor->hasOneBytePictureId)
				{
					// Shift the RTP payload one byte from the begining of the pictureId field.
					packet->ShiftPayload(2, 1);

					// Set the two byte pictureId marker bit.
					data[2] = 0x80;

					// Update the payloadDescriptor.
					payloadDescriptor->hasOneBytePictureId  = false;
					payloadDescriptor->hasTwoBytesPictureId = true;
				}
			}

			/* Instance methods. */

			void VP8::PayloadDescriptor::Dump(int indentation) const
			{
				MS_TRACE();

				MS_DUMP_CLEAN(indentation, "<VP8::PayloadDescriptor>");
				MS_DUMP_CLEAN(
				  indentation,
				  "  i:%" PRIu8 "|l:%" PRIu8 "|t:%" PRIu8 "|k:%" PRIu8,
				  this->i,
				  this->l,
				  this->t,
				  this->k);
				MS_DUMP_CLEAN(indentation, "  extended: %" PRIu8, this->extended);
				MS_DUMP_CLEAN(indentation, "  nonReference: %" PRIu8, this->nonReference);
				MS_DUMP_CLEAN(indentation, "  start: %" PRIu8, this->start);
				MS_DUMP_CLEAN(indentation, "  partitionIndex: %" PRIu8, this->partitionIndex);
				MS_DUMP_CLEAN(indentation, "  pictureId: %" PRIu16, this->pictureId);
				MS_DUMP_CLEAN(indentation, "  tl0PictureIndex: %" PRIu8, this->tl0PictureIndex);
				MS_DUMP_CLEAN(indentation, "  tlIndex: %" PRIu8, this->tlIndex);
				MS_DUMP_CLEAN(indentation, "  y: %" PRIu8, this->y);
				MS_DUMP_CLEAN(indentation, "  keyIndex: %" PRIu8, this->keyIndex);
				MS_DUMP_CLEAN(indentation, "  isKeyFrame: %s", this->isKeyFrame ? "true" : "false");
				MS_DUMP_CLEAN(indentation, "  hasPictureId: %s", this->hasPictureId ? "true" : "false");
				MS_DUMP_CLEAN(
				  indentation, "  hasOneBytePictureId: %s", this->hasOneBytePictureId ? "true" : "false");
				MS_DUMP_CLEAN(
				  indentation, "  hasTwoBytesPictureId: %s", this->hasTwoBytesPictureId ? "true" : "false");
				MS_DUMP_CLEAN(
				  indentation, "  hasTl0PictureIndex: %s", this->hasTl0PictureIndex ? "true" : "false");
				MS_DUMP_CLEAN(indentation, "  hasTlIndex: %s", this->hasTlIndex ? "true" : "false");
				MS_DUMP_CLEAN(indentation, "</VP8::PayloadDescriptor>");
			}

			void VP8::PayloadDescriptor::Encode(uint8_t* data, uint16_t pictureId, uint8_t tl0PictureIndex) const
			{
				MS_TRACE();

				if (data == nullptr)
				{
					MS_ABORT("data is nullptr");
				}

				// Nothing to do.
				if (!this->extended)
				{
					return;
				}

				data += 2;

				if (this->i)
				{
					if (this->hasTwoBytesPictureId)
					{
						uint16_t netPictureId = htons(pictureId);

						std::memcpy(data, &netPictureId, 2);
						data[0] |= 0x80;
						data += 2;
					}
					else if (this->hasOneBytePictureId)
					{
						*data = pictureId;
						data++;

						if (pictureId > 127)
						{
							MS_DEBUG_TAG(rtp, "casting pictureId value to one byte");
						}
					}
				}

				if (this->l)
				{
					*data = tl0PictureIndex;
				}
			}

			void VP8::PayloadDescriptor::Encode(uint8_t* data) const
			{
				MS_TRACE();

				if (this->encoder == std::nullopt)
				{
					return;
				}

				this->encoder->Encode(data, this);
			}

			void VP8::PayloadDescriptor::Restore(uint8_t* data) const
			{
				MS_TRACE();

				if (this->hasPictureId && this->hasTl0PictureIndex)
				{
					Encode(data, this->pictureId, this->tl0PictureIndex);
				}
			}

			void VP8::PayloadDescriptor::Encoder::Encode(
			  uint8_t* data, const PayloadDescriptor* payloadDescriptor) const
			{
				payloadDescriptor->Encode(
				  data, this->encodingData.pictureId, this->encodingData.tl0PictureIndex);
			}

			VP8::PayloadDescriptorHandler::PayloadDescriptorHandler(VP8::PayloadDescriptor* payloadDescriptor)
			{
				MS_TRACE();

				this->payloadDescriptor.reset(payloadDescriptor);
			}

			bool VP8::PayloadDescriptorHandler::Process(
			  Codecs::EncodingContext* encodingContext, RTP::Packet* packet, bool& /*marker*/)
			{
				MS_TRACE();

				auto* context = static_cast<Codecs::VP8::EncodingContext*>(encodingContext);

				MS_ASSERT(context->GetTargetTemporalLayer() >= 0, "target temporal layer cannot be -1");

				// Check if the payload should contain temporal layer info.
				if (context->GetTemporalLayers() > 1 && !this->payloadDescriptor->hasTlIndex)
				{
					MS_WARN_DEV("stream is supposed to have >1 temporal layers but does not have TlIndex field");
				}

				// Check whether pictureId and tl0PictureIndex sync is required.
				if (
				  context->syncRequired && this->payloadDescriptor->hasPictureId &&
				  this->payloadDescriptor->hasTl0PictureIndex)
				{
					context->pictureIdManager.Sync(this->payloadDescriptor->pictureId - 1);
					context->tl0PictureIndexManager.Sync(this->payloadDescriptor->tl0PictureIndex - 1);

					context->syncRequired = false;
				}

				// Incremental pictureId. Check the temporal layer.
				if (
				  this->payloadDescriptor->hasPictureId && this->payloadDescriptor->hasTlIndex &&
				  this->payloadDescriptor->hasTl0PictureIndex &&
				  !RTC::SeqManager<uint16_t, 15>::IsSeqLowerThan(
				    this->payloadDescriptor->pictureId, context->pictureIdManager.GetMaxInput()))
				{
					// Drop if:
					// - Temporal layer is higher than target.
					// - Temporal layer is higher than current and sync flag is not set.
					if (
					  this->payloadDescriptor->tlIndex > context->GetTargetTemporalLayer() ||
					  (this->payloadDescriptor->tlIndex > context->GetCurrentTemporalLayer() &&
					   !this->payloadDescriptor->y))
					{
						context->pictureIdManager.Drop(this->payloadDescriptor->pictureId);

						if (this->payloadDescriptor->tlIndex == 0)
						{
							context->tl0PictureIndexManager.Drop(this->payloadDescriptor->tl0PictureIndex);
						}

						return false;
					}
				}

				// Update pictureId and tl0PictureIndex values.
				uint16_t pictureId;
				uint8_t tl0PictureIndex;

				// Do not send a dropped pictureId.
				if (
				  this->payloadDescriptor->hasPictureId &&
				  !context->pictureIdManager.Input(this->payloadDescriptor->pictureId, pictureId))
				{
					return false;
				}

				// Do not send a dropped tl0PictureIndex.
				if (
				  this->payloadDescriptor->hasTl0PictureIndex &&
				  !context->tl0PictureIndexManager.Input(
				    this->payloadDescriptor->tl0PictureIndex, tl0PictureIndex))
				{
					return false;
				}

				// Update/fix current temporal layer.
				if (this->payloadDescriptor->hasTlIndex && this->payloadDescriptor->tlIndex == context->GetTargetTemporalLayer())
				{
					context->SetCurrentTemporalLayer(this->payloadDescriptor->tlIndex);
				}
				else if (!this->payloadDescriptor->hasTlIndex)
				{
					context->SetCurrentTemporalLayer(0);
				}

				if (context->GetCurrentTemporalLayer() > context->GetTargetTemporalLayer())
				{
					context->SetCurrentTemporalLayer(context->GetTargetTemporalLayer());
				}

				// Do not send tlIndex higher than current one.
				if (this->payloadDescriptor->tlIndex > context->GetCurrentTemporalLayer())
				{
					return false;
				}

				if (this->payloadDescriptor->hasPictureId && this->payloadDescriptor->hasTl0PictureIndex)
				{
					// Store the encoding data for retransmissions.
					this->payloadDescriptor->CreateEncoder({ pictureId, tl0PictureIndex });
					this->payloadDescriptor->Encode(packet->GetPayload());
				}

				return true;
			};

			void VP8::PayloadDescriptorHandler::Encode(
			  RTP::Packet* packet, Codecs::PayloadDescriptor::Encoder* encoder)
			{
				MS_TRACE();

				auto* vp8Encoder = static_cast<VP8::PayloadDescriptor::Encoder*>(encoder);

				vp8Encoder->Encode(packet->GetPayload(), this->payloadDescriptor.get());
			}

			void VP8::PayloadDescriptorHandler::Restore(RTP::Packet* packet)
			{
				MS_TRACE();

				if (this->payloadDescriptor->hasPictureId && this->payloadDescriptor->hasTl0PictureIndex)
				{
					this->payloadDescriptor->Restore(packet->GetPayload());
				}
			}
		} // namespace Codecs
	} // namespace RTP
} // namespace RTC
