"""
Phase 3 产品聚合服务
从标准化订单数据生成产品主表和采购需求
"""

import logging
from datetime import datetime, timedelta
from decimal import Decimal
from typing import Dict, List, Optional, Any
from collections import defaultdict

from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select, func, and_, or_, desc, asc
from sqlalchemy.orm import selectinload

from app.models.normalized_orders import OrderItemNorm, ProcurementStatus
from app.models.products import (
    Product, PendingPurchase, PurchaseList, PurchaseListItem,
    ProductCategory, ProcurementPriority
)
from app.utils.classification import ProductClassifier
from app.utils.image_processor import ImageProcessor

logger = logging.getLogger(__name__)


class ProductAggregationService:
    """产品聚合服务"""
    
    def __init__(self):
        self.classifier = ProductClassifier()
        self.image_processor = ImageProcessor()
    
    async def aggregate_products(
        self,
        db: AsyncSession,
        force_rebuild: bool = False,
        batch_size: int = 1000
    ) -> Dict[str, Any]:
        """
        从标准化订单聚合生成产品数据
        
        Args:
            db: 数据库会话
            force_rebuild: 是否强制重新构建所有产品
            batch_size: 批处理大小
            
        Returns:
            聚合结果统计
        """
        logger.info(f"开始产品聚合，force_rebuild={force_rebuild}, batch_size={batch_size}")
        start_time = datetime.now()
        
        # 获取需要聚合的标准化订单数据
        if force_rebuild:
            # 删除现有产品数据
            await self._clear_products(db)
            
        # 查询标准化订单数据
        norm_orders = await self._get_normalized_orders(db, force_rebuild)
        
        if not norm_orders:
            logger.info("没有需要聚合的标准化订单数据")
            return {
                "status": "success",
                "processed_orders": 0,
                "created_products": 0,
                "updated_products": 0,
                "duration": (datetime.now() - start_time).total_seconds()
            }
        
        logger.info(f"获取到 {len(norm_orders)} 条标准化订单数据")
        
        # 按 product_key 分组聚合
        product_groups = self._group_orders_by_product_key(norm_orders)
        
        created_count = 0
        updated_count = 0
        
        # 分批处理产品
        product_keys = list(product_groups.keys())
        for i in range(0, len(product_keys), batch_size):
            batch_keys = product_keys[i:i + batch_size]
            batch_results = await self._process_product_batch(
                db, product_groups, batch_keys
            )
            created_count += batch_results["created"]
            updated_count += batch_results["updated"]
            
            # 提交批次
            await db.commit()
            logger.info(f"已处理 {i + len(batch_keys)}/{len(product_keys)} 个产品")
        
        # 更新采购需求
        await self._update_procurement_requirements(db)
        
        duration = (datetime.now() - start_time).total_seconds()
        result = {
            "status": "success",
            "processed_orders": len(norm_orders),
            "created_products": created_count,
            "updated_products": updated_count,
            "total_products": created_count + updated_count,
            "duration": duration
        }
        
        logger.info(f"产品聚合完成: {result}")
        return result
    
    async def _clear_products(self, db: AsyncSession):
        """清理现有产品数据"""
        logger.info("清理现有产品数据")
        
        # 删除采购清单明细
        await db.execute(select(PurchaseListItem).filter(PurchaseListItem.id > 0).delete())
        
        # 删除采购清单
        await db.execute(select(PurchaseList).filter(PurchaseList.id > 0).delete())
        
        # 删除待采购记录
        await db.execute(select(PendingPurchase).filter(PendingPurchase.id > 0).delete())
        
        # 删除产品
        await db.execute(select(Product).filter(Product.id > 0).delete())
        
        await db.commit()
    
    async def _get_normalized_orders(
        self,
        db: AsyncSession,
        force_rebuild: bool = False
    ) -> List[OrderItemNorm]:
        """获取需要聚合的标准化订单"""
        if force_rebuild:
            # 获取所有标准化订单
            result = await db.execute(
                select(OrderItemNorm).order_by(OrderItemNorm.付款时间.desc())
            )
        else:
            # 只获取尚未聚合的订单
            result = await db.execute(
                select(OrderItemNorm)
                .outerjoin(Product, OrderItemNorm.product_key == Product.product_key)
                .filter(Product.id.is_(None))
                .order_by(OrderItemNorm.付款时间.desc())
            )
        
        return result.scalars().all()
    
    def _group_orders_by_product_key(
        self,
        norm_orders: List[OrderItemNorm]
    ) -> Dict[str, List[OrderItemNorm]]:
        """按 product_key 分组订单"""
        groups = defaultdict(list)
        
        for order in norm_orders:
            if order.product_key:
                groups[order.product_key].append(order)
        
        return dict(groups)
    
    async def _process_product_batch(
        self,
        db: AsyncSession,
        product_groups: Dict[str, List[OrderItemNorm]],
        batch_keys: List[str]
    ) -> Dict[str, int]:
        """处理产品批次"""
        created = 0
        updated = 0
        
        for product_key in batch_keys:
            orders = product_groups[product_key]
            if not orders:
                continue
            
            # 检查产品是否已存在
            result = await db.execute(
                select(Product).filter(Product.product_key == product_key)
            )
            existing_product = result.scalar_one_or_none()
            
            if existing_product:
                # 更新现有产品
                await self._update_product(db, existing_product, orders)
                updated += 1
            else:
                # 创建新产品
                await self._create_product(db, product_key, orders)
                created += 1
        
        return {"created": created, "updated": updated}
    
    async def _create_product(
        self,
        db: AsyncSession,
        product_key: str,
        orders: List[OrderItemNorm]
    ):
        """创建新产品"""
        if not orders:
            return
        
        # 使用第一个订单的基本信息
        first_order = orders[0]
        
        # 计算聚合数据
        aggregated_data = self._calculate_aggregated_data(orders)
        
        # 自动分类
        category, confidence = self.classifier.classify_product(
            first_order.线上宝贝名称,
            first_order.品牌
        )
        
        # 处理图片
        images = []
        for order in orders:
            if order.图片链接:
                order_images = self.image_processor.deserialize_images(order.图片链接)
                images.extend(order_images)
        
        # 去重图片
        unique_images = list(dict.fromkeys(images))
        
        # 创建产品记录
        product = Product(
            product_key=product_key,
            sku_id=first_order.sku_id,
            品牌=first_order.品牌,
            货号=first_order.货号,
            线上宝贝名称=first_order.线上宝贝名称,
            颜色=first_order.颜色,
            尺寸=first_order.尺寸,
            
            # 分类信息
            category=category,
            category_confidence=confidence,
            
            # 聚合统计
            平均价格=aggregated_data["平均价格"],
            最低价格=aggregated_data["最低价格"],
            最高价格=aggregated_data["最高价格"],
            总销量=aggregated_data["总销量"],
            订单数量=aggregated_data["订单数量"],
            最近订单时间=aggregated_data["最近订单时间"],
            首次订单时间=aggregated_data["首次订单时间"],
            
            # 采购相关
            待采购数量=self._calculate_procurement_demand(orders),
            procurement_priority=self._calculate_procurement_priority(orders),
            
            # 图片
            图片数量=len(unique_images),
            主要图片=self.image_processor.serialize_images(unique_images[:5]) if unique_images else None,
            
            last_aggregated_at=datetime.now()
        )
        
        db.add(product)
        await db.flush()
        
        logger.debug(f"创建产品: {product.product_key}, 品牌: {product.品牌}, 分类: {category.value}")
    
    async def _update_product(
        self,
        db: AsyncSession,
        product: Product,
        orders: List[OrderItemNorm]
    ):
        """更新现有产品"""
        # 重新计算聚合数据
        aggregated_data = self._calculate_aggregated_data(orders)
        
        # 更新产品信息
        product.总销量 = aggregated_data["总销量"]
        product.订单数量 = aggregated_data["订单数量"]
        product.平均价格 = aggregated_data["平均价格"]
        product.最低价格 = aggregated_data["最低价格"]
        product.最高价格 = aggregated_data["最高价格"]
        product.最近订单时间 = aggregated_data["最近订单时间"]
        product.首次订单时间 = aggregated_data["首次订单时间"]
        
        # 更新采购需求
        product.待采购数量 = self._calculate_procurement_demand(orders)
        product.procurement_priority = self._calculate_procurement_priority(orders)
        
        # 更新时间
        product.updated_at = datetime.now()
        product.last_aggregated_at = datetime.now()
        
        logger.debug(f"更新产品: {product.product_key}, 总销量: {product.总销量}")
    
    def _calculate_aggregated_data(self, orders: List[OrderItemNorm]) -> Dict[str, Any]:
        """计算聚合统计数据"""
        if not orders:
            return {}
        
        prices = [float(order.订单单价) for order in orders if order.订单单价]
        quantities = [order.数量 for order in orders if order.数量]
        pay_times = [order.付款时间 for order in orders if order.付款时间]
        
        return {
            "平均价格": Decimal(str(sum(prices) / len(prices))) if prices else None,
            "最低价格": Decimal(str(min(prices))) if prices else None,
            "最高价格": Decimal(str(max(prices))) if prices else None,
            "总销量": sum(quantities),
            "订单数量": len(orders),
            "最近订单时间": max(pay_times) if pay_times else None,
            "首次订单时间": min(pay_times) if pay_times else None,
        }
    
    def _calculate_procurement_demand(self, orders: List[OrderItemNorm]) -> int:
        """计算采购需求"""
        # 统计等待采购的订单数量
        waiting_orders = [
            order for order in orders 
            if order.procurement_status == ProcurementStatus.WAITING
        ]
        return sum(order.数量 for order in waiting_orders if order.数量)
    
    def _calculate_procurement_priority(self, orders: List[OrderItemNorm]) -> ProcurementPriority:
        """计算采购优先级"""
        if not orders:
            return ProcurementPriority.LOW
        
        # 近期订单数量
        recent_date = datetime.now() - timedelta(days=7)
        recent_orders = [
            order for order in orders 
            if order.付款时间 and order.付款时间 >= recent_date
        ]
        
        total_demand = sum(order.数量 for order in orders if order.数量)
        recent_demand = sum(order.数量 for order in recent_orders if order.数量)
        
        # 根据总需求和近期需求确定优先级
        if total_demand >= 50 or recent_demand >= 20:
            return ProcurementPriority.HIGH
        elif total_demand >= 20 or recent_demand >= 10:
            return ProcurementPriority.MEDIUM
        else:
            return ProcurementPriority.LOW
    
    async def _update_procurement_requirements(self, db: AsyncSession):
        """更新采购需求"""
        logger.info("更新采购需求")
        
        # 查询所有有采购需求的产品
        result = await db.execute(
            select(Product).filter(Product.待采购数量 > 0)
        )
        products_with_demand = result.scalars().all()
        
        for product in products_with_demand:
            # 检查是否已有待采购记录
            result = await db.execute(
                select(PendingPurchase).filter(
                    PendingPurchase.product_key == product.product_key
                )
            )
            existing_pending = result.scalar_one_or_none()
            
            if existing_pending:
                # 更新现有记录
                existing_pending.需求数量 = product.待采购数量
                existing_pending.剩余需求 = max(0, product.待采购数量 - existing_pending.已分配数量)
                existing_pending.priority = product.procurement_priority
                existing_pending.last_updated = datetime.now()
            else:
                # 创建新的待采购记录
                pending_purchase = PendingPurchase(
                    product_id=product.id,
                    product_key=product.product_key,
                    需求数量=product.待采购数量,
                    剩余需求=product.待采购数量,
                    priority=product.procurement_priority,
                    建议采购价=product.平均价格
                )
                db.add(pending_purchase)
        
        await db.commit()
    
    async def get_product_stats(self, db: AsyncSession) -> Dict[str, Any]:
        """获取产品统计信息"""
        # 总产品数
        total_result = await db.execute(select(func.count(Product.id)))
        total_products = total_result.scalar() or 0
        
        # 按分类统计
        category_result = await db.execute(
            select(Product.category, func.count(Product.id))
            .group_by(Product.category)
        )
        category_stats = dict(category_result.all())
        
        # 待采购统计
        pending_result = await db.execute(
            select(func.count(Product.id), func.sum(Product.待采购数量))
            .filter(Product.待采购数量 > 0)
        )
        pending_row = pending_result.first()
        pending_products = pending_row[0] or 0
        total_pending_quantity = int(pending_row[1] or 0)
        
        # 按优先级统计待采购
        priority_result = await db.execute(
            select(Product.procurement_priority, func.count(Product.id))
            .filter(Product.待采购数量 > 0)
            .group_by(Product.procurement_priority)
        )
        priority_stats = dict(priority_result.all())
        
        return {
            "total_products": total_products,
            "category_distribution": {
                category.value: count for category, count in category_stats.items()
            },
            "procurement": {
                "pending_products": pending_products,
                "total_pending_quantity": total_pending_quantity,
                "priority_distribution": {
                    priority.value: count for priority, count in priority_stats.items()
                }
            }
        }
    
    async def search_products(
        self,
        db: AsyncSession,
        keyword: Optional[str] = None,
        brand: Optional[str] = None,
        category: Optional[ProductCategory] = None,
        has_procurement_demand: Optional[bool] = None,
        priority: Optional[ProcurementPriority] = None,
        page: int = 1,
        page_size: int = 20
    ) -> Dict[str, Any]:
        """搜索产品"""
        query = select(Product)
        
        # 构建过滤条件
        conditions = []
        
        if keyword:
            conditions.append(
                or_(
                    Product.线上宝贝名称.contains(keyword),
                    Product.品牌.contains(keyword),
                    Product.货号.contains(keyword)
                )
            )
        
        if brand:
            conditions.append(Product.品牌 == brand)
        
        if category:
            conditions.append(Product.category == category)
        
        if has_procurement_demand is not None:
            if has_procurement_demand:
                conditions.append(Product.待采购数量 > 0)
            else:
                conditions.append(Product.待采购数量 == 0)
        
        if priority:
            conditions.append(Product.procurement_priority == priority)
        
        if conditions:
            query = query.filter(and_(*conditions))
        
        # 计算总数
        count_query = select(func.count(Product.id))
        if conditions:
            count_query = count_query.filter(and_(*conditions))
        
        total_result = await db.execute(count_query)
        total = total_result.scalar() or 0
        
        # 分页查询
        offset = (page - 1) * page_size
        query = query.offset(offset).limit(page_size).order_by(desc(Product.updated_at))
        
        result = await db.execute(query)
        products = result.scalars().all()
        
        return {
            "products": products,
            "pagination": {
                "page": page,
                "page_size": page_size,
                "total": total,
                "pages": (total + page_size - 1) // page_size
            }
        }