from __future__ import annotations
import concurrent.futures
import os
import traceback
from typing import Any, Dict, List, Tuple

import numpy as np
import scrypted_sdk
from scrypted_sdk.other import SettingValue
from scrypted_sdk.types import (
    MediaObject,
    ObjectDetection,
    ObjectDetectionGeneratorSession,
    ObjectDetectionModel,
    ObjectDetectionSession,
    ObjectsDetected,
    Setting,
)

# Try to import TensorFlow Lite with fallback options
try:
    # First try tensorflow-lite-runtime (preferred for deployment)
    import tflite_runtime.interpreter as tflite

    tf_available = True
    tf_lite_runtime = True
    print("✅ Using tensorflow-lite-runtime")
except ImportError:
    try:
        # Fallback to full tensorflow
        import tensorflow as tf

        tf_available = True
        tf_lite_runtime = False
        print("✅ Using tensorflow (full)")
    except ImportError:
        tf_available = False
        tf_lite_runtime = False
        print(
            "❌ TensorFlow not available. Please install tensorflow-lite-runtime or tensorflow"
        )


predictExecutor = concurrent.futures.ThreadPoolExecutor(1, "YAMNet-Predict")
prepareExecutor = concurrent.futures.ThreadPoolExecutor(1, "YAMNet-Prepare")


class YAMNetPlugin(
    scrypted_sdk.ScryptedDeviceBase,
    scrypted_sdk.Settings,
    scrypted_sdk.DeviceProvider,
    ObjectDetection,
):
    def __init__(self, nativeId: str | None = None):
        super().__init__(nativeId=nativeId)

        self.debug_log = bool(self.storage.getItem("debug_log"))

        # Load the TensorFlow Lite model
        if tf_available:
            model_path = os.path.join(
                os.path.dirname(os.path.dirname(__file__)), "model", "yamnet.tflite"
            )
            if tf_lite_runtime:
                # Use tensorflow-lite-runtime
                self.interpreter = tflite.Interpreter(model_path=model_path)
            else:
                # Use full tensorflow
                self.interpreter = tf.lite.Interpreter(model_path=model_path)

            self.interpreter.allocate_tensors()

            # Get input and output details
            self.input_details = self.interpreter.get_input_details()
            self.output_details = self.interpreter.get_output_details()

            print("YAMNet model loaded successfully")
            print(f"Input shape: {self.input_details[0]['shape']}")
            print(f"Input type: {self.input_details[0]['dtype']}")
        else:
            print(
                "TensorFlow not available. Please install tensorflow-lite-runtime or tensorflow."
            )
            self.interpreter = None

        self.sample_rate = 16000  # YAMNet expects 16kHz audio
        self.frame_duration = 0.975  # Each frame is 0.96 seconds

    def get_input_format(self) -> str:
        """Return audio input format"""
        return "audio/pcm"

    def getModelSettings(self, settings: Any = None) -> List[Setting]:
        """Return model-specific settings"""
        return []

    async def getDetectionModel(self, settings: Any = None) -> ObjectDetectionModel:
        """Return model information for ObjectDetection interface"""
        reduced_map_path = os.path.join(
            os.path.dirname(os.path.dirname(__file__)),
            "model",
            "yamnet_map_reduced.txt",
        )
        with open(reduced_map_path, "r") as f:
            reduced_classes = sorted(set(line.strip() for line in f if line.strip()))
            
        model: ObjectDetectionModel = {
            "name": "YAMNet Audio Classification",
            "classes": reduced_classes,
            # "inputSize": self.interpreter.get_input_details(),
            "inputFormat": self.get_input_format(),
            "settings": self.getModelSettings(settings),
        }
        return model

    def get_audio_tf(self):
        if not hasattr(self, "_audio_tf"):
            self._audio_tf = YAMNetAudioTf()
        return self._audio_tf

    async def run_detection_audio(
        self, audio_buffer, detection_session: ObjectDetectionSession = None
    ) -> ObjectsDetected:
        try:
            self.debug(
                f"🔍 Debug: run_detection_audio - audio_buffer type = {type(audio_buffer)}"
            )
            self.debug(
                f"🔍 Debug: run_detection_audio - audio_buffer length = {len(audio_buffer)}"
            )
            audio_tf = self.get_audio_tf()
            audio_data = audio_buffer
            audio_data = np.asarray(audio_data, dtype=np.float32)  # Conversione sicura!
            self.debug(
                f"🔍 Debug: audio_data shape = {getattr(audio_data, 'shape', 'N/A')}, dtype = {getattr(audio_data, 'dtype', 'N/A')}"
            )
            try:
                self.debug(f"🔍 Debug: chiamata YAMNetAudioTf.detect")
                detections = audio_tf.detect(audio_data)
                self.debug(
                    f"🔍 Debug: YAMNetAudioTf.detect restituisce {len(detections)} detection"
                )
                merged = {}
                for macro_class, conf, bbox in detections:
                    if macro_class in merged:
                        merged[macro_class]["score"] += float(conf)
                        if float(conf) > merged[macro_class]["score_max"]:
                            merged[macro_class]["boundingBox"] = tuple(
                                float(x) for x in bbox
                            )
                            merged[macro_class]["score_max"] = float(conf)
                    else:
                        merged[macro_class] = {
                            "className": macro_class,
                            "score": float(conf),
                            "boundingBox": tuple(float(x) for x in bbox),
                            "score_max": float(conf),
                        }
                objects = [
                    {k: v for k, v in d.items() if k != "score_max"}
                    for d in merged.values()
                ]
            except Exception as e:
                print(f"Errore in YAMNetAudioTf.detect: {e}")
                traceback.print_exc()
                objects = []
            audio_duration = len(audio_data) / self.sample_rate
            return {
                "detections": objects,
                "inputDimensions": (
                    self.sample_rate,
                    int(audio_duration * self.sample_rate),
                ),
            }
        except Exception as e:
            print(f"Error in audio detection: {e}")
            traceback.print_exc()
            return {"detections": [], "inputDimensions": (self.sample_rate, 0)}

    async def generateObjectDetections(
        self, audioFrames: Any, session: ObjectDetectionGeneratorSession = None
    ) -> Any:
        """Generate object detections from audio stream"""

    async def detectObjects(
        self, mediaObject: MediaObject, session: ObjectDetectionSession = None
    ) -> ObjectsDetected:
        """Main detection method for ObjectDetection interface"""
        try:
            self.debug(f"🔍 Debug: mediaObject type = {type(mediaObject)}")
            self.debug(
                f"🔍 Debug: mediaObject mimeType = {getattr(mediaObject, 'mimeType', 'N/A')}"
            )

            audio_buffer = await scrypted_sdk.mediaManager.convertMediaObjectToBuffer(
                mediaObject, "*/*"
            )

            self.debug(f"🔍 Debug: audio_buffer type = {type(audio_buffer)}")
            self.debug(
                f"🔍 Debug: audio_buffer length = {len(audio_buffer) if hasattr(audio_buffer, '__len__') else 'N/A'}"
            )
            self.debug(
                f"🔍 Debug: audio_buffer first 100 bytes = {audio_buffer[:100] if len(audio_buffer) > 0 else 'Empty'}"
            )

            return await self.run_detection_audio(audio_buffer, session)

        except Exception as e:
            print(f"Error detecting objects from media: {e}")
            traceback.print_exc()
            return {"detections": [], "inputDimensions": (self.sample_rate, 0)}

    # RIMOSSA preprocess_audio: ora l'audio viene passato direttamente

    def run_inference(
        self, audio_data: np.ndarray
    ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
        """
        Run YAMNet inference on preprocessed audio data
        Returns: (scores, embeddings, spectrogram)
        """
        if not tf_available or self.interpreter is None:
            raise RuntimeError("TensorFlow Lite interpreter not available")

        try:
            self.debug(f"🔍 Debug: run_inference - input shape = {audio_data.shape}")
            self.debug(
                f"🔍 Debug: run_inference - output_details count = {len(self.output_details)}"
            )

            self.interpreter.set_tensor(self.input_details[0]["index"], audio_data)
            self.interpreter.invoke()

            if len(self.output_details) >= 3:
                scores = self.interpreter.get_tensor(self.output_details[0]["index"])
                embeddings = self.interpreter.get_tensor(
                    self.output_details[1]["index"]
                )
                spectrogram = self.interpreter.get_tensor(
                    self.output_details[2]["index"]
                )
            elif len(self.output_details) == 2:
                scores = self.interpreter.get_tensor(self.output_details[0]["index"])
                embeddings = self.interpreter.get_tensor(
                    self.output_details[1]["index"]
                )
                spectrogram = np.zeros((1, 1))
            elif len(self.output_details) == 1:
                scores = self.interpreter.get_tensor(self.output_details[0]["index"])
                embeddings = np.zeros((1, 1))
                spectrogram = np.zeros((1, 1))
            else:
                raise ValueError(
                    f"Unexpected number of outputs: {len(self.output_details)}"
                )

            self.debug(f"🔍 Debug: run_inference - scores shape = {scores.shape}")
            self.debug(
                f"🔍 Debug: run_inference - embeddings shape = {embeddings.shape}"
            )
            self.debug(
                f"🔍 Debug: run_inference - spectrogram shape = {spectrogram.shape}"
            )

            return scores, embeddings, spectrogram
        except Exception as e:
            print(f"Error running inference: {e}")
            traceback.print_exc()
            raise

    async def getSettings(self) -> List[Setting]:
        """Get plugin settings"""
        return [
            {
                "title": "Debug Log",
                "description": "Abilita log dettagliati di debug per la classificazione audio",
                "type": "boolean",
                "value": bool(getattr(self, "debug_log", False)),
                "key": "debug_log",
                "immediate": True,
            },
        ]

    async def putSetting(self, key: str, value: SettingValue):
        """Handle setting changes"""
        self.storage.setItem(key, value)
        # Aggiorna la proprietà debug_log se la chiave è quella giusta
        if key == "debug_log":
            self.debug_log = bool(value)
        await self.onDeviceEvent(scrypted_sdk.ScryptedInterface.Settings.value, None)

    def debug(self, *args, **kwargs):
        if getattr(self, "debug_log", True):
            print(*args, **kwargs)

    async def getDevice(self, nativeId: str) -> Any:
        """Get device instance"""
        # For now, just return self as we don't have sub-devices
        return self

    def classify_samples(self, audio_buffer: bytes) -> Dict[str, Any]:
        """
        Public method for classifying audio samples
        This is the main entry point that other parts of Scrypted can call
        """
        return self.classify_audio_buffer(audio_buffer)

    async def classify_samples_async(self, audio_buffer: bytes) -> Dict[str, Any]:
        """
        Async version of classify_samples
        """
        return await self.classify_audio_async(audio_buffer)


class YAMNetAudioTf:
    def __init__(self, threshold=0.01):
        self.yamnet = YAMNetPlugin()

        self.labels = self.load_map()
        self.threshold = threshold
        self._setup_interpreter()

    def _setup_interpreter(self):
        self.interpreter = self.yamnet.interpreter
        self.tensor_input_details = self.interpreter.get_input_details()
        self.tensor_output_details = self.interpreter.get_output_details()

    def load_map(self):
        map_path = os.path.join(
            os.path.dirname(os.path.dirname(__file__)),
            "model",
            "yamnet_map_reduced.txt",
        )

        with open(map_path, "r", encoding="utf-8") as f:
            labels = {index: "unknown" for index in range(521)}
            lines = f.readlines()
            if not lines:
                return {}

            if lines[0].split(" ", maxsplit=1)[0].isdigit():
                pairs = [line.split(" ", maxsplit=1) for line in lines]
                labels.update({int(index): label.strip() for index, label in pairs})
            else:
                labels.update({index: line.strip() for index, line in enumerate(lines)})
            return labels

    def _detect_raw(self, tensor_input):
        self.interpreter.set_tensor(self.tensor_input_details[0]["index"], tensor_input)
        self.interpreter.invoke()
        detections = np.zeros((20, 6), np.float32)

        res = self.interpreter.get_tensor(self.tensor_output_details[0]["index"])[0]
        non_zero_indices = res > 0
        class_ids = np.argpartition(-res, 20)[:20]
        class_ids = class_ids[np.argsort(-res[class_ids])]
        class_ids = class_ids[non_zero_indices[class_ids]]
        scores = res[class_ids]
        boxes = np.full((scores.shape[0], 4), -1, np.float32)
        count = len(scores)

        for i in range(count):
            if scores[i] < self.threshold or i == 20:
                break
            detections[i] = [
                class_ids[i],
                float(scores[i]),
                boxes[i][0],
                boxes[i][1],
                boxes[i][2],
                boxes[i][3],
            ]
        return detections

    def detect(self, tensor_input, threshold=None):
        detections = []
        raw_detections = self._detect_raw(tensor_input)
        threshold = threshold if threshold is not None else self.threshold
        for d in raw_detections:
            if d[1] < threshold:
                break
            detections.append(
                (self.labels[int(d[0])], float(d[1]), (d[2], d[3], d[4], d[5]))
            )
        return detections
