"""
Intelligent Caching System for Hotel MCP Server.

This module provides a DRY caching layer that improves performance by caching
frequently accessed data with intelligent invalidation strategies.
"""

import hashlib
import json
import logging
import time
from collections import OrderedDict
from dataclasses import dataclass, field
from datetime import datetime, timedelta
from enum import Enum
from functools import wraps
from typing import Any, Callable, Dict, Optional, Union

logger = logging.getLogger(__name__)


class CacheStrategy(Enum):
    """Cache invalidation strategies."""

    TTL = "ttl"  # Time-to-live
    LRU = "lru"  # Least Recently Used
    LFU = "lfu"  # Least Frequently Used
    ADAPTIVE = "adaptive"  # Adaptive based on access patterns


@dataclass
class CacheEntry:
    """Cache entry with metadata."""

    key: str
    value: Any
    created_at: datetime = field(default_factory=datetime.now)
    last_accessed: datetime = field(default_factory=datetime.now)
    access_count: int = 0
    ttl_seconds: Optional[int] = None
    tags: Dict[str, str] = field(default_factory=dict)

    @property
    def is_expired(self) -> bool:
        """Check if cache entry is expired."""
        if self.ttl_seconds is None:
            return False

        expiry_time = self.created_at + timedelta(seconds=self.ttl_seconds)
        return datetime.now() > expiry_time

    @property
    def age_seconds(self) -> float:
        """Get age of cache entry in seconds."""
        return (datetime.now() - self.created_at).total_seconds()

    def touch(self) -> None:
        """Update access metadata."""
        self.last_accessed = datetime.now()
        self.access_count += 1


class IntelligentCache:
    """
    Intelligent cache with multiple eviction strategies and performance monitoring.

    Designed to be DRY and reusable across different data types and access patterns.
    """

    def __init__(
        self,
        max_size: int = 1000,
        default_ttl: int = 300,  # 5 minutes
        strategy: CacheStrategy = CacheStrategy.ADAPTIVE,
    ):
        """Initialize intelligent cache."""
        self.max_size = max_size
        self.default_ttl = default_ttl
        self.strategy = strategy

        # Use OrderedDict for LRU behavior
        self._cache: OrderedDict[str, CacheEntry] = OrderedDict()
        self._stats = {"hits": 0, "misses": 0, "evictions": 0, "expired_removals": 0}

        logger.info(
            f"Intelligent cache initialized: max_size={max_size}, strategy={strategy.value}"
        )

    def _generate_key(self, *args, **kwargs) -> str:
        """Generate cache key from arguments."""
        # Create deterministic key from arguments
        key_data = {"args": args, "kwargs": sorted(kwargs.items())}
        key_string = json.dumps(key_data, sort_keys=True, default=str)
        return hashlib.md5(key_string.encode()).hexdigest()

    def get(self, key: str) -> Optional[Any]:
        """Get value from cache."""
        if key not in self._cache:
            self._stats["misses"] += 1
            return None

        entry = self._cache[key]

        # Check if expired
        if entry.is_expired:
            self.delete(key)
            self._stats["misses"] += 1
            self._stats["expired_removals"] += 1
            return None

        # Update access metadata
        entry.touch()

        # Move to end for LRU
        self._cache.move_to_end(key)

        self._stats["hits"] += 1
        return entry.value

    def set(
        self,
        key: str,
        value: Any,
        ttl: Optional[int] = None,
        tags: Optional[Dict[str, str]] = None,
    ) -> None:
        """Set value in cache."""
        # Use default TTL if not specified
        effective_ttl = ttl if ttl is not None else self.default_ttl

        # Create cache entry
        entry = CacheEntry(
            key=key, value=value, ttl_seconds=effective_ttl, tags=tags or {}
        )

        # Add to cache
        self._cache[key] = entry

        # Move to end (most recently used)
        self._cache.move_to_end(key)

        # Evict if necessary
        self._evict_if_needed()

    def delete(self, key: str) -> bool:
        """Delete entry from cache."""
        if key in self._cache:
            del self._cache[key]
            return True
        return False

    def clear(self) -> None:
        """Clear all cache entries."""
        self._cache.clear()
        logger.info("Cache cleared")

    def _evict_if_needed(self) -> None:
        """Evict entries if cache is full."""
        while len(self._cache) > self.max_size:
            self._evict_one()

    def _evict_one(self) -> None:
        """Evict one entry based on strategy."""
        if not self._cache:
            return

        if self.strategy == CacheStrategy.LRU:
            # Remove least recently used (first item)
            key = next(iter(self._cache))
            del self._cache[key]

        elif self.strategy == CacheStrategy.LFU:
            # Remove least frequently used
            lfu_key = min(self._cache.keys(), key=lambda k: self._cache[k].access_count)
            del self._cache[lfu_key]

        elif self.strategy == CacheStrategy.TTL:
            # Remove oldest entry
            oldest_key = min(
                self._cache.keys(), key=lambda k: self._cache[k].created_at
            )
            del self._cache[oldest_key]

        elif self.strategy == CacheStrategy.ADAPTIVE:
            # Adaptive strategy: remove expired first, then LRU
            expired_keys = [k for k, v in self._cache.items() if v.is_expired]
            if expired_keys:
                del self._cache[expired_keys[0]]
            else:
                # Fall back to LRU
                key = next(iter(self._cache))
                del self._cache[key]

        self._stats["evictions"] += 1

    def get_stats(self) -> Dict[str, Any]:
        """Get cache statistics."""
        total_requests = self._stats["hits"] + self._stats["misses"]
        hit_rate = self._stats["hits"] / total_requests if total_requests > 0 else 0

        return {
            "size": len(self._cache),
            "max_size": self.max_size,
            "hit_rate": hit_rate,
            "stats": self._stats.copy(),
            "strategy": self.strategy.value,
            "oldest_entry_age": min(
                (entry.age_seconds for entry in self._cache.values()), default=0
            ),
        }

    def cleanup_expired(self) -> int:
        """Remove all expired entries."""
        expired_keys = [k for k, v in self._cache.items() if v.is_expired]
        for key in expired_keys:
            del self._cache[key]

        self._stats["expired_removals"] += len(expired_keys)
        return len(expired_keys)


def cached(
    ttl: int = 300,
    key_prefix: str = "",
    cache_instance: Optional[IntelligentCache] = None,
):
    """
    DRY decorator for caching function results.

    Automatically caches function results based on arguments.
    """

    def decorator(func: Callable) -> Callable:
        # Use global cache if none provided
        cache = cache_instance or global_cache

        @wraps(func)
        def wrapper(*args, **kwargs):
            # Generate cache key
            func_key = f"{key_prefix}{func.__name__}"
            cache_key = f"{func_key}:{cache._generate_key(*args, **kwargs)}"

            # Try to get from cache
            cached_result = cache.get(cache_key)
            if cached_result is not None:
                logger.debug(f"Cache hit for {func.__name__}")
                return cached_result

            # Execute function
            logger.debug(f"Cache miss for {func.__name__}, executing function")
            result = func(*args, **kwargs)

            # Cache result
            cache.set(
                cache_key,
                result,
                ttl=ttl,
                tags={"function": func.__name__, "module": func.__module__},
            )

            return result

        # Add cache management methods to function
        wrapper.cache_clear = lambda: cache.clear()
        wrapper.cache_stats = lambda: cache.get_stats()

        return wrapper

    return decorator


def cache_by_site_and_language(ttl: int = 300):
    """
    Specialized DRY decorator for caching by site_id and language.

    Common pattern for hotel MCP tools.
    """

    def decorator(func: Callable) -> Callable:
        @wraps(func)
        def wrapper(*args, **kwargs):
            # Extract site_id and language for cache key
            site_id = kwargs.get("site_id", "default")
            language = kwargs.get("language", "es")

            cache_key = f"{func.__name__}:site_{site_id}:lang_{language}:{global_cache._generate_key(*args, **kwargs)}"

            # Try cache first
            cached_result = global_cache.get(cache_key)
            if cached_result is not None:
                return cached_result

            # Execute and cache
            result = func(*args, **kwargs)
            global_cache.set(
                cache_key,
                result,
                ttl=ttl,
                tags={
                    "function": func.__name__,
                    "site_id": site_id,
                    "language": language,
                },
            )

            return result

        return wrapper

    return decorator


# Global cache instance for DRY usage
global_cache = IntelligentCache(
    max_size=1000, default_ttl=300, strategy=CacheStrategy.ADAPTIVE  # 5 minutes
)


def get_cache_summary() -> Dict[str, Any]:
    """Get comprehensive cache summary for monitoring."""
    return {
        "cache_stats": global_cache.get_stats(),
        "expired_cleaned": global_cache.cleanup_expired(),
        "timestamp": datetime.now().isoformat(),
    }


def invalidate_cache_by_tags(**tags) -> int:
    """Invalidate cache entries by tags."""
    invalidated = 0
    keys_to_remove = []

    for key, entry in global_cache._cache.items():
        # Check if entry matches all provided tags
        if all(
            entry.tags.get(tag_key) == tag_value for tag_key, tag_value in tags.items()
        ):
            keys_to_remove.append(key)

    for key in keys_to_remove:
        global_cache.delete(key)
        invalidated += 1

    logger.info(f"Invalidated {invalidated} cache entries by tags: {tags}")
    return invalidated
