"""
审计日志服务
负责记录用户操作、数据变更和系统事件的审计日志
"""

import logging
import json
import uuid
from datetime import datetime, timedelta
from typing import List, Dict, Optional, Any, Union, Tuple
from dataclasses import dataclass
from enum import Enum

from sqlalchemy.orm import Session
from sqlalchemy import and_, or_, func, desc, asc, text
from sqlalchemy.sql import func as sql_func

from app.models import AuditLog

logger = logging.getLogger(__name__)


class AuditAction(str, Enum):
    """审计操作类型"""
    CREATE = "create"
    READ = "read"
    UPDATE = "update"
    DELETE = "delete"
    LOGIN = "login"
    LOGOUT = "logout"
    EXPORT = "export"
    IMPORT = "import"
    APPROVE = "approve"
    REJECT = "reject"
    PROCESS = "process"


class AuditSeverity(str, Enum):
    """审计严重级别"""
    LOW = "low"
    INFO = "info"
    WARNING = "warning"
    HIGH = "high"
    CRITICAL = "critical"


@dataclass
class AuditSummary:
    """审计摘要统计"""
    total_events: int
    action_breakdown: Dict[str, int]
    severity_breakdown: Dict[str, int]
    user_breakdown: Dict[str, int]
    resource_breakdown: Dict[str, int]
    error_count: int
    success_rate: float


@dataclass
class UserActivity:
    """用户活动统计"""
    user_id: str
    user_name: str
    total_actions: int
    last_activity: datetime
    action_breakdown: Dict[str, int]
    risk_score: float


class AuditLoggingService:
    """审计日志服务"""
    
    def __init__(self, db: Session):
        self.db = db
    
    def log_action(
        self,
        action: str,
        resource_type: str,
        endpoint: str,
        method: str,
        status_code: int,
        user_id: Optional[str] = None,
        user_name: Optional[str] = None,
        resource_id: Optional[str] = None,
        description: Optional[str] = None,
        old_values: Optional[Dict[str, Any]] = None,
        new_values: Optional[Dict[str, Any]] = None,
        request_data: Optional[Dict[str, Any]] = None,
        session_id: Optional[str] = None,
        request_id: Optional[str] = None,
        client_ip: Optional[str] = None,
        user_agent: Optional[str] = None,
        referer: Optional[str] = None,
        business_category: Optional[str] = None,
        severity: str = AuditSeverity.INFO,
        tags: Optional[List[str]] = None,
        error_message: Optional[str] = None
    ) -> AuditLog:
        """记录审计日志"""
        try:
            # 确定状态
            if status_code >= 500:
                status = "error"
            elif status_code >= 400:
                status = "failed"
            else:
                status = "success"
            
            # 处理标签
            tags_str = ",".join(tags) if tags else None
            
            # 创建审计日志记录
            audit_log = AuditLog(
                session_id=session_id,
                request_id=request_id,
                user_id=user_id,
                user_name=user_name,
                action=action,
                resource_type=resource_type,
                resource_id=resource_id,
                endpoint=endpoint,
                method=method,
                description=description,
                old_values=old_values,
                new_values=new_values,
                request_data=request_data,
                status=status,
                status_code=status_code,
                error_message=error_message,
                client_ip=client_ip,
                user_agent=user_agent,
                referer=referer,
                business_category=business_category,
                severity=severity,
                tags=tags_str
            )
            
            self.db.add(audit_log)
            self.db.commit()
            
            logger.debug(
                f"Logged audit action: {action} on {resource_type} "
                f"by {user_name or user_id or 'anonymous'}"
            )
            
            return audit_log
            
        except Exception as e:
            logger.error(f"Error logging audit action: {e}")
            self.db.rollback()
            raise
    
    def log_data_change(
        self,
        resource_type: str,
        resource_id: str,
        action: str,
        old_data: Dict[str, Any],
        new_data: Dict[str, Any],
        user_id: Optional[str] = None,
        user_name: Optional[str] = None,
        description: Optional[str] = None,
        business_category: Optional[str] = None,
        severity: str = AuditSeverity.INFO
    ) -> AuditLog:
        """记录数据变更日志"""
        try:
            # 计算变更的字段
            changed_fields = []
            for key in set(old_data.keys()) | set(new_data.keys()):
                old_value = old_data.get(key)
                new_value = new_data.get(key)
                if old_value != new_value:
                    changed_fields.append(key)
            
            if not description:
                description = f"Changed fields: {', '.join(changed_fields)}"
            
            return self.log_action(
                action=action,
                resource_type=resource_type,
                resource_id=resource_id,
                endpoint="/data-change",
                method="UPDATE",
                status_code=200,
                user_id=user_id,
                user_name=user_name,
                description=description,
                old_values=old_data,
                new_values=new_data,
                business_category=business_category,
                severity=severity,
                tags=["data_change"] + changed_fields
            )
            
        except Exception as e:
            logger.error(f"Error logging data change: {e}")
            raise
    
    def log_user_activity(
        self,
        user_id: str,
        user_name: str,
        action: str,
        endpoint: str,
        method: str,
        session_id: Optional[str] = None,
        client_ip: Optional[str] = None,
        user_agent: Optional[str] = None,
        success: bool = True
    ) -> AuditLog:
        """记录用户活动"""
        return self.log_action(
            action=action,
            resource_type="user_session",
            endpoint=endpoint,
            method=method,
            status_code=200 if success else 401,
            user_id=user_id,
            user_name=user_name,
            session_id=session_id,
            client_ip=client_ip,
            user_agent=user_agent,
            business_category="user_management",
            severity=AuditSeverity.INFO if success else AuditSeverity.WARNING
        )
    
    def log_security_event(
        self,
        event_type: str,
        description: str,
        client_ip: Optional[str] = None,
        user_id: Optional[str] = None,
        user_name: Optional[str] = None,
        severity: str = AuditSeverity.HIGH,
        additional_data: Optional[Dict[str, Any]] = None
    ) -> AuditLog:
        """记录安全事件"""
        return self.log_action(
            action=event_type,
            resource_type="security",
            endpoint="/security-event",
            method="EVENT",
            status_code=200,
            user_id=user_id,
            user_name=user_name,
            description=description,
            request_data=additional_data,
            client_ip=client_ip,
            business_category="security",
            severity=severity,
            tags=["security", event_type]
        )
    
    def get_audit_logs(
        self,
        user_id: Optional[str] = None,
        action: Optional[str] = None,
        resource_type: Optional[str] = None,
        business_category: Optional[str] = None,
        severity: Optional[str] = None,
        start_time: Optional[datetime] = None,
        end_time: Optional[datetime] = None,
        page: int = 1,
        page_size: int = 50,
        order_by: str = "timestamp",
        order_desc: bool = True
    ) -> Tuple[List[AuditLog], int]:
        """获取审计日志列表"""
        try:
            query = self.db.query(AuditLog)
            
            # 应用过滤条件
            conditions = []
            if user_id:
                conditions.append(AuditLog.user_id == user_id)
            if action:
                conditions.append(AuditLog.action == action)
            if resource_type:
                conditions.append(AuditLog.resource_type == resource_type)
            if business_category:
                conditions.append(AuditLog.business_category == business_category)
            if severity:
                conditions.append(AuditLog.severity == severity)
            if start_time:
                conditions.append(AuditLog.timestamp >= start_time)
            if end_time:
                conditions.append(AuditLog.timestamp <= end_time)
            
            if conditions:
                query = query.filter(and_(*conditions))
            
            # 排序
            if order_desc:
                query = query.order_by(desc(getattr(AuditLog, order_by)))
            else:
                query = query.order_by(asc(getattr(AuditLog, order_by)))
            
            # 分页
            total = query.count()
            items = query.offset((page - 1) * page_size).limit(page_size).all()
            
            return items, total
            
        except Exception as e:
            logger.error(f"Error getting audit logs: {e}")
            return [], 0
    
    def get_audit_summary(
        self,
        start_time: Optional[datetime] = None,
        end_time: Optional[datetime] = None,
        user_id: Optional[str] = None
    ) -> AuditSummary:
        """获取审计摘要统计"""
        try:
            query = self.db.query(AuditLog)
            
            # 应用时间和用户过滤
            conditions = []
            if start_time:
                conditions.append(AuditLog.timestamp >= start_time)
            if end_time:
                conditions.append(AuditLog.timestamp <= end_time)
            if user_id:
                conditions.append(AuditLog.user_id == user_id)
            
            if conditions:
                query = query.filter(and_(*conditions))
            
            # 总事件数
            total_events = query.count()
            
            # 按操作类型分组
            action_stats = query.with_entities(
                AuditLog.action,
                func.count(AuditLog.id).label('count')
            ).group_by(AuditLog.action).all()
            action_breakdown = {stat.action: stat.count for stat in action_stats}
            
            # 按严重级别分组
            severity_stats = query.with_entities(
                AuditLog.severity,
                func.count(AuditLog.id).label('count')
            ).group_by(AuditLog.severity).all()
            severity_breakdown = {stat.severity: stat.count for stat in severity_stats}
            
            # 按用户分组
            user_stats = query.with_entities(
                AuditLog.user_name,
                func.count(AuditLog.id).label('count')
            ).filter(
                AuditLog.user_name.isnot(None)
            ).group_by(AuditLog.user_name).all()
            user_breakdown = {stat.user_name: stat.count for stat in user_stats}
            
            # 按资源类型分组
            resource_stats = query.with_entities(
                AuditLog.resource_type,
                func.count(AuditLog.id).label('count')
            ).group_by(AuditLog.resource_type).all()
            resource_breakdown = {stat.resource_type: stat.count for stat in resource_stats}
            
            # 错误统计
            error_count = query.filter(AuditLog.status.in_(['failed', 'error'])).count()
            success_rate = ((total_events - error_count) / total_events * 100) if total_events > 0 else 100.0
            
            return AuditSummary(
                total_events=total_events,
                action_breakdown=action_breakdown,
                severity_breakdown=severity_breakdown,
                user_breakdown=user_breakdown,
                resource_breakdown=resource_breakdown,
                error_count=error_count,
                success_rate=success_rate
            )
            
        except Exception as e:
            logger.error(f"Error getting audit summary: {e}")
            return AuditSummary(0, {}, {}, {}, {}, 0, 0.0)
    
    def get_user_activity_report(
        self,
        start_time: Optional[datetime] = None,
        end_time: Optional[datetime] = None,
        limit: int = 50
    ) -> List[UserActivity]:
        """获取用户活动报告"""
        try:
            query = self.db.query(AuditLog).filter(
                AuditLog.user_id.isnot(None)
            )
            
            # 应用时间过滤
            if start_time:
                query = query.filter(AuditLog.timestamp >= start_time)
            if end_time:
                query = query.filter(AuditLog.timestamp <= end_time)
            
            # 按用户分组统计
            user_stats = query.with_entities(
                AuditLog.user_id,
                AuditLog.user_name,
                func.count(AuditLog.id).label('total_actions'),
                func.max(AuditLog.timestamp).label('last_activity')
            ).group_by(
                AuditLog.user_id, AuditLog.user_name
            ).order_by(
                desc('total_actions')
            ).limit(limit).all()
            
            user_activities = []
            for stat in user_stats:
                # 获取用户的操作类型统计
                action_stats = query.filter(
                    AuditLog.user_id == stat.user_id
                ).with_entities(
                    AuditLog.action,
                    func.count(AuditLog.id).label('count')
                ).group_by(AuditLog.action).all()
                
                action_breakdown = {a.action: a.count for a in action_stats}
                
                # 计算风险分数（简单实现）
                risk_score = 0.0
                high_risk_actions = ['delete', 'export', 'admin']
                for action, count in action_breakdown.items():
                    if action in high_risk_actions:
                        risk_score += count * 0.5
                    else:
                        risk_score += count * 0.1
                
                user_activities.append(UserActivity(
                    user_id=stat.user_id,
                    user_name=stat.user_name or 'Unknown',
                    total_actions=stat.total_actions,
                    last_activity=stat.last_activity,
                    action_breakdown=action_breakdown,
                    risk_score=min(risk_score, 100.0)  # 限制在100以内
                ))
            
            return user_activities
            
        except Exception as e:
            logger.error(f"Error getting user activity report: {e}")
            return []
    
    def get_security_events(
        self,
        severity: Optional[str] = None,
        event_type: Optional[str] = None,
        hours_back: int = 24,
        limit: int = 100
    ) -> List[AuditLog]:
        """获取安全事件"""
        try:
            since_time = datetime.now() - timedelta(hours=hours_back)
            
            query = self.db.query(AuditLog).filter(
                and_(
                    AuditLog.business_category == 'security',
                    AuditLog.timestamp >= since_time
                )
            )
            
            if severity:
                query = query.filter(AuditLog.severity == severity)
            
            if event_type:
                query = query.filter(AuditLog.action == event_type)
            
            return query.order_by(desc(AuditLog.timestamp)).limit(limit).all()
            
        except Exception as e:
            logger.error(f"Error getting security events: {e}")
            return []
    
    def get_failed_operations(
        self,
        hours_back: int = 24,
        limit: int = 100
    ) -> List[AuditLog]:
        """获取失败的操作"""
        try:
            since_time = datetime.now() - timedelta(hours=hours_back)
            
            return self.db.query(AuditLog).filter(
                and_(
                    AuditLog.status.in_(['failed', 'error']),
                    AuditLog.timestamp >= since_time
                )
            ).order_by(desc(AuditLog.timestamp)).limit(limit).all()
            
        except Exception as e:
            logger.error(f"Error getting failed operations: {e}")
            return []
    
    def search_logs(
        self,
        search_term: str,
        search_fields: List[str] = None,
        start_time: Optional[datetime] = None,
        end_time: Optional[datetime] = None,
        limit: int = 100
    ) -> List[AuditLog]:
        """搜索审计日志"""
        try:
            if not search_fields:
                search_fields = ['description', 'user_name', 'endpoint', 'resource_type']
            
            query = self.db.query(AuditLog)
            
            # 构建搜索条件
            search_conditions = []
            for field in search_fields:
                if hasattr(AuditLog, field):
                    attr = getattr(AuditLog, field)
                    search_conditions.append(attr.ilike(f"%{search_term}%"))
            
            if search_conditions:
                query = query.filter(or_(*search_conditions))
            
            # 时间过滤
            if start_time:
                query = query.filter(AuditLog.timestamp >= start_time)
            if end_time:
                query = query.filter(AuditLog.timestamp <= end_time)
            
            return query.order_by(desc(AuditLog.timestamp)).limit(limit).all()
            
        except Exception as e:
            logger.error(f"Error searching logs: {e}")
            return []
    
    def cleanup_old_logs(self, days_to_keep: int = 90) -> int:
        """清理旧的审计日志"""
        try:
            cutoff_date = datetime.now() - timedelta(days=days_to_keep)
            
            deleted_count = self.db.query(AuditLog).filter(
                AuditLog.timestamp < cutoff_date
            ).delete()
            
            self.db.commit()
            
            logger.info(f"Cleaned up {deleted_count} old audit log records")
            return deleted_count
            
        except Exception as e:
            logger.error(f"Error cleaning up old logs: {e}")
            self.db.rollback()
            return 0
    
    def export_logs(
        self,
        start_time: datetime,
        end_time: datetime,
        format: str = "json",
        user_id: Optional[str] = None
    ) -> Dict[str, Any]:
        """导出审计日志"""
        try:
            query = self.db.query(AuditLog).filter(
                and_(
                    AuditLog.timestamp >= start_time,
                    AuditLog.timestamp <= end_time
                )
            )
            
            if user_id:
                query = query.filter(AuditLog.user_id == user_id)
            
            logs = query.order_by(AuditLog.timestamp).all()
            
            # 转换为字典格式
            export_data = []
            for log in logs:
                log_dict = {
                    'id': log.id,
                    'timestamp': log.timestamp.isoformat(),
                    'user_id': log.user_id,
                    'user_name': log.user_name,
                    'action': log.action,
                    'resource_type': log.resource_type,
                    'resource_id': log.resource_id,
                    'endpoint': log.endpoint,
                    'method': log.method,
                    'description': log.description,
                    'status': log.status,
                    'status_code': log.status_code,
                    'client_ip': log.client_ip,
                    'severity': log.severity,
                    'business_category': log.business_category,
                    'tags': log.tags
                }
                
                # 包含变更数据（如果存在）
                if log.old_values:
                    log_dict['old_values'] = log.old_values
                if log.new_values:
                    log_dict['new_values'] = log.new_values
                if log.request_data:
                    log_dict['request_data'] = log.request_data
                
                export_data.append(log_dict)
            
            return {
                'total_records': len(export_data),
                'export_time': datetime.now().isoformat(),
                'time_range': {
                    'start': start_time.isoformat(),
                    'end': end_time.isoformat()
                },
                'format': format,
                'data': export_data
            }
            
        except Exception as e:
            logger.error(f"Error exporting logs: {e}")
            return {'error': str(e)}