#!/usr/bin/env python3
"""
MCP Server for Image/Video Understanding using Doubao Vision Model

This server provides tools for analyzing images and videos using the Doubao vision model.
The API key is provided by the MCP client, not hardcoded in the server.

This is a simplified implementation that works with Python 3.9+
"""

import asyncio
import json
import sys
from typing import Any, Dict, List, Optional, Union
import base64
import os

import httpx


class SimpleMCPServer:
    """Simplified MCP server implementation."""
    
    def __init__(self, name: str):
        self.name = name
        self.tools = {}
        self.resources = {}
        self.prompts = {}
    
    def tool(self, name: str = None):
        """Decorator to register a tool."""
        def decorator(func):
            tool_name = name or func.__name__
            self.tools[tool_name] = {
                'name': tool_name,
                'description': func.__doc__ or '',
                'function': func
            }
            return func
        return decorator
    
    def resource(self, uri_template: str):
        """Decorator to register a resource."""
        def decorator(func):
            self.resources[uri_template] = {
                'uriTemplate': uri_template,
                'description': func.__doc__ or '',
                'function': func
            }
            return func
        return decorator
    
    def prompt(self, name: str = None):
        """Decorator to register a prompt."""
        def decorator(func):
            prompt_name = name or func.__name__
            self.prompts[prompt_name] = {
                'name': prompt_name,
                'description': func.__doc__ or '',
                'function': func
            }
            return func
        return decorator
    
    async def handle_request(self, request: Dict[str, Any]) -> Dict[str, Any]:
        """Handle MCP request."""
        method = request.get('method')
        params = request.get('params', {})
        
        if method == 'initialize':
            return {
                'jsonrpc': '2.0',
                'id': request.get('id'),
                'result': {
                    'protocolVersion': '2024-11-05',
                    'capabilities': {
                        'tools': {},
                        'resources': {},
                        'prompts': {}
                    },
                    'serverInfo': {
                        'name': self.name,
                        'version': '1.0.0'
                    }
                }
            }
        
        elif method == 'tools/list':
            # Define specific schemas for each tool
            tool_schemas = {
                'analyze_image': {
                    'type': 'object',
                    'properties': {
                        'image_input': {
                            'type': 'string',
                            'description': 'URL of the image or local file path to analyze'
                        },
                        'custom_prompt': {
                            'type': 'string',
                            'description': 'Custom prompt for analysis (optional)'
                        },
                        'model': {
                            'type': 'string',
                            'description': 'Model name to use',
                            'default': 'doubao-seed-1-6-flash-250615'
                        }
                    },
                    'required': ['image_input']
                },
                'analyze_video': {
                    'type': 'object',
                    'properties': {
                        'video_url': {
                            'type': 'string',
                            'description': 'URL of the video to analyze'
                        },
                        'custom_prompt': {
                            'type': 'string',
                            'description': 'Custom prompt for analysis (optional)'
                        },
                        'model': {
                            'type': 'string',
                            'description': 'Model name to use',
                            'default': 'doubao-seed-1-6-flash-250615'
                        }
                    },
                    'required': ['video_url']
                }
            }
            
            return {
                'jsonrpc': '2.0',
                'id': request.get('id'),
                'result': {
                    'tools': [
                        {
                            'name': tool['name'],
                            'description': tool['description'],
                            'inputSchema': tool_schemas.get(tool['name'], {
                                'type': 'object',
                                'properties': {},
                                'required': []
                            })
                        }
                        for tool in self.tools.values()
                    ]
                }
            }
        
        elif method == 'tools/call':
            tool_name = params.get('name')
            arguments = params.get('arguments', {})
            
            if tool_name in self.tools:
                try:
                    result = await self.tools[tool_name]['function'](**arguments)
                    return {
                        'jsonrpc': '2.0',
                        'id': request.get('id'),
                        'result': {
                            'content': [
                                {
                                    'type': 'text',
                                    'text': json.dumps(result) if isinstance(result, dict) else str(result)
                                }
                            ]
                        }
                    }
                except Exception as e:
                    return {
                        'jsonrpc': '2.0',
                        'id': request.get('id'),
                        'error': {
                            'code': -32603,
                            'message': f'Tool execution failed: {str(e)}'
                        }
                    }
        
        elif method == 'resources/list':
            return {
                'jsonrpc': '2.0',
                'id': request.get('id'),
                'result': {
                    'resources': [
                        {
                            'uri': resource['uriTemplate'],
                            'name': resource['uriTemplate'],
                            'description': resource['description']
                        }
                        for resource in self.resources.values()
                    ]
                }
            }
        
        elif method == 'prompts/list':
            return {
                'jsonrpc': '2.0',
                'id': request.get('id'),
                'result': {
                    'prompts': [
                        {
                            'name': prompt['name'],
                            'description': prompt['description']
                        }
                        for prompt in self.prompts.values()
                    ]
                }
            }
        
        return {
            'jsonrpc': '2.0',
            'id': request.get('id'),
            'error': {
                'code': -32601,
                'message': f'Method not found: {method}'
            }
        }
    
    async def run_stdio(self):
        """Run the server using stdio transport."""
        while True:
            try:
                line = await asyncio.get_event_loop().run_in_executor(None, sys.stdin.readline)
                if not line:
                    break
                
                request = json.loads(line.strip())
                response = await self.handle_request(request)
                
                print(json.dumps(response), flush=True)
                
            except json.JSONDecodeError:
                continue
            except Exception as e:
                error_response = {
                    'jsonrpc': '2.0',
                    'id': None,
                    'error': {
                        'code': -32700,
                        'message': f'Parse error: {str(e)}'
                    }
                }
                print(json.dumps(error_response), flush=True)


# Create MCP server instance
mcp = SimpleMCPServer("Vision Understanding Server")


def encode_image_to_base64(image_path: str) -> Optional[str]:
    """
    Encode a local image file to base64.
    
    Args:
        image_path: Path to the local image file
    
    Returns:
        Base64 encoded image string or None if failed
    """
    try:
        if not os.path.exists(image_path):
            return None
        
        with open(image_path, "rb") as image_file:
            encoded_string = base64.b64encode(image_file.read()).decode('utf-8')
            
        # Get file extension to determine MIME type
        _, ext = os.path.splitext(image_path.lower())
        mime_type = {
            '.jpg': 'image/jpeg',
            '.jpeg': 'image/jpeg',
            '.png': 'image/png',
            '.gif': 'image/gif',
            '.bmp': 'image/bmp',
            '.webp': 'image/webp'
        }.get(ext, 'image/jpeg')
        
        return f"data:{mime_type};base64,{encoded_string}"
        
    except Exception as e:
        return None


@mcp.tool()
async def analyze_image(
    image_input: str,
    custom_prompt: Optional[str] = None,
    model: str = "doubao-seed-1-6-flash-250615"
) -> Dict[str, Any]:
    """
    Analyze an image using Doubao vision model.
    
    Args:
        image_input: URL of the image or local file path to analyze
        custom_prompt: Custom prompt for analysis (optional)
        model: Model name to use (default: doubao-seed-1-6-flash-250615)
    
    Returns:
        Dictionary with analysis results
    """
    # Get API key from environment variable
    api_key = os.getenv('DOUBAO_API_KEY')
    if not api_key:
        return {
            "description": "",
            "model_used": model,
            "success": False,
            "error_message": "DOUBAO_API_KEY environment variable not set",
            "input_type": "unknown"
        }
    # Default prompt if none provided
    if custom_prompt is None:
        custom_prompt = "请详细描述这张图片的内容，包括主要对象、场景、颜色、构图等视觉元素。"
    
    # Determine if input is URL or local file
    image_url = image_input
    input_type = "url"
    
    if not image_input.startswith(('http://', 'https://')):
        # It's a local file path
        input_type = "file"
        encoded_image = encode_image_to_base64(image_input)
        if not encoded_image:
            return {
                "description": "",
                "model_used": model,
                "success": False,
                "error_message": "Failed to encode local image file or file not found",
                "input_type": input_type
            }
        image_url = encoded_image
    
    # Construct the request payload
    request_data = {
        "model": model,
        "messages": [
            {
                "content": [
                    {
                        "text": custom_prompt,
                        "type": "text"
                    },
                    {
                        "image_url": {
                            "url": image_url
                        },
                        "type": "image_url"
                    }
                ],
                "role": "user"
            }
        ]
    }
    
    headers = {
        "Content-Type": "application/json",
        "Authorization": f"Bearer {api_key}"
    }
    
    try:
        async with httpx.AsyncClient(timeout=30.0) as client:
            response = await client.post(
                "https://ark.cn-beijing.volces.com/api/v3/chat/completions",
                json=request_data,
                headers=headers
            )
            
            if response.status_code == 200:
                result = response.json()
                description = result["choices"][0]["message"]["content"]
                
                return {
                    "description": description,
                    "model_used": model,
                    "success": True,
                    "error_message": None,
                    "image_input": image_input,
                    "input_type": input_type
                }
            else:
                error_msg = f"API request failed with status {response.status_code}: {response.text}"
                return {
                    "description": "",
                    "model_used": model,
                    "success": False,
                    "error_message": error_msg,
                    "image_input": image_input,
                    "input_type": input_type
                }
                
    except Exception as e:
        return {
            "description": "",
            "model_used": model,
            "success": False,
            "error_message": f"Error during API call: {str(e)}",
            "image_input": image_input,
            "input_type": "unknown"
        }


@mcp.tool()
async def analyze_video(
    video_url: str,
    custom_prompt: Optional[str] = None,
    model: str = "doubao-seed-1-6-flash-250615"
) -> Dict[str, Any]:
    """
    Analyze a video using Doubao vision model.
    
    Args:
        video_url: URL of the video to analyze
        custom_prompt: Custom prompt for analysis (optional)
        model: Model name to use (default: doubao-seed-1-6-flash-250615)
    
    Returns:
        Dictionary with analysis results
    """
    # Get API key from environment variable
    api_key = os.getenv('DOUBAO_API_KEY')
    if not api_key:
        return {
            "description": "",
            "model_used": model,
            "success": False,
            "error_message": "DOUBAO_API_KEY environment variable not set"
        }
    # Default prompt for video analysis
    if custom_prompt is None:
        custom_prompt = "请详细描述这个视频的内容，包括主要场景、人物动作、情节发展、视觉效果等元素。"
    
    # For video analysis, we use the correct video_url format according to Volcengine API documentation
    request_data = {
        "model": model,
        "messages": [
            {
                "content": [
                    {
                        "text": custom_prompt,
                        "type": "text"
                    },
                    {
                        "video_url": {
                            "url": video_url,
                            "fps": 1.0  # Default fps for video analysis
                        },
                        "type": "video_url"
                    }
                ],
                "role": "user"
            }
        ]
    }
    
    headers = {
        "Content-Type": "application/json",
        "Authorization": f"Bearer {api_key}"
    }
    
    try:
        async with httpx.AsyncClient(timeout=60.0) as client:  # Longer timeout for video
            response = await client.post(
                "https://ark.cn-beijing.volces.com/api/v3/chat/completions",
                json=request_data,
                headers=headers
            )
            
            if response.status_code == 200:
                result = response.json()
                description = result["choices"][0]["message"]["content"]
                
                return {
                    "description": description,
                    "model_used": model,
                    "success": True,
                    "error_message": None
                }
            else:
                error_msg = f"API request failed with status {response.status_code}: {response.text}"
                return {
                    "description": "",
                    "model_used": model,
                    "success": False,
                    "error_message": error_msg
                }
                
    except Exception as e:
        return {
            "description": "",
            "model_used": model,
            "success": False,
            "error_message": f"Error during API call: {str(e)}"
        }


@mcp.resource("vision://prompts/{prompt_type}")
def get_vision_prompt(prompt_type: str) -> str:
    """
    Get predefined prompts for different types of vision analysis.
    
    Args:
        prompt_type: Type of prompt (general, detailed, artistic, technical)
    
    Returns:
        Predefined prompt text
    """
    prompts = {
        "general": "请描述这张图片/视频的主要内容。",
        "detailed": "请详细描述这张图片/视频的内容，包括主要对象、场景、颜色、构图、情感表达等视觉元素。",
        "artistic": "请从艺术角度分析这张图片/视频，包括构图、色彩搭配、光影效果、艺术风格等。",
        "technical": "请从技术角度分析这张图片/视频，包括拍摄技巧、画质、分辨率、技术特点等。",
        "objects": "请识别并列出图片/视频中的所有主要对象和元素。",
        "scene": "请描述图片/视频中的场景和环境。",
        "emotion": "请分析图片/视频传达的情感和氛围。"
    }
    
    return prompts.get(prompt_type, prompts["general"])


@mcp.prompt()
def create_vision_analysis_prompt(
    analysis_type: str = "general",
    focus_areas: str = "all"
) -> str:
    """
    Create a customized prompt for vision analysis.
    
    Args:
        analysis_type: Type of analysis (general, detailed, artistic, technical)
        focus_areas: Specific areas to focus on (objects, scene, emotion, composition)
    
    Returns:
        Customized prompt for vision analysis
    """
    base_prompts = {
        "general": "请描述这张图片/视频的主要内容",
        "detailed": "请详细分析这张图片/视频",
        "artistic": "请从艺术角度分析这张图片/视频",
        "technical": "请从技术角度分析这张图片/视频"
    }
    
    focus_instructions = {
        "objects": "，重点关注其中的对象和元素",
        "scene": "，重点关注场景和环境",
        "emotion": "，重点关注情感表达和氛围",
        "composition": "，重点关注构图和视觉效果",
        "all": "，包括对象、场景、情感、构图等各个方面"
    }
    
    base = base_prompts.get(analysis_type, base_prompts["general"])
    focus = focus_instructions.get(focus_areas, focus_instructions["all"])
    
    return f"{base}{focus}。请提供准确、详细的描述。"


if __name__ == "__main__":
    # Run the server
    if len(sys.argv) > 1 and sys.argv[1] == "stdio":
        asyncio.run(mcp.run_stdio())
    else:
        print("Usage: python3 vision_mcp_server.py stdio")
        print("This MCP server provides vision understanding capabilities using Doubao model.")
        print("\nAvailable tools:")
        print("- analyze_image: Analyze image content")
        print("- analyze_video: Analyze video content")
        print("\nAvailable resources:")
        print("- vision://prompts/{prompt_type}: Get predefined prompts")
        print("\nAvailable prompts:")
        print("- create_vision_analysis_prompt: Create customized analysis prompts")
        print("\nNote: This server requires httpx. Install with: pip3 install httpx")