#!/usr/bin/env python3
"""
综合测试系统核心逻辑
包括：导入、去重、解析、产品主表生成、采购订单生成等
"""

import asyncio
import logging
from datetime import datetime, timedelta
from typing import Dict, Any, List
import json
from pathlib import Path

from sqlalchemy import select, text, func
from sqlalchemy.ext.asyncio import AsyncSession

# 设置日志
logging.basicConfig(level=logging.INFO, format='%(levelname)s - %(message)s')
logger = logging.getLogger(__name__)


class CoreLogicTester:
    """核心逻辑测试器"""
    
    def __init__(self):
        self.test_results = []
        self.passed = 0
        self.failed = 0
        self.warnings = 0
    
    def add_result(self, test_name: str, status: str, message: str = None, details: Dict = None):
        """添加测试结果"""
        result = {
            'test': test_name,
            'status': status,
            'message': message,
            'details': details,
            'timestamp': datetime.now().isoformat()
        }
        self.test_results.append(result)
        
        if status == 'PASS':
            self.passed += 1
            logger.info(f"✅ {test_name}: PASS {message or ''}")
        elif status == 'FAIL':
            self.failed += 1
            logger.error(f"❌ {test_name}: FAIL {message or ''}")
        elif status == 'WARNING':
            self.warnings += 1
            logger.warning(f"⚠️ {test_name}: WARNING {message or ''}")
    
    async def test_database_connection(self):
        """测试数据库连接"""
        try:
            from backend.app.core.database import get_db
            
            async for db in get_db():
                # 测试基本查询
                result = await db.execute(text("SELECT 1"))
                if result.scalar() == 1:
                    self.add_result("数据库连接", "PASS", "数据库连接正常")
                else:
                    self.add_result("数据库连接", "FAIL", "数据库查询异常")
                await db.close()
                break
        except Exception as e:
            self.add_result("数据库连接", "FAIL", f"无法连接数据库: {str(e)}")
    
    async def test_import_deduplication(self, db: AsyncSession):
        """测试导入去重逻辑"""
        try:
            # 检查去重约束
            result = await db.execute(text("""
                SELECT COUNT(*) as total_orders,
                       COUNT(DISTINCT 原始订单编号 || '-' || 线上宝贝名称 || '-' || 
                             线上销售属性 || '-' || 数量 || '-' || 付款时间) as unique_orders
                FROM raw_orders
            """))
            stats = result.first()
            
            if stats[0] == stats[1]:
                self.add_result("导入去重", "PASS", 
                              f"去重逻辑正常，共{stats[0]}条唯一记录")
            else:
                self.add_result("导入去重", "WARNING", 
                              f"可能存在重复: 总记录{stats[0]}, 唯一记录{stats[1]}")
                
            # 检查重复记录详情
            dup_check = await db.execute(text("""
                SELECT 原始订单编号, COUNT(*) as cnt
                FROM raw_orders
                GROUP BY 原始订单编号, 线上宝贝名称, 线上销售属性, 数量, 付款时间
                HAVING COUNT(*) > 1
                LIMIT 5
            """))
            dups = dup_check.fetchall()
            
            if dups:
                self.add_result("重复记录检查", "WARNING", 
                              f"发现{len(dups)}组重复记录", 
                              {'duplicates': [dict(row) for row in dups]})
            
        except Exception as e:
            self.add_result("导入去重", "FAIL", f"测试失败: {str(e)}")
    
    async def test_text_parsing(self, db: AsyncSession):
        """测试文本解析逻辑"""
        try:
            # 测试品牌提取
            brand_test = await db.execute(text("""
                SELECT COUNT(*) as total,
                       SUM(CASE WHEN brand IS NOT NULL AND brand != '' THEN 1 ELSE 0 END) as has_brand,
                       COUNT(DISTINCT brand) as unique_brands
                FROM products_master
            """))
            brand_stats = brand_test.first()
            
            if brand_stats and brand_stats[0] > 0:
                brand_rate = (brand_stats[1] / brand_stats[0]) * 100
                self.add_result("品牌解析", "PASS" if brand_rate > 90 else "WARNING",
                              f"品牌提取率: {brand_rate:.1f}%, 唯一品牌数: {brand_stats[2]}")
            
            # 测试颜色提取
            color_test = await db.execute(text("""
                SELECT COUNT(*) as total,
                       SUM(CASE WHEN color IS NOT NULL AND color != '' THEN 1 ELSE 0 END) as has_color
                FROM products_master
            """))
            color_stats = color_test.first()
            
            if color_stats and color_stats[0] > 0:
                color_rate = (color_stats[1] / color_stats[0]) * 100
                self.add_result("颜色解析", "PASS" if color_rate > 80 else "WARNING",
                              f"颜色提取率: {color_rate:.1f}%")
            
            # 测试尺码提取
            size_test = await db.execute(text("""
                SELECT COUNT(*) as total,
                       SUM(CASE WHEN size IS NOT NULL AND size != '' THEN 1 ELSE 0 END) as has_size
                FROM products_master
            """))
            size_stats = size_test.first()
            
            if size_stats and size_stats[0] > 0:
                size_rate = (size_stats[1] / size_stats[0]) * 100
                self.add_result("尺码解析", "PASS" if size_rate > 70 else "WARNING",
                              f"尺码提取率: {size_rate:.1f}%")
            
        except Exception as e:
            self.add_result("文本解析", "FAIL", f"测试失败: {str(e)}")
    
    async def test_product_master_generation(self, db: AsyncSession):
        """测试产品主表生成逻辑"""
        try:
            # 检查产品主表完整性
            master_check = await db.execute(text("""
                SELECT 
                    COUNT(*) as total_products,
                    COUNT(DISTINCT product_key) as unique_keys,
                    SUM(CASE WHEN product_name IS NULL THEN 1 ELSE 0 END) as null_names,
                    SUM(CASE WHEN product_key IS NULL THEN 1 ELSE 0 END) as null_keys,
                    SUM(CASE WHEN total_quantity IS NULL OR total_quantity = 0 THEN 1 ELSE 0 END) as zero_qty
                FROM products_master
            """))
            stats = master_check.first()
            
            if stats:
                if stats[0] == stats[1] and stats[3] == 0:
                    self.add_result("产品主表生成", "PASS",
                                  f"生成{stats[0]}个产品，所有product_key唯一")
                else:
                    self.add_result("产品主表生成", "WARNING",
                                  f"产品数: {stats[0]}, 唯一key: {stats[1]}, NULL key: {stats[3]}")
                
                if stats[4] > 0:
                    self.add_result("产品数量检查", "WARNING",
                                  f"{stats[4]}个产品数量为0或NULL")
            
            # 检查SKU生成逻辑
            sku_check = await db.execute(text("""
                SELECT COUNT(DISTINCT product_key) as unique_skus,
                       MIN(LENGTH(product_key)) as min_len,
                       MAX(LENGTH(product_key)) as max_len
                FROM products_master
                WHERE product_key IS NOT NULL
            """))
            sku_stats = sku_check.first()
            
            if sku_stats and sku_stats[1] == 64 and sku_stats[2] == 64:
                self.add_result("SKU生成", "PASS", 
                              f"所有SKU长度正确(64位SHA-256)")
            else:
                self.add_result("SKU生成", "WARNING",
                              f"SKU长度异常: 最小{sku_stats[1]}, 最大{sku_stats[2]}")
            
        except Exception as e:
            self.add_result("产品主表生成", "FAIL", f"测试失败: {str(e)}")
    
    async def test_procurement_order_generation(self, db: AsyncSession):
        """测试采购订单生成逻辑"""
        try:
            # 检查采购订单状态分布
            status_check = await db.execute(text("""
                SELECT 
                    procurement_status,
                    COUNT(*) as count
                FROM procurement_orders
                GROUP BY procurement_status
            """))
            status_dist = {row[0]: row[1] for row in status_check.fetchall()}
            
            total_orders = sum(status_dist.values())
            if total_orders > 0:
                self.add_result("采购订单生成", "PASS",
                              f"共生成{total_orders}个采购订单",
                              {'status_distribution': status_dist})
            else:
                self.add_result("采购订单生成", "WARNING", "未找到采购订单")
            
            # 检查采购方式分布
            method_check = await db.execute(text("""
                SELECT 
                    procurement_method,
                    COUNT(*) as count
                FROM procurement_orders
                GROUP BY procurement_method
                ORDER BY count DESC
            """))
            method_dist = {row[0]: row[1] for row in method_check.fetchall()}
            
            if method_dist:
                self.add_result("采购方式分配", "PASS",
                              f"采购方式分布正常",
                              {'method_distribution': method_dist})
            
            # 检查FIFO逻辑
            fifo_check = await db.execute(text("""
                SELECT COUNT(*) as total
                FROM procurement_orders p1
                WHERE EXISTS (
                    SELECT 1 FROM procurement_orders p2
                    WHERE p2.product_key = p1.product_key
                    AND p2.付款时间 < p1.付款时间
                    AND p2.procurement_status = 'PENDING'
                    AND p1.procurement_status != 'PENDING'
                )
            """))
            fifo_violations = fifo_check.scalar()
            
            if fifo_violations == 0:
                self.add_result("FIFO逻辑", "PASS", "FIFO顺序正确")
            else:
                self.add_result("FIFO逻辑", "WARNING",
                              f"发现{fifo_violations}个FIFO违规")
            
        except Exception as e:
            self.add_result("采购订单生成", "FAIL", f"测试失败: {str(e)}")
    
    async def test_merged_order_handling(self, db: AsyncSession):
        """测试合并订单处理逻辑"""
        try:
            # 检查合并订单识别
            merged_check = await db.execute(text("""
                SELECT 
                    COUNT(DISTINCT 原始订单编号) as merged_orders,
                    COUNT(*) as merged_records,
                    AVG(cnt) as avg_items_per_order
                FROM (
                    SELECT 原始订单编号, COUNT(*) as cnt
                    FROM raw_orders
                    WHERE 是否合并订单 = 1
                    GROUP BY 原始订单编号
                ) t
            """))
            merged_stats = merged_check.first()
            
            if merged_stats and merged_stats[0] > 0:
                self.add_result("合并订单识别", "PASS",
                              f"识别{merged_stats[0]}个合并订单，平均{merged_stats[2]:.1f}个商品/订单")
            
            # 检查商品小计计算
            subtotal_check = await db.execute(text("""
                SELECT 
                    COUNT(*) as total,
                    SUM(CASE WHEN 商品小计 IS NOT NULL THEN 1 ELSE 0 END) as has_subtotal,
                    SUM(CASE 
                        WHEN 商品小计 IS NOT NULL 
                        AND ABS(商品小计 - (CAST(数量 AS FLOAT) * CAST(订单单价 AS FLOAT))) < 0.01
                        THEN 1 ELSE 0 
                    END) as correct_subtotal
                FROM raw_orders
                WHERE 数量 IS NOT NULL AND 订单单价 IS NOT NULL
            """))
            subtotal_stats = subtotal_check.first()
            
            if subtotal_stats and subtotal_stats[0] > 0:
                accuracy = (subtotal_stats[2] / subtotal_stats[0]) * 100
                self.add_result("商品小计计算", "PASS" if accuracy > 99 else "WARNING",
                              f"计算准确率: {accuracy:.1f}%")
            
        except Exception as e:
            self.add_result("合并订单处理", "FAIL", f"测试失败: {str(e)}")
    
    async def test_data_integrity(self, db: AsyncSession):
        """测试数据完整性"""
        try:
            # 检查原始订单与采购订单的关联
            orphan_check = await db.execute(text("""
                SELECT COUNT(*) as orphan_orders
                FROM procurement_orders
                WHERE original_order_id NOT IN (SELECT id FROM raw_orders)
            """))
            orphan_count = orphan_check.scalar()
            
            if orphan_count == 0:
                self.add_result("数据关联完整性", "PASS", "所有采购订单都有对应的原始订单")
            else:
                self.add_result("数据关联完整性", "FAIL",
                              f"发现{orphan_count}个孤立的采购订单")
            
            # 检查产品主表与原始订单的一致性
            consistency_check = await db.execute(text("""
                SELECT 
                    COUNT(DISTINCT ro.线上宝贝名称) as raw_products,
                    COUNT(DISTINCT pm.product_name) as master_products
                FROM raw_orders ro
                LEFT JOIN products_master pm ON pm.product_name = ro.线上宝贝名称
            """))
            consistency = consistency_check.first()
            
            if consistency:
                self.add_result("产品名称一致性", "INFO",
                              f"原始产品数: {consistency[0]}, 主表产品数: {consistency[1]}")
            
            # 检查必需字段完整性
            required_fields_check = await db.execute(text("""
                SELECT 
                    SUM(CASE WHEN 原始订单编号 IS NULL THEN 1 ELSE 0 END) as null_order_id,
                    SUM(CASE WHEN 线上宝贝名称 IS NULL THEN 1 ELSE 0 END) as null_product,
                    SUM(CASE WHEN 付款时间 IS NULL THEN 1 ELSE 0 END) as null_payment_time
                FROM raw_orders
            """))
            nulls = required_fields_check.first()
            
            if nulls and sum(nulls) == 0:
                self.add_result("必需字段完整性", "PASS", "所有必需字段都有值")
            else:
                self.add_result("必需字段完整性", "WARNING",
                              f"NULL值: 订单号{nulls[0]}, 产品名{nulls[1]}, 付款时间{nulls[2]}")
            
        except Exception as e:
            self.add_result("数据完整性", "FAIL", f"测试失败: {str(e)}")
    
    async def test_performance_metrics(self, db: AsyncSession):
        """测试性能指标"""
        try:
            # 测试查询性能
            start_time = datetime.now()
            await db.execute(text("""
                SELECT COUNT(*) FROM raw_orders
                WHERE 付款时间 >= date('now', '-30 days')
            """))
            query_time = (datetime.now() - start_time).total_seconds()
            
            if query_time < 1:
                self.add_result("查询性能", "PASS", f"30天订单查询: {query_time:.3f}秒")
            else:
                self.add_result("查询性能", "WARNING", f"查询较慢: {query_time:.3f}秒")
            
            # 检查索引使用
            index_check = await db.execute(text("""
                SELECT name FROM sqlite_master 
                WHERE type = 'index' AND tbl_name = 'raw_orders'
            """))
            indexes = [row[0] for row in index_check.fetchall()]
            
            if len(indexes) >= 3:
                self.add_result("索引配置", "PASS", f"配置了{len(indexes)}个索引")
            else:
                self.add_result("索引配置", "WARNING", f"索引较少: {len(indexes)}个")
            
        except Exception as e:
            self.add_result("性能测试", "FAIL", f"测试失败: {str(e)}")
    
    async def test_business_rules(self, db: AsyncSession):
        """测试业务规则"""
        try:
            # 测试价格合理性
            price_check = await db.execute(text("""
                SELECT 
                    COUNT(*) as total,
                    SUM(CASE WHEN CAST(订单单价 AS FLOAT) < 0 THEN 1 ELSE 0 END) as negative_price,
                    SUM(CASE WHEN CAST(订单单价 AS FLOAT) > 100000 THEN 1 ELSE 0 END) as excessive_price
                FROM raw_orders
                WHERE 订单单价 IS NOT NULL
            """))
            price_stats = price_check.first()
            
            if price_stats and price_stats[1] == 0 and price_stats[2] == 0:
                self.add_result("价格合理性", "PASS", "所有价格在合理范围内")
            else:
                self.add_result("价格合理性", "WARNING",
                              f"负价格: {price_stats[1]}, 超高价格: {price_stats[2]}")
            
            # 测试时间逻辑
            time_check = await db.execute(text("""
                SELECT COUNT(*) as invalid_time
                FROM raw_orders
                WHERE 付款时间 > datetime('now')
                   OR 付款时间 < datetime('2020-01-01')
            """))
            invalid_time = time_check.scalar()
            
            if invalid_time == 0:
                self.add_result("时间逻辑", "PASS", "所有订单时间合理")
            else:
                self.add_result("时间逻辑", "WARNING", f"{invalid_time}个订单时间异常")
            
            # 测试数量合理性
            qty_check = await db.execute(text("""
                SELECT 
                    COUNT(*) as total,
                    SUM(CASE WHEN 数量 <= 0 THEN 1 ELSE 0 END) as zero_qty,
                    SUM(CASE WHEN 数量 > 1000 THEN 1 ELSE 0 END) as excessive_qty
                FROM raw_orders
                WHERE 数量 IS NOT NULL
            """))
            qty_stats = qty_check.first()
            
            if qty_stats and qty_stats[1] == 0:
                self.add_result("数量合理性", "PASS", "所有订单数量合理")
            else:
                self.add_result("数量合理性", "WARNING",
                              f"零数量: {qty_stats[1]}, 超大数量: {qty_stats[2]}")
            
        except Exception as e:
            self.add_result("业务规则", "FAIL", f"测试失败: {str(e)}")
    
    async def test_status_synchronization(self, db: AsyncSession):
        """测试状态同步逻辑"""
        try:
            # 检查原始订单和采购订单的状态同步
            sync_check = await db.execute(text("""
                SELECT 
                    COUNT(*) as total,
                    SUM(CASE WHEN ro.交易状态 != po.交易状态 THEN 1 ELSE 0 END) as trade_mismatch,
                    SUM(CASE WHEN ro.退款状态 != po.退款状态 THEN 1 ELSE 0 END) as refund_mismatch
                FROM procurement_orders po
                JOIN raw_orders ro ON po.original_order_id = ro.id
            """))
            sync_stats = sync_check.first()
            
            if sync_stats and sync_stats[0] > 0:
                sync_rate = ((sync_stats[0] - sync_stats[1] - sync_stats[2]) / sync_stats[0]) * 100
                if sync_rate == 100:
                    self.add_result("状态同步", "PASS", "订单状态完全同步")
                else:
                    self.add_result("状态同步", "WARNING",
                                  f"同步率: {sync_rate:.1f}%, 交易状态不一致: {sync_stats[1]}, 退款状态不一致: {sync_stats[2]}")
            
            # 检查已取消订单的处理
            cancelled_check = await db.execute(text("""
                SELECT COUNT(*) as cancelled_but_active
                FROM procurement_orders po
                JOIN raw_orders ro ON po.original_order_id = ro.id
                WHERE ro.交易状态 = '已取消'
                  AND po.procurement_status = 'PENDING'
            """))
            cancelled_active = cancelled_check.scalar()
            
            if cancelled_active > 0:
                self.add_result("已取消订单处理", "INFO",
                              f"{cancelled_active}个已取消订单仍在待采购状态（按需求保留）")
            
        except Exception as e:
            self.add_result("状态同步", "FAIL", f"测试失败: {str(e)}")
    
    def print_summary(self):
        """打印测试总结"""
        print("\n" + "=" * 80)
        print("📊 核心逻辑测试总结")
        print("=" * 80)
        
        # 按状态分组
        by_status = {}
        for result in self.test_results:
            status = result['status']
            if status not in by_status:
                by_status[status] = []
            by_status[status].append(result)
        
        # 显示各状态的测试
        for status in ['PASS', 'WARNING', 'FAIL', 'INFO']:
            if status in by_status:
                print(f"\n{status}:")
                for test in by_status[status]:
                    icon = {'PASS': '✅', 'FAIL': '❌', 'WARNING': '⚠️', 'INFO': 'ℹ️'}[status]
                    print(f"  {icon} {test['test']}: {test['message'] or ''}")
                    if test.get('details'):
                        for key, value in test['details'].items():
                            print(f"      {key}: {value}")
        
        # 总体统计
        total = self.passed + self.failed + self.warnings
        if total > 0:
            pass_rate = (self.passed / total) * 100
            print(f"\n" + "-" * 80)
            print(f"总测试数: {total}")
            print(f"通过: {self.passed} ({self.passed/total*100:.1f}%)")
            print(f"警告: {self.warnings} ({self.warnings/total*100:.1f}%)")
            print(f"失败: {self.failed} ({self.failed/total*100:.1f}%)")
            
            if pass_rate == 100:
                print("\n🎉 所有核心逻辑测试通过！")
            elif pass_rate >= 80:
                print("\n✅ 大部分核心逻辑正常")
            elif self.failed == 0:
                print("\n⚠️ 存在一些警告，但核心功能正常")
            else:
                print("\n❌ 发现严重问题，需要修复")
        
        print("=" * 80)
        
        # 保存测试报告
        report_file = Path("test_core_logic_report.json")
        with open(report_file, 'w', encoding='utf-8') as f:
            json.dump({
                'summary': {
                    'total': total,
                    'passed': self.passed,
                    'warnings': self.warnings,
                    'failed': self.failed,
                    'pass_rate': pass_rate if total > 0 else 0
                },
                'results': self.test_results,
                'timestamp': datetime.now().isoformat()
            }, f, ensure_ascii=False, indent=2)
        print(f"\n测试报告已保存到: {report_file}")


async def main():
    """主测试函数"""
    print("=" * 80)
    print("🚀 开始核心逻辑综合测试")
    print(f"时间: {datetime.now():%Y-%m-%d %H:%M:%S}")
    print("=" * 80)
    
    tester = CoreLogicTester()
    
    # 先测试数据库连接
    await tester.test_database_connection()
    
    if tester.failed == 0:
        # 如果数据库连接成功，继续其他测试
        from backend.app.core.database import get_db
        
        async for db in get_db():
            try:
                print("\n⏳ 正在执行核心逻辑测试...")
                
                # 执行各项测试
                await tester.test_import_deduplication(db)
                await tester.test_text_parsing(db)
                await tester.test_product_master_generation(db)
                await tester.test_procurement_order_generation(db)
                await tester.test_merged_order_handling(db)
                await tester.test_data_integrity(db)
                await tester.test_performance_metrics(db)
                await tester.test_business_rules(db)
                await tester.test_status_synchronization(db)
                
            except Exception as e:
                logger.error(f"测试过程出错: {e}", exc_info=True)
            finally:
                await db.close()
                break
    
    # 打印总结
    tester.print_summary()


if __name__ == "__main__":
    asyncio.run(main())