"""
性能监控和审计日志中间件
自动收集API请求的性能指标和操作日志
"""

import time
import uuid
import logging
from typing import Callable, Dict, Any, Optional
from fastapi import Request, Response
from starlette.middleware.base import BaseHTTPMiddleware

from app.core.database import get_sync_db
from app.services.performance_monitoring_service import PerformanceMonitoringService
from app.services.audit_logging_service import AuditLoggingService, AuditAction

logger = logging.getLogger(__name__)


class MonitoringMiddleware(BaseHTTPMiddleware):
    """监控中间件 - 集成性能监控和审计日志"""
    
    def __init__(
        self,
        app,
        enable_performance_monitoring: bool = True,
        enable_audit_logging: bool = True,
        excluded_paths: Optional[list] = None
    ):
        super().__init__(app)
        self.enable_performance_monitoring = enable_performance_monitoring
        self.enable_audit_logging = enable_audit_logging
        self.excluded_paths = excluded_paths or [
            "/docs", "/redoc", "/openapi.json", "/favicon.ico",
            "/health", "/metrics", "/static"
        ]
    
    async def dispatch(self, request: Request, call_next: Callable) -> Response:
        """处理请求并收集监控数据"""
        # 生成唯一的请求ID
        request_id = str(uuid.uuid4())
        request.state.request_id = request_id
        
        # 记录请求开始时间
        start_time = time.time()
        
        # 获取请求信息
        endpoint = str(request.url.path)
        method = request.method
        client_ip = self._get_client_ip(request)
        user_agent = request.headers.get("user-agent")
        
        # 检查是否需要跳过监控
        if self._should_skip_monitoring(endpoint):
            response = await call_next(request)
            return response
        
        # 获取请求体大小
        request_size = await self._get_request_size(request)
        
        # 执行请求
        response = None
        error_message = None
        
        try:
            response = await call_next(request)
            
        except Exception as e:
            error_message = str(e)
            logger.error(f"Request {request_id} failed: {e}")
            raise
            
        finally:
            # 计算响应时间
            end_time = time.time()
            response_time_ms = int((end_time - start_time) * 1000)
            
            # 获取响应信息
            status_code = response.status_code if response else 500
            response_size = self._get_response_size(response) if response else 0
            
            # 异步记录监控数据（避免影响响应时间）
            try:
                await self._record_monitoring_data(
                    request_id=request_id,
                    endpoint=endpoint,
                    method=method,
                    status_code=status_code,
                    response_time_ms=response_time_ms,
                    request_size=request_size,
                    response_size=response_size,
                    client_ip=client_ip,
                    user_agent=user_agent,
                    error_message=error_message,
                    request=request
                )
            except Exception as e:
                # 监控记录失败不应影响正常响应
                logger.error(f"Failed to record monitoring data: {e}")
        
        return response
    
    def _should_skip_monitoring(self, endpoint: str) -> bool:
        """检查是否应跳过监控"""
        for excluded_path in self.excluded_paths:
            if endpoint.startswith(excluded_path):
                return True
        return False
    
    def _get_client_ip(self, request: Request) -> Optional[str]:
        """获取客户端IP地址"""
        # 检查代理头
        forwarded_for = request.headers.get("x-forwarded-for")
        if forwarded_for:
            return forwarded_for.split(",")[0].strip()
        
        real_ip = request.headers.get("x-real-ip")
        if real_ip:
            return real_ip
        
        # 使用客户端地址
        if hasattr(request, 'client') and request.client:
            return request.client.host
        
        return None
    
    async def _get_request_size(self, request: Request) -> int:
        """获取请求体大小"""
        try:
            content_length = request.headers.get("content-length")
            if content_length:
                return int(content_length)
        except (ValueError, TypeError):
            pass
        return 0
    
    def _get_response_size(self, response: Response) -> int:
        """获取响应体大小"""
        try:
            if hasattr(response, 'headers'):
                content_length = response.headers.get("content-length")
                if content_length:
                    return int(content_length)
        except (ValueError, TypeError):
            pass
        return 0
    
    async def _record_monitoring_data(
        self,
        request_id: str,
        endpoint: str,
        method: str,
        status_code: int,
        response_time_ms: int,
        request_size: int,
        response_size: int,
        client_ip: Optional[str],
        user_agent: Optional[str],
        error_message: Optional[str],
        request: Request
    ):
        """记录监控数据"""
        # 获取数据库会话
        try:
            db = next(get_sync_db())
            
            # 记录性能指标
            if self.enable_performance_monitoring:
                await self._record_performance_metric(
                    db, request_id, endpoint, method, status_code,
                    response_time_ms, request_size, response_size,
                    client_ip, user_agent, error_message
                )
            
            # 记录审计日志
            if self.enable_audit_logging:
                await self._record_audit_log(
                    db, request_id, endpoint, method, status_code,
                    client_ip, user_agent, error_message, request
                )
                
        except Exception as e:
            logger.error(f"Error in monitoring data recording: {e}")
        finally:
            if 'db' in locals():
                db.close()
    
    async def _record_performance_metric(
        self,
        db,
        request_id: str,
        endpoint: str,
        method: str,
        status_code: int,
        response_time_ms: int,
        request_size: int,
        response_size: int,
        client_ip: Optional[str],
        user_agent: Optional[str],
        error_message: Optional[str]
    ):
        """记录性能指标"""
        try:
            performance_service = PerformanceMonitoringService(db)
            performance_service.record_performance_metric(
                endpoint=endpoint,
                method=method,
                status_code=status_code,
                request_id=request_id,
                response_time_ms=response_time_ms,
                request_size_bytes=request_size,
                response_size_bytes=response_size,
                client_ip=client_ip,
                user_agent=user_agent,
                error_message=error_message
            )
        except Exception as e:
            logger.error(f"Error recording performance metric: {e}")
    
    async def _record_audit_log(
        self,
        db,
        request_id: str,
        endpoint: str,
        method: str,
        status_code: int,
        client_ip: Optional[str],
        user_agent: Optional[str],
        error_message: Optional[str],
        request: Request
    ):
        """记录审计日志"""
        try:
            audit_service = AuditLoggingService(db)
            
            # 确定操作类型
            action = self._determine_action(method, endpoint)
            
            # 确定资源类型
            resource_type = self._determine_resource_type(endpoint)
            
            # 获取用户信息（如果有认证）
            user_id = getattr(request.state, 'user_id', None)
            user_name = getattr(request.state, 'user_name', None)
            session_id = getattr(request.state, 'session_id', None)
            
            # 获取业务分类
            business_category = self._determine_business_category(endpoint)
            
            # 确定严重级别
            severity = self._determine_severity(method, status_code, endpoint)
            
            audit_service.log_action(
                action=action,
                resource_type=resource_type,
                endpoint=endpoint,
                method=method,
                status_code=status_code,
                request_id=request_id,
                user_id=user_id,
                user_name=user_name,
                session_id=session_id,
                client_ip=client_ip,
                user_agent=user_agent,
                business_category=business_category,
                severity=severity,
                error_message=error_message
            )
            
        except Exception as e:
            logger.error(f"Error recording audit log: {e}")
    
    def _determine_action(self, method: str, endpoint: str) -> str:
        """根据HTTP方法和端点确定操作类型"""
        if method == "POST":
            if "login" in endpoint.lower():
                return AuditAction.LOGIN
            elif "logout" in endpoint.lower():
                return AuditAction.LOGOUT
            elif "import" in endpoint.lower():
                return AuditAction.IMPORT
            elif "export" in endpoint.lower():
                return AuditAction.EXPORT
            elif "approve" in endpoint.lower():
                return AuditAction.APPROVE
            elif "reject" in endpoint.lower():
                return AuditAction.REJECT
            else:
                return AuditAction.CREATE
        elif method == "GET":
            return AuditAction.READ
        elif method in ["PUT", "PATCH"]:
            return AuditAction.UPDATE
        elif method == "DELETE":
            return AuditAction.DELETE
        else:
            return method.lower()
    
    def _determine_resource_type(self, endpoint: str) -> str:
        """根据端点确定资源类型"""
        endpoint_lower = endpoint.lower()
        
        if "/orders" in endpoint_lower:
            return "order"
        elif "/products" in endpoint_lower:
            return "product"
        elif "/procurement" in endpoint_lower:
            return "procurement"
        elif "/offset" in endpoint_lower:
            return "offset"
        elif "/users" in endpoint_lower:
            return "user"
        elif "/import" in endpoint_lower:
            return "import"
        elif "/export" in endpoint_lower:
            return "export"
        elif "/audit" in endpoint_lower:
            return "audit"
        elif "/monitoring" in endpoint_lower:
            return "monitoring"
        else:
            # 尝试从端点路径中提取资源类型
            parts = endpoint.strip("/").split("/")
            if len(parts) >= 2 and parts[0] == "api":
                return parts[1]
            elif len(parts) >= 1:
                return parts[0]
            else:
                return "unknown"
    
    def _determine_business_category(self, endpoint: str) -> str:
        """确定业务分类"""
        endpoint_lower = endpoint.lower()
        
        if any(keyword in endpoint_lower for keyword in ["login", "logout", "auth", "session"]):
            return "authentication"
        elif any(keyword in endpoint_lower for keyword in ["user", "profile", "permission"]):
            return "user_management"
        elif any(keyword in endpoint_lower for keyword in ["order", "purchase", "procurement"]):
            return "order_management"
        elif any(keyword in endpoint_lower for keyword in ["product", "inventory"]):
            return "product_management"
        elif any(keyword in endpoint_lower for keyword in ["import", "export", "batch"]):
            return "data_management"
        elif any(keyword in endpoint_lower for keyword in ["audit", "log", "monitoring"]):
            return "system_management"
        elif any(keyword in endpoint_lower for keyword in ["offset", "fifo", "refund"]):
            return "financial_management"
        else:
            return "general"
    
    def _determine_severity(self, method: str, status_code: int, endpoint: str) -> str:
        """确定严重级别"""
        # 基于状态码
        if status_code >= 500:
            return "critical"
        elif status_code >= 400:
            return "warning"
        
        # 基于操作类型
        if method == "DELETE":
            return "high"
        elif method in ["PUT", "PATCH"]:
            return "info"
        elif method == "POST":
            if any(keyword in endpoint.lower() for keyword in ["login", "logout"]):
                return "info"
            else:
                return "info"
        else:  # GET
            return "low"


class AuditMiddleware(BaseHTTPMiddleware):
    """专用审计日志中间件"""
    
    def __init__(self, app, excluded_paths: Optional[list] = None):
        super().__init__(app)
        self.excluded_paths = excluded_paths or [
            "/docs", "/redoc", "/openapi.json", "/favicon.ico", "/health"
        ]
    
    async def dispatch(self, request: Request, call_next: Callable) -> Response:
        """处理请求并记录审计日志"""
        # 检查是否需要跳过
        if self._should_skip(str(request.url.path)):
            return await call_next(request)
        
        # 获取请求信息
        request_id = str(uuid.uuid4())
        endpoint = str(request.url.path)
        method = request.method
        client_ip = request.headers.get("x-forwarded-for", "").split(",")[0].strip()
        if not client_ip:
            client_ip = getattr(request.client, 'host', None) if request.client else None
        
        # 执行请求
        start_time = time.time()
        response = await call_next(request)
        end_time = time.time()
        
        # 异步记录审计日志
        try:
            db = next(get_sync_db())
            audit_service = AuditLoggingService(db)
            
            # 记录操作日志
            audit_service.log_action(
                action=self._get_action(method),
                resource_type=self._get_resource_type(endpoint),
                endpoint=endpoint,
                method=method,
                status_code=response.status_code,
                request_id=request_id,
                client_ip=client_ip,
                user_agent=request.headers.get("user-agent")
            )
            
        except Exception as e:
            logger.error(f"Audit logging failed: {e}")
        finally:
            if 'db' in locals():
                db.close()
        
        return response
    
    def _should_skip(self, path: str) -> bool:
        """检查是否应跳过审计"""
        return any(path.startswith(excluded) for excluded in self.excluded_paths)
    
    def _get_action(self, method: str) -> str:
        """获取操作类型"""
        action_map = {
            "GET": "read",
            "POST": "create", 
            "PUT": "update",
            "PATCH": "update",
            "DELETE": "delete"
        }
        return action_map.get(method, method.lower())
    
    def _get_resource_type(self, endpoint: str) -> str:
        """获取资源类型"""
        parts = endpoint.strip("/").split("/")
        if len(parts) >= 2 and parts[0] == "api":
            return parts[1]
        elif len(parts) >= 1:
            return parts[0]
        else:
            return "unknown"