#define MS_CLASS "RTC::RTP::Codecs::VP9"
// #define MS_LOG_DEV_LEVEL 3

#include "RTC/RTP/Codecs/VP9.hpp"
#include "Logger.hpp"

namespace RTC
{
	namespace RTP
	{
		namespace Codecs
		{
			/* Class methods. */

			VP9::PayloadDescriptor* VP9::Parse(const uint8_t* data, size_t len)
			{
				MS_TRACE();

				if (len < 1)
				{
					MS_WARN_DEV("ignoring empty payload");

					return nullptr;
				}

				std::unique_ptr<PayloadDescriptor> payloadDescriptor(new PayloadDescriptor());

				size_t offset{ 0 };
				uint8_t byte = data[offset];

				payloadDescriptor->i = (byte >> 7) & 0x01;
				payloadDescriptor->p = (byte >> 6) & 0x01;
				payloadDescriptor->l = (byte >> 5) & 0x01;
				payloadDescriptor->f = (byte >> 4) & 0x01;
				payloadDescriptor->b = (byte >> 3) & 0x01;
				payloadDescriptor->e = (byte >> 2) & 0x01;
				payloadDescriptor->v = (byte >> 1) & 0x01;

				if (payloadDescriptor->i)
				{
					if (len < ++offset + 1)
					{
						MS_WARN_DEV("ignoring invalid payload (1)");

						return nullptr;
					}

					byte = data[offset];

					if (byte >> 7 & 0x01)
					{
						if (len < ++offset + 1)
						{
							MS_WARN_DEV("ignoring invalid payload (2)");

							return nullptr;
						}

						payloadDescriptor->pictureId = (byte & 0x7F) << 8;
						payloadDescriptor->pictureId += data[offset];
						payloadDescriptor->hasTwoBytesPictureId = true;
					}
					else
					{
						payloadDescriptor->pictureId           = byte & 0x7F;
						payloadDescriptor->hasOneBytePictureId = true;
					}

					payloadDescriptor->hasPictureId = true;
				}

				if (payloadDescriptor->l)
				{
					if (len < ++offset + 1)
					{
						MS_WARN_DEV("ignoring invalid payload (3)");

						return nullptr;
					}

					byte = data[offset];

					payloadDescriptor->interLayerDependency = byte & 0x01;
					payloadDescriptor->switchingUpPoint     = byte >> 4 & 0x01;
					payloadDescriptor->slIndex              = byte >> 1 & 0x07;
					payloadDescriptor->tlIndex              = byte >> 5 & 0x07;
					payloadDescriptor->hasSlIndex           = true;
					payloadDescriptor->hasTlIndex           = true;

					if (len < ++offset + 1)
					{
						MS_WARN_DEV("ignoring invalid payload (4)");

						return nullptr;
					}

					// Read TL0PICIDX if flexible mode is unset.
					if (!payloadDescriptor->f)
					{
						payloadDescriptor->tl0PictureIndex    = data[offset];
						payloadDescriptor->hasTl0PictureIndex = true;
					}
				}

				if (!payloadDescriptor->p && payloadDescriptor->b && payloadDescriptor->slIndex == 0)
				{
					payloadDescriptor->isKeyFrame = true;
				}

				return payloadDescriptor.release();
			}

			void VP9::ProcessRtpPacket(RTP::Packet* packet)
			{
				MS_TRACE();

				auto* data = packet->GetPayload();
				auto len   = packet->GetPayloadLength();

				PayloadDescriptor* payloadDescriptor = VP9::Parse(data, len);

				if (!payloadDescriptor)
				{
					return;
				}

				if (payloadDescriptor->isKeyFrame)
				{
					MS_DEBUG_DEV(
					  "key frame [spatialLayer:%" PRIu8 ", temporalLayer:%" PRIu8 "]",
					  packet->GetSpatialLayer(),
					  packet->GetTemporalLayer());
				}

				auto* payloadDescriptorHandler = new PayloadDescriptorHandler(payloadDescriptor);

				packet->SetPayloadDescriptorHandler(payloadDescriptorHandler);
			}

			/* Instance methods. */

			void VP9::PayloadDescriptor::Dump(int indentation) const
			{
				MS_TRACE();

				MS_DUMP_CLEAN(indentation, "<VP9::PayloadDescriptor>");
				MS_DUMP_CLEAN(
				  indentation,
				  "  i:%" PRIu8 "|p:%" PRIu8 "|l:%" PRIu8 "|f:%" PRIu8 "|b:%" PRIu8 "|e:%" PRIu8 "|v:%" PRIu8,
				  this->i,
				  this->p,
				  this->l,
				  this->f,
				  this->b,
				  this->e,
				  this->v);
				MS_DUMP_CLEAN(indentation, "  pictureId: %" PRIu16, this->pictureId);
				MS_DUMP_CLEAN(indentation, "  slIndex: %" PRIu8, this->slIndex);
				MS_DUMP_CLEAN(indentation, "  tlIndex: %" PRIu8, this->tlIndex);
				MS_DUMP_CLEAN(indentation, "  tl0PictureIndex: %" PRIu8, this->tl0PictureIndex);
				MS_DUMP_CLEAN(indentation, "  interLayerDependency: %" PRIu8, this->interLayerDependency);
				MS_DUMP_CLEAN(indentation, "  switchingUpPoint: %" PRIu8, this->switchingUpPoint);
				MS_DUMP_CLEAN(indentation, "  isKeyFrame: %s", this->isKeyFrame ? "true" : "false");
				MS_DUMP_CLEAN(indentation, "  hasPictureId: %s", this->hasPictureId ? "true" : "false");
				MS_DUMP_CLEAN(
				  indentation, "  hasOneBytePictureId: %s", this->hasOneBytePictureId ? "true" : "false");
				MS_DUMP_CLEAN(
				  indentation, "  hasTwoBytesPictureId: %s", this->hasTwoBytesPictureId ? "true" : "false");
				MS_DUMP_CLEAN(
				  indentation, "  hasTl0PictureIndex: %s", this->hasTl0PictureIndex ? "true" : "false");
				MS_DUMP_CLEAN(indentation, "  hasSlIndex: %s", this->hasSlIndex ? "true" : "false");
				MS_DUMP_CLEAN(indentation, "  hasTlIndex: %s", this->hasTlIndex ? "true" : "false");
				MS_DUMP_CLEAN(indentation, "</VP9::PayloadDescriptor>");
			}

			VP9::PayloadDescriptorHandler::PayloadDescriptorHandler(VP9::PayloadDescriptor* payloadDescriptor)
			{
				MS_TRACE();

				this->payloadDescriptor.reset(payloadDescriptor);
			}

			bool VP9::PayloadDescriptorHandler::Process(
			  Codecs::EncodingContext* encodingContext, RTP::Packet* /*packet*/, bool& marker)
			{
				MS_TRACE();

				auto* context = static_cast<Codecs::VP9::EncodingContext*>(encodingContext);

				MS_ASSERT(context->GetTargetSpatialLayer() >= 0, "target spatial layer cannot be -1");
				MS_ASSERT(context->GetTargetTemporalLayer() >= 0, "target temporal layer cannot be -1");

				auto packetSpatialLayer  = GetSpatialLayer();
				auto packetTemporalLayer = GetTemporalLayer();
				auto tmpSpatialLayer     = context->GetCurrentSpatialLayer();
				auto tmpTemporalLayer    = context->GetCurrentTemporalLayer();

				// If packet spatial or temporal layer is higher than maximum announced
				// one, drop the packet.
				if (packetSpatialLayer >= context->GetSpatialLayers() || packetTemporalLayer >= context->GetTemporalLayers())
				{
					MS_WARN_TAG(
					  rtp, "too high packet layers %" PRIu8 ":%" PRIu8, packetSpatialLayer, packetTemporalLayer);

					return false;
				}

				// Check whether pictureId sync is required.
				if (context->syncRequired && this->payloadDescriptor->hasPictureId)
				{
					context->pictureIdManager.Sync(this->payloadDescriptor->pictureId - 1);

					context->syncRequired = false;
				}

				const bool isOldPacket =
				  (this->payloadDescriptor->hasPictureId &&
				   RTC::SeqManager<uint16_t, 15>::IsSeqLowerThan(
				     this->payloadDescriptor->pictureId, context->pictureIdManager.GetMaxInput()));

				if (!isOldPacket)
				{
					// Upgrade current spatial layer if needed.
					if (context->GetTargetSpatialLayer() > context->GetCurrentSpatialLayer())
					{
						if (this->payloadDescriptor->isKeyFrame)
						{
							MS_DEBUG_DEV(
							  "upgrading tmpSpatialLayer from %" PRIu16 " to %" PRIu16 " (packet:%" PRIu8
							  ":%" PRIu8 ")",
							  context->GetCurrentSpatialLayer(),
							  context->GetTargetSpatialLayer(),
							  packetSpatialLayer,
							  packetTemporalLayer);

							tmpSpatialLayer  = context->GetTargetSpatialLayer();
							tmpTemporalLayer = 0; // Just in case.
						}
					}
					// Downgrade current spatial layer if needed.
					else if (context->GetTargetSpatialLayer() < context->GetCurrentSpatialLayer())
					{
						// In K-SVC we must wait for a keyframe.
						if (context->IsKSvc())
						{
							if (this->payloadDescriptor->isKeyFrame)
							{
								MS_DEBUG_DEV(
								  "downgrading tmpSpatialLayer from %" PRIu16 " to %" PRIu16 " (packet:%" PRIu8
								  ":%" PRIu8 ") after keyframe (K-SVC)",
								  context->GetCurrentSpatialLayer(),
								  context->GetTargetSpatialLayer(),
								  packetSpatialLayer,
								  packetTemporalLayer);

								tmpSpatialLayer  = context->GetTargetSpatialLayer();
								tmpTemporalLayer = 0; // Just in case.
							}
						}
						// In full SVC we do not need a keyframe.
						else
						{
							if (packetSpatialLayer == context->GetTargetSpatialLayer() && this->payloadDescriptor->e)
							{
								MS_DEBUG_DEV(
								  "downgrading tmpSpatialLayer from %" PRIu16 " to %" PRIu16 " (packet:%" PRIu8
								  ":%" PRIu8 ") without keyframe (full SVC)",
								  context->GetCurrentSpatialLayer(),
								  context->GetTargetSpatialLayer(),
								  packetSpatialLayer,
								  packetTemporalLayer);

								tmpSpatialLayer  = context->GetTargetSpatialLayer();
								tmpTemporalLayer = 0; // Just in case.
							}
						}
					}
				}

				// Filter spatial layers that are either
				// * higher than current one
				// * different than the current one when KSVC is enabled and this is not a keyframe
				// (interframe p bit = 1)
				const uint16_t spatialLayerForPictureId =
				  isOldPacket ? context->GetSpatialLayerForPictureId(this->payloadDescriptor->pictureId)
				              : tmpSpatialLayer;

				if (
				  packetSpatialLayer > spatialLayerForPictureId ||
				  (context->IsKSvc() && this->payloadDescriptor->p &&
				   packetSpatialLayer != spatialLayerForPictureId))
				{
					return false;
				}

				// Check and handle temporal layer (unless old packet).
				if (!isOldPacket)
				{
					// Upgrade current temporal layer if needed.
					if (context->GetTargetTemporalLayer() > context->GetCurrentTemporalLayer())
					{
						if (
						  packetTemporalLayer >= context->GetCurrentTemporalLayer() + 1 &&
						  (context->GetCurrentTemporalLayer() == -1 || this->payloadDescriptor->switchingUpPoint) &&
						  this->payloadDescriptor->b)
						{
							MS_DEBUG_DEV(
							  "upgrading tmpTemporalLayer from %" PRIu16 " to %" PRIu8 " (packet:%" PRIu8
							  ":%" PRIu8 ")",
							  context->GetCurrentTemporalLayer(),
							  packetTemporalLayer,
							  packetSpatialLayer,
							  packetTemporalLayer);

							tmpTemporalLayer = packetTemporalLayer;
						}
					}
					// Downgrade current temporal layer if needed.
					else if (context->GetTargetTemporalLayer() < context->GetCurrentTemporalLayer())
					{
						if (packetTemporalLayer == context->GetTargetTemporalLayer() && this->payloadDescriptor->e)
						{
							MS_DEBUG_DEV(
							  "downgrading tmpTemporalLayer from %" PRIu16 " to %" PRIu16 " (packet:%" PRIu8
							  ":%" PRIu8 ")",
							  context->GetCurrentTemporalLayer(),
							  context->GetTargetTemporalLayer(),
							  packetSpatialLayer,
							  packetTemporalLayer);

							tmpTemporalLayer = context->GetTargetTemporalLayer();
						}
					}
				}

				// Filter temporal layers higher than current one.
				const uint16_t temporalLayerForPictureId =
				  isOldPacket ? context->GetTemporalLayerForPictureId(this->payloadDescriptor->pictureId)
				              : tmpTemporalLayer;

				if (packetTemporalLayer > temporalLayerForPictureId)
				{
					return false;
				}

				// Set marker bit if needed.
				if (packetSpatialLayer == tmpSpatialLayer && this->payloadDescriptor->e)
				{
					marker = true;
				}

				// Update the pictureId manager.
				if (this->payloadDescriptor->hasPictureId)
				{
					uint16_t pictureId;

					context->pictureIdManager.Input(this->payloadDescriptor->pictureId, pictureId);
				}

				// Update current spatial layer if needed.
				if (tmpSpatialLayer != context->GetCurrentSpatialLayer())
				{
					context->SetCurrentSpatialLayer(tmpSpatialLayer, this->payloadDescriptor->pictureId);
				}

				// Update current temporal layer if needed.
				if (tmpTemporalLayer != context->GetCurrentTemporalLayer())
				{
					context->SetCurrentTemporalLayer(tmpTemporalLayer, this->payloadDescriptor->pictureId);
				}

				return true;
			}
		} // namespace Codecs
	} // namespace RTP
} // namespace RTC
