#!/usr/bin/env python3
"""
业务逻辑测试
测试系统的核心业务逻辑和数据完整性
"""

import asyncio
import sys
import os
from datetime import datetime, timedelta
from decimal import Decimal

sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'backend'))
os.environ['DATABASE_URL'] = 'sqlite+aiosqlite:///./backend/ordersys.db'

from app.core.database import AsyncSessionLocal
from app.models.raw_orders import RawOrder
from app.models.products_master import ProductMaster
from app.models.procurement_orders import ProcurementOrder, ProcurementStatus
from app.services.products_master_service import ProductsMasterService
from app.services.procurement_order_service_v2 import ProcurementOrderServiceV2
from sqlalchemy import select, func, and_, or_, distinct


class BusinessLogicTests:
    """业务逻辑测试类"""
    
    def __init__(self):
        self.passed = 0
        self.failed = 0
        self.warnings = 0
    
    def log_test(self, name: str, passed: bool, message: str = None):
        """记录测试结果"""
        if passed:
            print(f"  ✅ {name}")
            self.passed += 1
        else:
            print(f"  ❌ {name}: {message}")
            self.failed += 1
        if message and not passed:
            print(f"     原因: {message}")
    
    def log_warning(self, message: str):
        """记录警告"""
        print(f"  ⚠️ {message}")
        self.warnings += 1
    
    async def test_order_deduplication(self, db):
        """测试订单去重逻辑"""
        print("\n📋 测试1: 订单去重逻辑")
        
        # 查找重复的订单组合
        duplicate_query = """
            SELECT 原始订单编号, 线上宝贝名称, 线上销售属性, 数量, 付款时间, COUNT(*) as cnt
            FROM raw_orders
            GROUP BY 原始订单编号, 线上宝贝名称, 线上销售属性, 数量, 付款时间
            HAVING COUNT(*) > 1
        """
        
        from sqlalchemy import text
        result = await db.execute(text(duplicate_query))
        duplicates = result.fetchall()
        
        if duplicates:
            self.log_test("订单去重", False, f"发现 {len(duplicates)} 组重复订单")
            for dup in duplicates[:3]:  # 显示前3个
                print(f"     - 订单号: {dup[0]}, 重复数: {dup[5]}")
        else:
            self.log_test("订单去重", True)
        
        # 检查唯一约束是否生效
        try:
            constraint_query = """
                SELECT name FROM sqlite_master 
                WHERE type = 'index' AND name LIKE '%uk_raw_order%'
            """
            result = await db.execute(text(constraint_query))
            constraints = result.fetchall()
            
            if constraints:
                self.log_test("唯一约束存在", True, f"找到 {len(constraints)} 个约束")
            else:
                self.log_test("唯一约束存在", False, "未找到唯一约束")
        except Exception as e:
            self.log_test("唯一约束检查", False, str(e))
    
    async def test_product_sku_generation(self, db):
        """测试产品SKU生成逻辑"""
        print("\n📋 测试2: 产品SKU生成")
        
        # 获取产品样本
        result = await db.execute(
            select(ProductMaster)
            .where(ProductMaster.sku_key.is_not(None))
            .limit(10)
        )
        products = result.scalars().all()
        
        if not products:
            self.log_warning("没有产品数据可供测试")
            return
        
        # 测试SKU唯一性
        sku_counts = {}
        for product in products:
            if product.sku_key:
                if product.sku_key in sku_counts:
                    sku_counts[product.sku_key] += 1
                else:
                    sku_counts[product.sku_key] = 1
        
        duplicate_skus = [sku for sku, count in sku_counts.items() if count > 1]
        
        if duplicate_skus:
            self.log_test("SKU唯一性", False, f"发现 {len(duplicate_skus)} 个重复SKU")
        else:
            self.log_test("SKU唯一性", True)
        
        # 测试SKU格式（应该是64字符的SHA-256哈希）
        invalid_skus = []
        for product in products:
            if product.sku_key and len(product.sku_key) != 64:
                invalid_skus.append(product.sku_key)
        
        if invalid_skus:
            self.log_test("SKU格式正确", False, f"{len(invalid_skus)} 个SKU格式错误")
        else:
            self.log_test("SKU格式正确", True)
        
        # 测试SKU生成的一致性（相同属性应该生成相同SKU）
        if len(products) >= 2:
            p1 = products[0]
            # 查找具有相同属性的产品
            result = await db.execute(
                select(ProductMaster)
                .where(and_(
                    ProductMaster.品牌 == p1.品牌,
                    ProductMaster.货号 == p1.货号,
                    ProductMaster.颜色 == p1.颜色,
                    ProductMaster.尺寸 == p1.尺寸,
                    ProductMaster.id != p1.id
                ))
                .limit(1)
            )
            similar_product = result.scalar_one_or_none()
            
            if similar_product:
                if similar_product.sku_key == p1.sku_key:
                    self.log_test("SKU一致性", True)
                else:
                    self.log_test("SKU一致性", False, "相同属性生成了不同SKU")
            else:
                self.log_test("SKU一致性", True, "无相似产品可比较")
    
    async def test_procurement_method_assignment(self, db):
        """测试采购方式分配逻辑"""
        print("\n📋 测试3: 采购方式分配")
        
        # 获取采购订单的方式分布
        result = await db.execute(
            select(
                ProcurementOrder.procurement_method,
                func.count(ProcurementOrder.id).label('count')
            )
            .group_by(ProcurementOrder.procurement_method)
        )
        method_distribution = dict(result.all())
        
        if not method_distribution:
            self.log_warning("没有采购订单数据")
            return
        
        # 检查是否有有效的采购方式
        valid_methods = ['NY', 'LA', 'MC', 'GN', 'AT', 'AP', 'SS']
        invalid_methods = [m for m in method_distribution.keys() if m not in valid_methods]
        
        if invalid_methods:
            self.log_test("采购方式有效性", False, f"发现无效方式: {invalid_methods}")
        else:
            self.log_test("采购方式有效性", True)
        
        # 显示分布
        print("  采购方式分布:")
        for method, count in sorted(method_distribution.items(), key=lambda x: x[1], reverse=True):
            percentage = (count / sum(method_distribution.values())) * 100
            print(f"    - {method}: {count} ({percentage:.1f}%)")
        
        # 检查默认方式NY是否存在
        if 'NY' in method_distribution:
            self.log_test("默认方式(NY)存在", True)
        else:
            self.log_test("默认方式(NY)存在", False, "未找到NY方式的订单")
        
        # 测试特定品牌的采购方式分配
        # 例如：检查某些品牌是否正确分配到特定采购方式
        brand_method_query = """
            SELECT po.品牌, po.procurement_method, COUNT(*) as cnt
            FROM procurement_orders po
            WHERE po.品牌 IS NOT NULL
            GROUP BY po.品牌, po.procurement_method
            ORDER BY cnt DESC
            LIMIT 10
        """
        
        from sqlalchemy import text
        result = await db.execute(text(brand_method_query))
        brand_methods = result.fetchall()
        
        if brand_methods:
            print("  品牌采购方式分配（前10）:")
            for brand, method, count in brand_methods:
                print(f"    - {brand}: {method} ({count})")
    
    async def test_order_status_consistency(self, db):
        """测试订单状态一致性"""
        print("\n📋 测试4: 订单状态一致性")
        
        # 检查原始订单和采购订单的状态一致性
        inconsistent_query = """
            SELECT COUNT(*)
            FROM procurement_orders po
            JOIN raw_orders ro ON po.original_order_id = ro.id
            WHERE po.交易状态 != ro.交易状态
               OR po.退款状态 != ro.退款状态
        """
        
        from sqlalchemy import text
        result = await db.execute(text(inconsistent_query))
        inconsistent_count = result.scalar()
        
        if inconsistent_count > 0:
            self.log_test("状态一致性", False, f"{inconsistent_count} 个订单状态不一致")
        else:
            self.log_test("状态一致性", True)
        
        # 检查是否有已取消但仍在采购的订单
        cancelled_but_pending = await db.execute(
            select(func.count(ProcurementOrder.id))
            .where(and_(
                ProcurementOrder.交易状态 == '已取消',
                ProcurementOrder.procurement_status == ProcurementStatus.PENDING
            ))
        )
        cancelled_count = cancelled_but_pending.scalar()
        
        if cancelled_count > 0:
            self.log_warning(f"发现 {cancelled_count} 个已取消但仍待采购的订单")
        else:
            self.log_test("取消订单处理", True)
        
        # 检查退款成功的订单
        refunded_query = """
            SELECT COUNT(*)
            FROM procurement_orders
            WHERE 退款状态 = '退款成功'
        """
        
        result = await db.execute(text(refunded_query))
        refunded_count = result.scalar()
        
        if refunded_count > 0:
            self.log_test("退款订单保留", True, f"保留了 {refunded_count} 个退款订单")
        else:
            self.log_test("退款订单保留", True, "暂无退款订单")
    
    async def test_price_calculations(self, db):
        """测试价格计算逻辑"""
        print("\n📋 测试5: 价格计算")
        
        # 检查订单金额计算是否正确（数量 * 单价 = 金额）
        price_error_query = """
            SELECT COUNT(*)
            FROM raw_orders
            WHERE ABS(数量 * 订单单价 - 订单金额) > 0.01
              AND 数量 IS NOT NULL
              AND 订单单价 IS NOT NULL
              AND 订单金额 IS NOT NULL
        """
        
        from sqlalchemy import text
        result = await db.execute(text(price_error_query))
        error_count = result.scalar()
        
        if error_count > 0:
            self.log_test("订单金额计算", False, f"{error_count} 个订单金额计算错误")
        else:
            self.log_test("订单金额计算", True)
        
        # 检查产品平均价格
        result = await db.execute(
            select(
                ProductMaster.avg_price,
                ProductMaster.min_price,
                ProductMaster.max_price,
                ProductMaster.total_quantity
            )
            .where(ProductMaster.avg_price.is_not(None))
            .limit(10)
        )
        products_with_price = result.all()
        
        invalid_prices = 0
        for avg_price, min_price, max_price, quantity in products_with_price:
            if avg_price and min_price and max_price:
                if not (min_price <= avg_price <= max_price):
                    invalid_prices += 1
        
        if invalid_prices > 0:
            self.log_test("价格范围有效性", False, f"{invalid_prices} 个产品价格范围错误")
        else:
            self.log_test("价格范围有效性", True)
        
        # 检查建议采购价
        result = await db.execute(
            select(func.count(ProcurementOrder.id))
            .where(and_(
                ProcurementOrder.建议采购价.is_not(None),
                ProcurementOrder.建议采购价 > 0
            ))
        )
        with_suggested_price = result.scalar()
        
        result = await db.execute(select(func.count(ProcurementOrder.id)))
        total_procurement = result.scalar()
        
        if total_procurement > 0:
            percentage = (with_suggested_price / total_procurement) * 100
            print(f"  建议采购价覆盖率: {percentage:.1f}%")
            if percentage > 50:
                self.log_test("建议采购价", True, f"覆盖率 {percentage:.1f}%")
            else:
                self.log_test("建议采购价", False, f"覆盖率仅 {percentage:.1f}%")
    
    async def test_data_relationships(self, db):
        """测试数据关系完整性"""
        print("\n📋 测试6: 数据关系完整性")
        
        # 检查孤立的采购订单（没有对应原始订单）
        orphaned_query = """
            SELECT COUNT(*)
            FROM procurement_orders po
            LEFT JOIN raw_orders ro ON po.original_order_id = ro.id
            WHERE ro.id IS NULL
        """
        
        from sqlalchemy import text
        result = await db.execute(text(orphaned_query))
        orphaned_count = result.scalar()
        
        if orphaned_count > 0:
            self.log_test("采购订单关联", False, f"{orphaned_count} 个孤立采购订单")
        else:
            self.log_test("采购订单关联", True)
        
        # 检查产品主表和原始订单的映射关系
        mapping_query = """
            SELECT COUNT(DISTINCT pm.id) as product_count,
                   COUNT(DISTINCT psm.raw_order_id) as order_count
            FROM products_master pm
            LEFT JOIN product_source_mappings psm ON pm.id = psm.product_master_id
        """
        
        result = await db.execute(text(mapping_query))
        mapping_stats = result.fetchone()
        
        if mapping_stats:
            product_count, order_count = mapping_stats
            if order_count > 0:
                avg_orders_per_product = order_count / product_count if product_count > 0 else 0
                print(f"  产品映射统计:")
                print(f"    - 产品数: {product_count}")
                print(f"    - 映射订单数: {order_count}")
                print(f"    - 平均每产品订单数: {avg_orders_per_product:.1f}")
                self.log_test("产品映射关系", True)
            else:
                self.log_test("产品映射关系", False, "无映射关系")
    
    async def test_fifo_logic(self, db):
        """测试FIFO（先进先出）逻辑"""
        print("\n📋 测试7: FIFO逻辑")
        
        # 获取同一产品的多个采购订单，按付款时间排序
        fifo_query = """
            SELECT product_key, 原始订单编号, 付款时间, procurement_status
            FROM procurement_orders
            WHERE product_key IS NOT NULL
            ORDER BY product_key, 付款时间
        """
        
        from sqlalchemy import text
        result = await db.execute(text(fifo_query))
        orders = result.fetchall()
        
        if not orders:
            self.log_warning("没有足够的数据测试FIFO")
            return
        
        # 按产品分组检查FIFO
        products_fifo = {}
        for product_key, order_id, payment_time, status in orders:
            if product_key not in products_fifo:
                products_fifo[product_key] = []
            products_fifo[product_key].append({
                'order_id': order_id,
                'payment_time': payment_time,
                'status': status
            })
        
        # 检查是否有违反FIFO的情况
        fifo_violations = 0
        for product_key, product_orders in products_fifo.items():
            if len(product_orders) < 2:
                continue
            
            # 检查是否有较早的订单还是PENDING，而较晚的订单已经ORDERED/RECEIVED
            for i in range(len(product_orders) - 1):
                for j in range(i + 1, len(product_orders)):
                    if (product_orders[i]['status'] == 'PENDING' and 
                        product_orders[j]['status'] in ['ORDERED', 'RECEIVED']):
                        fifo_violations += 1
                        break
        
        if fifo_violations > 0:
            self.log_warning(f"发现 {fifo_violations} 个可能违反FIFO的情况")
        else:
            self.log_test("FIFO逻辑", True)
    
    def print_summary(self):
        """打印测试总结"""
        print("\n" + "=" * 60)
        print("📊 业务逻辑测试总结")
        print("=" * 60)
        
        total = self.passed + self.failed
        if total > 0:
            pass_rate = (self.passed / total) * 100
            print(f"测试结果:")
            print(f"  - 通过: {self.passed}")
            print(f"  - 失败: {self.failed}")
            print(f"  - 警告: {self.warnings}")
            print(f"  - 通过率: {pass_rate:.1f}%")
            
            if pass_rate == 100 and self.warnings == 0:
                print("\n🎉 所有业务逻辑测试通过，无警告！")
            elif pass_rate == 100:
                print(f"\n✅ 所有测试通过，但有 {self.warnings} 个警告需要关注")
            elif pass_rate >= 80:
                print("\n✅ 大部分业务逻辑正常")
            else:
                print("\n⚠️ 多个业务逻辑测试失败，需要修复")
        
        print("=" * 60)


async def main():
    """主测试函数"""
    print("=" * 60)
    print("🚀 开始业务逻辑测试")
    print(f"时间: {datetime.now():%Y-%m-%d %H:%M:%S}")
    print("=" * 60)
    
    tests = BusinessLogicTests()
    
    async with AsyncSessionLocal() as db:
        await tests.test_order_deduplication(db)
        await tests.test_product_sku_generation(db)
        await tests.test_procurement_method_assignment(db)
        await tests.test_order_status_consistency(db)
        await tests.test_price_calculations(db)
        await tests.test_data_relationships(db)
        await tests.test_fifo_logic(db)
    
    tests.print_summary()


if __name__ == "__main__":
    asyncio.run(main())