"""
Advanced Monitoring and Logging System for Hotel MCP Server.

This module provides structured logging, performance metrics, and monitoring
capabilities that extend the existing logging infrastructure in a DRY way.
"""

import json
import logging
import time
from collections import defaultdict, deque
from contextlib import contextmanager
from dataclasses import dataclass, field
from datetime import datetime
from enum import Enum
from functools import wraps
from typing import Any, Dict, List, Optional, Union

# Reuse existing logger infrastructure
logger = logging.getLogger(__name__)


class MetricType(Enum):
    """Types of metrics we can collect."""

    COUNTER = "counter"
    TIMER = "timer"
    GAUGE = "gauge"
    HISTOGRAM = "histogram"


class LogLevel(Enum):
    """Enhanced log levels for structured logging."""

    TRACE = "TRACE"
    DEBUG = "DEBUG"
    INFO = "INFO"
    WARNING = "WARNING"
    ERROR = "ERROR"
    CRITICAL = "CRITICAL"


@dataclass
class PerformanceMetric:
    """Performance metric data structure."""

    name: str
    value: Union[int, float]
    metric_type: MetricType
    timestamp: datetime = field(default_factory=datetime.now)
    tags: Dict[str, str] = field(default_factory=dict)
    unit: str = "ms"


@dataclass
class StructuredLogEntry:
    """Structured log entry with context and metadata."""

    level: LogLevel
    message: str
    timestamp: datetime = field(default_factory=datetime.now)
    context: Dict[str, Any] = field(default_factory=dict)
    tags: Dict[str, str] = field(default_factory=dict)
    request_id: Optional[str] = None
    user_id: Optional[str] = None
    tool_name: Optional[str] = None
    duration_ms: Optional[float] = None


class MetricsCollector:
    """
    DRY metrics collector that aggregates performance data.

    Reuses existing infrastructure while adding advanced capabilities.
    """

    def __init__(self, max_history: int = 1000):
        """Initialize metrics collector."""
        self.max_history = max_history
        self.metrics: Dict[str, deque] = defaultdict(lambda: deque(maxlen=max_history))
        self.counters: Dict[str, int] = defaultdict(int)
        self.gauges: Dict[str, float] = defaultdict(float)
        self.timers: Dict[str, List[float]] = defaultdict(list)

        logger.info("Metrics collector initialized")

    def record_metric(self, metric: PerformanceMetric) -> None:
        """Record a performance metric."""
        self.metrics[metric.name].append(metric)

        # Update aggregated data based on metric type
        if metric.metric_type == MetricType.COUNTER:
            self.counters[metric.name] += metric.value
        elif metric.metric_type == MetricType.GAUGE:
            self.gauges[metric.name] = metric.value
        elif metric.metric_type == MetricType.TIMER:
            self.timers[metric.name].append(metric.value)
            # Keep only recent timer values
            if len(self.timers[metric.name]) > self.max_history:
                self.timers[metric.name] = self.timers[metric.name][-self.max_history :]

    def get_summary(self) -> Dict[str, Any]:
        """Get metrics summary for monitoring dashboard."""
        summary = {
            "timestamp": datetime.now().isoformat(),
            "counters": dict(self.counters),
            "gauges": dict(self.gauges),
            "timers": {},
        }

        # Calculate timer statistics
        for name, values in self.timers.items():
            if values:
                summary["timers"][name] = {
                    "count": len(values),
                    "avg": sum(values) / len(values),
                    "min": min(values),
                    "max": max(values),
                    "p95": self._percentile(values, 95),
                    "p99": self._percentile(values, 99),
                }

        return summary

    def _percentile(self, values: List[float], percentile: int) -> float:
        """Calculate percentile value."""
        if not values:
            return 0.0
        sorted_values = sorted(values)
        index = int(len(sorted_values) * percentile / 100)
        return sorted_values[min(index, len(sorted_values) - 1)]


class StructuredLogger:
    """
    Enhanced structured logger that extends existing logging infrastructure.

    Provides context-aware logging with performance tracking.
    """

    def __init__(self, name: str, metrics_collector: Optional[MetricsCollector] = None):
        """Initialize structured logger."""
        self.logger = logging.getLogger(name)
        self.metrics = metrics_collector or MetricsCollector()
        self.context_stack: List[Dict[str, Any]] = []

    def log_structured(self, entry: StructuredLogEntry) -> None:
        """Log a structured entry."""
        # Merge context from stack
        merged_context = {}
        for ctx in self.context_stack:
            merged_context.update(ctx)
        merged_context.update(entry.context)

        # Create structured message
        structured_data = {
            "timestamp": entry.timestamp.isoformat(),
            "level": entry.level.value,
            "message": entry.message,
            "context": merged_context,
            "tags": entry.tags,
        }

        # Add optional fields
        if entry.request_id:
            structured_data["request_id"] = entry.request_id
        if entry.user_id:
            structured_data["user_id"] = entry.user_id
        if entry.tool_name:
            structured_data["tool_name"] = entry.tool_name
        if entry.duration_ms is not None:
            structured_data["duration_ms"] = entry.duration_ms

        # Log using existing infrastructure
        log_level = getattr(logging, entry.level.value)
        self.logger.log(log_level, json.dumps(structured_data, default=str))

        # Record performance metric if duration provided
        if entry.duration_ms is not None:
            metric = PerformanceMetric(
                name=f"tool.{entry.tool_name or 'unknown'}.duration",
                value=entry.duration_ms,
                metric_type=MetricType.TIMER,
                tags=entry.tags,
            )
            self.metrics.record_metric(metric)

    @contextmanager
    def context(self, **kwargs):
        """Add context to all logs within this block."""
        self.context_stack.append(kwargs)
        try:
            yield
        finally:
            self.context_stack.pop()

    def info(self, message: str, **kwargs) -> None:
        """Log info message with context."""
        entry = StructuredLogEntry(level=LogLevel.INFO, message=message, context=kwargs)
        self.log_structured(entry)

    def error(self, message: str, **kwargs) -> None:
        """Log error message with context."""
        entry = StructuredLogEntry(
            level=LogLevel.ERROR, message=message, context=kwargs
        )
        self.log_structured(entry)

    def warning(self, message: str, **kwargs) -> None:
        """Log warning message with context."""
        entry = StructuredLogEntry(
            level=LogLevel.WARNING, message=message, context=kwargs
        )
        self.log_structured(entry)


def performance_monitor(tool_name: str):
    """
    DRY decorator for monitoring tool performance.

    Automatically tracks execution time and logs performance metrics.
    """

    def decorator(func):
        @wraps(func)
        def wrapper(*args, **kwargs):
            start_time = time.time()

            # Extract context from arguments
            context = {
                "tool_name": tool_name,
                "args_count": len(args),
                "kwargs_keys": list(kwargs.keys()),
            }

            # Add specific context based on common parameters
            if "site_id" in kwargs:
                context["site_id"] = kwargs["site_id"]
            if "language" in kwargs:
                context["language"] = kwargs["language"]
            if "limit" in kwargs:
                context["limit"] = kwargs["limit"]

            try:
                # Execute the function
                result = func(*args, **kwargs)

                # Calculate duration
                duration_ms = (time.time() - start_time) * 1000

                # Log successful execution
                entry = StructuredLogEntry(
                    level=LogLevel.INFO,
                    message=f"Tool executed successfully: {tool_name}",
                    context=context,
                    tool_name=tool_name,
                    duration_ms=duration_ms,
                    tags={"status": "success"},
                )

                # Use global logger instance
                global_logger.log_structured(entry)

                # Record performance metrics
                global_metrics.record_metric(
                    PerformanceMetric(
                        name=f"tool.{tool_name}.executions",
                        value=1,
                        metric_type=MetricType.COUNTER,
                        tags={"status": "success"},
                    )
                )

                return result

            except Exception as e:
                # Calculate duration even for errors
                duration_ms = (time.time() - start_time) * 1000

                # Log error
                entry = StructuredLogEntry(
                    level=LogLevel.ERROR,
                    message=f"Tool execution failed: {tool_name}",
                    context={
                        **context,
                        "error": str(e),
                        "error_type": type(e).__name__,
                    },
                    tool_name=tool_name,
                    duration_ms=duration_ms,
                    tags={"status": "error"},
                )

                global_logger.log_structured(entry)

                # Record error metrics
                global_metrics.record_metric(
                    PerformanceMetric(
                        name=f"tool.{tool_name}.errors",
                        value=1,
                        metric_type=MetricType.COUNTER,
                        tags={"status": "error", "error_type": type(e).__name__},
                    )
                )

                raise

        return wrapper

    return decorator


# Global instances for DRY usage across the application
global_metrics = MetricsCollector()
global_logger = StructuredLogger("hotel_mcp.monitoring", global_metrics)


def get_monitoring_summary() -> Dict[str, Any]:
    """Get comprehensive monitoring summary."""
    return {
        "metrics": global_metrics.get_summary(),
        "system_info": {
            "timestamp": datetime.now().isoformat(),
            "logger_name": global_logger.logger.name,
            "metrics_history_size": global_metrics.max_history,
        },
    }


def log_system_event(event_type: str, message: str, **context) -> None:
    """Log system-level events with context."""
    global_logger.info(message, event_type=event_type, **context)


def record_business_metric(name: str, value: Union[int, float], **tags) -> None:
    """Record business-level metrics."""
    metric = PerformanceMetric(
        name=f"business.{name}", value=value, metric_type=MetricType.GAUGE, tags=tags
    )
    global_metrics.record_metric(metric)
