import Foundation

/// Diarization segment result from pyannote processing
struct DiarizationSegment {
    let startTimeMs: Int64
    let endTimeMs: Int64
    let speakerId: Int
}

/// Configuration for pyannote-based speaker segmentation
struct PyannoteSegmentationConfig {
    /// Path to pyannote segmentation model (model.onnx)
    var segmentationModelPath: String = ""
    /// Path to speaker embedding model for clustering
    var embeddingModelPath: String = ""
    /// Number of speakers (0 = auto-detect using threshold)
    var numSpeakers: Int = 0
    /// Distance threshold for clustering (used when numSpeakers = 0). Smaller = more clusters.
    var clusteringThreshold: Float = STTConstants.defaultClusteringThreshold
    /// Minimum duration for speech segment (seconds)
    var minDurationOn: Float = STTConstants.defaultMinDurationOn
    /// Minimum duration for silence segment (seconds)
    var minDurationOff: Float = STTConstants.defaultMinDurationOff
    /// Number of threads for inference
    var numThreads: Int = STTConstants.defaultAuxiliaryNumThreads
    /// ONNX provider: "cpu", "coreml", etc.
    var provider: String = STTConstants.defaultProvider
}

/// Protocol for pyannote event emission
protocol PyannoteSegmentationDelegate: AnyObject {
    func sendMultiSpeakerDetected(speakers: [Int], windowPairs: [(Int, Int)])
}

/// Manages speaker diarization using sherpa-onnx's native pyannote segmentation pipeline.
///
/// This is a post-processing approach that processes complete audio segments
/// and returns speaker assignments for each detected speech segment.
///
/// Key features:
/// - Native sherpa-onnx OfflineSpeakerDiarization integration
/// - Automatic speaker count detection (via clustering threshold)
/// - Fixed speaker count support (via numSpeakers parameter)
/// - Multi-speaker event emission
///
/// Thread-safe: uses dedicated dispatch queue.
class PyannoteSegmentationManager {
    private weak var delegate: PyannoteSegmentationDelegate?

    // Native diarizer handle
    private var diarizer: OpaquePointer?

    // Thread safety
    private let processingQueue = DispatchQueue(label: "com.stt.pyannote", qos: .userInitiated)
    private let lock = NSLock()

    // Configuration
    private var config: PyannoteSegmentationConfig = PyannoteSegmentationConfig()

    var isEnabled: Bool {
        lock.lock()
        defer { lock.unlock() }
        return diarizer != nil
    }

    init(delegate: PyannoteSegmentationDelegate?) {
        self.delegate = delegate
    }

    /// Initialize the pyannote diarization pipeline.
    /// - Parameters:
    ///   - segmentationModelPath: Path to pyannote segmentation model
    ///   - embeddingModelPath: Path to speaker embedding model
    ///   - segmentationConfig: Optional configuration (uses defaults if nil)
    /// - Returns: true if successful
    func initialize(
        segmentationModelPath: String,
        embeddingModelPath: String,
        segmentationConfig: PyannoteSegmentationConfig? = nil
    ) -> Bool {
        lock.lock()
        defer { lock.unlock() }

        if let cfg = segmentationConfig {
            self.config = cfg
        }
        config.segmentationModelPath = segmentationModelPath
        config.embeddingModelPath = embeddingModelPath

        // Validate model paths
        let segModelExists = FileManager.default.fileExists(atPath: segmentationModelPath)
        let embModelExists = FileManager.default.fileExists(atPath: embeddingModelPath)

        if !segModelExists {
            NSLog("PyannoteSegmentation: Segmentation model not found: \(segmentationModelPath)")
            return false
        }
        if !embModelExists {
            NSLog("PyannoteSegmentation: Embedding model not found: \(embeddingModelPath)")
            return false
        }

        let startTime = Date()

        // Create configuration structs
        // Note: The actual implementation depends on sherpa-onnx iOS bindings
        // This is a placeholder structure that matches the expected API
        var segmentationConfig = SherpaOnnxOfflineSpeakerSegmentationModelConfig()
        segmentationModelPath.withCString { ptr in
            segmentationConfig.pyannote.model = ptr
        }
        segmentationConfig.num_threads = Int32(config.numThreads)
        segmentationConfig.debug = 0
        config.provider.withCString { ptr in
            segmentationConfig.provider = ptr
        }

        var embeddingConfig = SherpaOnnxSpeakerEmbeddingExtractorConfig()
        embeddingModelPath.withCString { ptr in
            embeddingConfig.model = ptr
        }
        embeddingConfig.num_threads = Int32(config.numThreads)
        embeddingConfig.debug = 0
        config.provider.withCString { ptr in
            embeddingConfig.provider = ptr
        }

        var clusteringConfig = SherpaOnnxFastClusteringConfig()
        clusteringConfig.num_clusters = Int32(config.numSpeakers)
        clusteringConfig.threshold = config.clusteringThreshold

        var diarizationConfig = SherpaOnnxOfflineSpeakerDiarizationConfig()
        diarizationConfig.segmentation = segmentationConfig
        diarizationConfig.embedding = embeddingConfig
        diarizationConfig.clustering = clusteringConfig
        diarizationConfig.min_duration_on = config.minDurationOn
        diarizationConfig.min_duration_off = config.minDurationOff

        // Create diarizer
        diarizer = SherpaOnnxCreateOfflineSpeakerDiarization(&diarizationConfig)

        guard diarizer != nil else {
            NSLog("PyannoteSegmentation: Failed to create diarizer")
            return false
        }

        let initTime = Date().timeIntervalSince(startTime) * 1000
        NSLog("PyannoteSegmentation: Initialized in \(Int(initTime))ms (numSpeakers=\(config.numSpeakers), threshold=\(config.clusteringThreshold))")

        return true
    }

    /// Get the expected sample rate for the diarizer.
    /// - Returns: Sample rate in Hz (typically 16000)
    func getSampleRate() -> Int {
        lock.lock()
        defer { lock.unlock() }

        guard let diar = diarizer else { return STTConstants.defaultSampleRate }
        return Int(SherpaOnnxOfflineSpeakerDiarizationSampleRate(diar))
    }

    /// Process audio samples and return speaker diarization results.
    /// Thread-safe: processing happens on dedicated queue.
    /// - Parameters:
    ///   - samples: Audio samples (mono, float, -1 to 1 range)
    ///   - sampleRate: Sample rate in Hz (should match getSampleRate())
    /// - Returns: List of diarization segments with speaker assignments, or nil on error
    func process(samples: [Float], sampleRate: Int) -> [DiarizationSegment]? {
        lock.lock()
        guard let diar = diarizer else {
            lock.unlock()
            NSLog("PyannoteSegmentation: Process called but diarizer not initialized")
            return nil
        }
        lock.unlock()

        guard !samples.isEmpty else {
            NSLog("PyannoteSegmentation: Empty samples array")
            return nil
        }

        let durationMs = (samples.count * 1000) / sampleRate
        NSLog("PyannoteSegmentation: Processing \(samples.count) samples (\(durationMs)ms) @ \(sampleRate)Hz")

        let startTime = Date()

        // Process through diarization pipeline
        let segmentsPtr = samples.withUnsafeBufferPointer { ptr -> OpaquePointer? in
            return SherpaOnnxOfflineSpeakerDiarizationProcess(diar, ptr.baseAddress, Int32(samples.count))
        }

        guard let segPtr = segmentsPtr else {
            NSLog("PyannoteSegmentation: Processing returned nil")
            return nil
        }
        defer {
            SherpaOnnxOfflineSpeakerDiarizationDestroySegment(segPtr)
        }

        let numSegments = Int(SherpaOnnxOfflineSpeakerDiarizationResultGetNumSegments(segPtr))
        let processingTime = Date().timeIntervalSince(startTime) * 1000
        NSLog("PyannoteSegmentation: Processed in \(Int(processingTime))ms, found \(numSegments) segments")

        // Convert to our segment format
        var result: [DiarizationSegment] = []
        for i in 0..<numSegments {
            let segment = SherpaOnnxOfflineSpeakerDiarizationResultGetSegment(segPtr, Int32(i))
            let startMs = Int64(segment.start * 1000)
            let endMs = Int64(segment.end * 1000)
            let speakerId = Int(segment.speaker)

            result.append(DiarizationSegment(
                startTimeMs: startMs,
                endTimeMs: endMs,
                speakerId: speakerId
            ))
        }

        // Log segment details
        if !result.isEmpty {
            let speakers = Array(Set(result.map { $0.speakerId })).sorted()
            NSLog("PyannoteSegmentation: Speakers detected: \(speakers)")

            for (index, seg) in result.enumerated() {
                NSLog("PyannoteSegmentation: Segment \(index): \(seg.startTimeMs)-\(seg.endTimeMs)ms -> Speaker \(seg.speakerId)")
            }

            // Emit multi-speaker event if multiple speakers detected
            if speakers.count > 1 {
                let windowPairs = result.map { (Int($0.startTimeMs), $0.speakerId) }
                delegate?.sendMultiSpeakerDetected(speakers: speakers, windowPairs: windowPairs)
            }
        } else {
            NSLog("PyannoteSegmentation: No segments detected")
        }

        return result
    }

    /// Process audio and return the primary speaker ID.
    /// Useful for simple speaker tracking during live recording.
    /// - Parameters:
    ///   - samples: Audio samples
    ///   - sampleRate: Sample rate
    /// - Returns: Primary speaker ID (most spoken time), or -1 on error
    func processPrimarySpeaker(samples: [Float], sampleRate: Int) -> Int {
        guard let segments = process(samples: samples, sampleRate: sampleRate),
              !segments.isEmpty else {
            return -1
        }

        // Find speaker with most speech time
        var speakerDurations: [Int: Int64] = [:]
        for segment in segments {
            let duration = segment.endTimeMs - segment.startTimeMs
            speakerDurations[segment.speakerId, default: 0] += duration
        }

        return speakerDurations.max(by: { $0.value < $1.value })?.key ?? -1
    }

    /// Set the expected number of speakers.
    /// Use 0 for auto-detection (uses clustering threshold).
    func setNumSpeakers(_ numSpeakers: Int) {
        lock.lock()
        defer { lock.unlock() }

        config.numSpeakers = numSpeakers

        guard let diar = diarizer else { return }

        var clusteringConfig = SherpaOnnxFastClusteringConfig()
        clusteringConfig.num_clusters = Int32(numSpeakers)
        clusteringConfig.threshold = config.clusteringThreshold

        SherpaOnnxOfflineSpeakerDiarizationSetConfig(diar, &clusteringConfig)
        NSLog("PyannoteSegmentation: Set numSpeakers to \(numSpeakers)")
    }

    /// Set the clustering threshold for auto speaker count detection.
    /// Only used when numSpeakers = 0.
    /// Smaller values = more clusters (more speakers detected).
    func setClusteringThreshold(_ threshold: Float) {
        lock.lock()
        defer { lock.unlock() }

        config.clusteringThreshold = threshold

        guard let diar = diarizer else { return }

        var clusteringConfig = SherpaOnnxFastClusteringConfig()
        clusteringConfig.num_clusters = Int32(config.numSpeakers)
        clusteringConfig.threshold = threshold

        SherpaOnnxOfflineSpeakerDiarizationSetConfig(diar, &clusteringConfig)
        NSLog("PyannoteSegmentation: Set clustering threshold to \(threshold)")
    }

    /// Check if the manager is initialized
    func isInitialized() -> Bool {
        lock.lock()
        defer { lock.unlock() }
        return diarizer != nil
    }

    /// Release all resources
    func release() {
        lock.lock()
        defer { lock.unlock() }

        if let diar = diarizer {
            SherpaOnnxDestroyOfflineSpeakerDiarization(diar)
            diarizer = nil
        }

        NSLog("PyannoteSegmentation: Released")
    }

    deinit {
        release()
    }
}
