#!/usr/bin/env python3
"""
测试修复后的价格计算逻辑
区分合并订单和单品订单进行验证
"""

import asyncio
import logging
from sqlalchemy import text
from sqlalchemy.ext.asyncio import AsyncSession

from app.core.database import get_db

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)


async def test_price_calculations(db: AsyncSession):
    """测试价格计算的准确性"""
    
    logger.info("=" * 60)
    logger.info("测试修复后的价格计算")
    logger.info("=" * 60)
    
    # 1. 测试单品订单的价格计算
    logger.info("\n1. 单品订单价格验证:")
    single_order_test = text("""
        SELECT 
            COUNT(*) as total_single_orders,
            SUM(CASE 
                WHEN ABS(商品小计 - (CAST(数量 AS FLOAT) * CAST(订单单价 AS FLOAT))) > 0.01 
                THEN 1 ELSE 0 
            END) as calculation_errors,
            SUM(CASE 
                WHEN 商品小计 IS NULL 
                THEN 1 ELSE 0 
            END) as null_subtotals
        FROM raw_orders
        WHERE 是否合并订单 = 0
        AND 数量 IS NOT NULL 
        AND 订单单价 IS NOT NULL
        AND 数量 != ''
        AND 订单单价 != ''
    """)
    
    result = await db.execute(single_order_test)
    single_stats = result.first()
    
    logger.info(f"  总单品订单数: {single_stats[0]}")
    logger.info(f"  计算错误数: {single_stats[1]}")
    logger.info(f"  NULL小计数: {single_stats[2]}")
    
    if single_stats[1] == 0:
        logger.info("  ✅ 所有单品订单价格计算正确!")
    else:
        logger.warning(f"  ⚠️ 发现 {single_stats[1]} 个单品订单计算错误")
        
        # 显示错误示例
        error_examples = text("""
            SELECT 
                原始订单编号,
                线上宝贝名称,
                数量,
                订单单价,
                商品小计,
                CAST(数量 AS FLOAT) * CAST(订单单价 AS FLOAT) as 应计算值
            FROM raw_orders
            WHERE 是否合并订单 = 0
            AND 数量 IS NOT NULL 
            AND 订单单价 IS NOT NULL
            AND ABS(商品小计 - (CAST(数量 AS FLOAT) * CAST(订单单价 AS FLOAT))) > 0.01
            LIMIT 5
        """)
        
        result = await db.execute(error_examples)
        errors = result.fetchall()
        
        for row in errors:
            logger.warning(f"    订单: {row[0][:20]}...")
            logger.warning(f"      商品: {row[1][:30]}...")
            logger.warning(f"      计算: {row[2]} × {row[3]} = {row[5]}")
            logger.warning(f"      实际小计: {row[4]}")
    
    # 2. 测试合并订单的处理
    logger.info("\n2. 合并订单验证:")
    merged_order_test = text("""
        SELECT 
            COUNT(DISTINCT 原始订单编号) as unique_merged_orders,
            COUNT(*) as total_merged_records,
            SUM(CASE WHEN 商品小计 IS NULL THEN 1 ELSE 0 END) as null_subtotals,
            AVG(CASE WHEN 商品小计 IS NOT NULL THEN 商品小计 ELSE 0 END) as avg_subtotal
        FROM raw_orders
        WHERE 是否合并订单 = 1
    """)
    
    result = await db.execute(merged_order_test)
    merged_stats = result.first()
    
    logger.info(f"  合并订单数: {merged_stats[0]}")
    logger.info(f"  合并订单记录数: {merged_stats[1]}")
    logger.info(f"  平均每个订单商品数: {merged_stats[1]/merged_stats[0] if merged_stats[0] > 0 else 0:.1f}")
    logger.info(f"  NULL小计数: {merged_stats[2]}")
    logger.info(f"  平均商品小计: {merged_stats[3]:.2f}")
    
    # 3. 验证合并订单的总金额一致性
    logger.info("\n3. 合并订单总金额一致性检查:")
    consistency_check = text("""
        WITH order_totals AS (
            SELECT 
                原始订单编号,
                SUM(商品小计) as 计算总金额,
                MAX(CAST(订单金额 AS FLOAT)) as 标记总金额,
                COUNT(*) as 商品数量
            FROM raw_orders
            WHERE 是否合并订单 = 1
            AND 商品小计 IS NOT NULL
            GROUP BY 原始订单编号
        )
        SELECT 
            COUNT(*) as total_orders,
            SUM(CASE 
                WHEN ABS(计算总金额 - 标记总金额) < 0.01 
                THEN 1 ELSE 0 
            END) as consistent_orders,
            AVG(ABS(计算总金额 - 标记总金额)) as avg_difference
        FROM order_totals
    """)
    
    result = await db.execute(consistency_check)
    consistency = result.first()
    
    if consistency and consistency[0] > 0:
        consistency_rate = (consistency[1] / consistency[0]) * 100
        logger.info(f"  检查的合并订单数: {consistency[0]}")
        logger.info(f"  金额一致的订单数: {consistency[1]}")
        logger.info(f"  一致率: {consistency_rate:.1f}%")
        logger.info(f"  平均差异: {consistency[2]:.2f}")
        
        if consistency_rate == 100:
            logger.info("  ✅ 所有合并订单总金额完全一致!")
        elif consistency_rate > 95:
            logger.info("  ✅ 合并订单总金额基本一致")
        else:
            logger.warning(f"  ⚠️ 合并订单总金额一致性较低: {consistency_rate:.1f}%")
    
    # 4. 原始505个问题订单的验证
    logger.info("\n4. 验证原始问题订单 (数量×单价≠订单金额):")
    original_problem_check = text("""
        SELECT 
            COUNT(*) as total_records,
            SUM(CASE 
                WHEN 是否合并订单 = 1 THEN 1 ELSE 0 
            END) as merged_records,
            SUM(CASE 
                WHEN 是否合并订单 = 0 
                AND ABS(CAST(数量 AS FLOAT) * CAST(订单单价 AS FLOAT) - CAST(订单金额 AS FLOAT)) > 0.01
                THEN 1 ELSE 0 
            END) as single_order_errors
        FROM raw_orders
        WHERE 数量 IS NOT NULL 
        AND 订单单价 IS NOT NULL 
        AND 订单金额 IS NOT NULL
        AND 数量 != ''
        AND 订单单价 != ''
        AND 订单金额 != ''
    """)
    
    result = await db.execute(original_problem_check)
    problem_stats = result.first()
    
    logger.info(f"  总记录数: {problem_stats[0]}")
    logger.info(f"  识别为合并订单: {problem_stats[1]}")
    logger.info(f"  单品订单计算错误: {problem_stats[2]}")
    
    if problem_stats[2] == 0:
        logger.info("  ✅ 原始505个问题已全部解决!")
    else:
        logger.warning(f"  ⚠️ 仍有 {problem_stats[2]} 个单品订单存在问题")
    
    # 5. 总体统计
    logger.info("\n5. 总体统计:")
    overall_stats = text("""
        SELECT 
            COUNT(DISTINCT 原始订单编号) as total_orders,
            COUNT(*) as total_records,
            SUM(CASE WHEN 是否合并订单 = 1 THEN 1 ELSE 0 END) as merged_records,
            SUM(CASE WHEN 是否合并订单 = 0 THEN 1 ELSE 0 END) as single_records,
            SUM(CASE WHEN 商品小计 IS NOT NULL THEN 1 ELSE 0 END) as has_subtotal,
            SUM(CASE WHEN 商品小计 IS NULL THEN 1 ELSE 0 END) as no_subtotal
        FROM raw_orders
    """)
    
    result = await db.execute(overall_stats)
    overall = result.first()
    
    logger.info(f"  总订单数: {overall[0]}")
    logger.info(f"  总记录数: {overall[1]}")
    logger.info(f"  合并订单记录: {overall[2]} ({overall[2]/overall[1]*100:.1f}%)")
    logger.info(f"  单品订单记录: {overall[3]} ({overall[3]/overall[1]*100:.1f}%)")
    logger.info(f"  有商品小计: {overall[4]} ({overall[4]/overall[1]*100:.1f}%)")
    logger.info(f"  无商品小计: {overall[5]} ({overall[5]/overall[1]*100:.1f}%)")


async def main():
    """主函数"""
    async for db in get_db():
        try:
            await test_price_calculations(db)
            
            logger.info("\n" + "=" * 60)
            logger.info("测试完成")
            logger.info("=" * 60)
            
        except Exception as e:
            logger.error(f"测试失败: {e}", exc_info=True)
        finally:
            await db.close()
            break


if __name__ == "__main__":
    asyncio.run(main())