#!/usr/bin/env python3
"""
测试多SKU订单导出时的"已推送"标记逻辑
验证：导出一个SKU时，是否会误标记同一订单的其他SKU
"""

import asyncio
import sys
import os
from pathlib import Path

# 添加项目路径
project_dir = Path(__file__).parent / "backend"
sys.path.insert(0, str(project_dir))
os.chdir(str(Path(__file__).parent))

from sqlalchemy import select, and_, func
from sqlalchemy.ext.asyncio import AsyncSession
from datetime import datetime

from backend.app.core.database import get_db
from backend.app.models.procurement_orders import ProcurementOrder, ProcurementStatus
from backend.app.services.procurement_aggregation_service import ProcurementAggregationService


class MultiSkuExportTester:
    """多SKU订单导出测试器"""

    def __init__(self):
        self.service = ProcurementAggregationService()
        self.test_order_number = None
        self.original_states = []

    async def find_multi_sku_order(self, db: AsyncSession):
        """查找具有多个SKU的订单"""
        print("\n=== 步骤1: 查找多SKU订单 ===")

        # 查找有多个SKU且都未推送的订单
        subquery = select(
            ProcurementOrder.原始订单编号,
            func.count(ProcurementOrder.id).label('sku_count')
        ).where(
            and_(
                ProcurementOrder.已推送 == '否',
                ProcurementOrder.procurement_status == ProcurementStatus.PENDING
            )
        ).group_by(
            ProcurementOrder.原始订单编号
        ).having(
            func.count(ProcurementOrder.id) > 1
        ).subquery()

        # 获取第一个符合条件的订单
        query = select(subquery.c.原始订单编号, subquery.c.sku_count).order_by(
            subquery.c.sku_count.desc()
        ).limit(1)

        result = await db.execute(query)
        order = result.first()

        if order:
            self.test_order_number = order.原始订单编号
            print(f"✓ 找到测试订单: {self.test_order_number}")
            print(f"  该订单包含 {order.sku_count} 个SKU")
            return self.test_order_number
        else:
            print("✗ 未找到合适的多SKU订单")
            return None

    async def get_order_skus(self, db: AsyncSession):
        """获取订单的所有SKU详情"""
        print(f"\n=== 步骤2: 获取订单 {self.test_order_number} 的所有SKU ===")

        query = select(ProcurementOrder).where(
            ProcurementOrder.原始订单编号 == self.test_order_number
        ).order_by(ProcurementOrder.id)

        result = await db.execute(query)
        skus = result.scalars().all()

        self.original_states = []
        for sku in skus:
            state = {
                'id': sku.id,
                '商品名称': sku.线上宝贝名称,
                '销售属性': sku.线上销售属性,
                '颜色': sku.颜色,
                '尺寸': sku.尺寸,
                '数量': sku.数量,
                '已推送': sku.已推送,
                'procurement_status': sku.procurement_status
            }
            self.original_states.append(state)
            print(f"  SKU {sku.id}: {sku.线上宝贝名称}")
            print(f"    销售属性: {sku.线上销售属性}")
            print(f"    颜色: {sku.颜色}, 尺寸: {sku.尺寸}")
            print(f"    已推送: {sku.已推送}, 状态: {sku.procurement_status}")

        return self.original_states

    async def simulate_export_single_sku(self, db: AsyncSession):
        """模拟导出单个SKU"""
        if not self.original_states:
            print("✗ 没有SKU数据")
            return

        print(f"\n=== 步骤3: 模拟导出第一个SKU ===")

        # 选择第一个SKU进行导出
        first_sku = self.original_states[0]
        print(f"准备导出 SKU ID={first_sku['id']}: {first_sku['商品名称']}")
        print(f"  销售属性: {first_sku['销售属性']}")

        # 构造导出数据（模拟前端传递的数据）
        exported_products = [{
            'product_name': first_sku['商品名称'],
            'product_code': '',
            'brand': '',
            'procurement_method': 'NY',
            'quantity': first_sku['数量'],
            'skus': [{
                'color': first_sku['颜色'] or '',
                'size': first_sku['尺寸'] or '',
                'sales_attr': first_sku['销售属性'] or '',
                'quantity': first_sku['数量'],
                'order_ids': [self.test_order_number]  # 关键：这里传递的是原始订单编号
            }]
        }]

        print("\n导出数据结构:")
        print(f"  product_name: {exported_products[0]['product_name']}")
        print(f"  SKU order_ids: {exported_products[0]['skus'][0]['order_ids']}")
        print(f"  SKU sales_attr: {exported_products[0]['skus'][0]['sales_attr']}")

        # 执行更新
        print("\n执行更新...")
        result = await self.service.update_push_status_for_exported_products(
            db,
            exported_products
        )

        print(f"更新结果: {result}")
        return result

    async def verify_results(self, db: AsyncSession):
        """验证更新后的结果"""
        print(f"\n=== 步骤4: 验证更新结果 ===")

        # 重新查询所有SKU的状态
        query = select(ProcurementOrder).where(
            ProcurementOrder.原始订单编号 == self.test_order_number
        ).order_by(ProcurementOrder.id)

        result = await db.execute(query)
        skus = result.scalars().all()

        print("\n更新后的状态:")
        updated_count = 0
        unexpected_updates = []

        for i, sku in enumerate(skus):
            original = self.original_states[i]
            status_changed = original['已推送'] != sku.已推送

            print(f"  SKU {sku.id}: {sku.线上宝贝名称}")
            print(f"    销售属性: {sku.线上销售属性}")
            print(f"    已推送: {original['已推送']} -> {sku.已推送} {'✓ 已更新' if status_changed else ''}")

            if status_changed:
                updated_count += 1
                if i > 0:  # 第一个SKU应该被更新，其他的不应该
                    unexpected_updates.append(sku.id)

        print(f"\n测试结果:")
        print(f"  总SKU数: {len(skus)}")
        print(f"  更新的SKU数: {updated_count}")

        if updated_count == 1 and not unexpected_updates:
            print("✓ 测试通过: 只有导出的SKU被标记为已推送")
            return True
        else:
            print(f"✗ 测试失败: 期望更新1个SKU，实际更新了{updated_count}个")
            if unexpected_updates:
                print(f"  意外被更新的SKU ID: {unexpected_updates}")
            return False

    async def reset_test_data(self, db: AsyncSession):
        """重置测试数据"""
        print(f"\n=== 步骤5: 重置测试数据 ===")

        if not self.test_order_number:
            return

        # 将所有SKU的已推送状态重置为"否"
        query = select(ProcurementOrder).where(
            ProcurementOrder.原始订单编号 == self.test_order_number
        )

        result = await db.execute(query)
        skus = result.scalars().all()

        for sku in skus:
            sku.已推送 = '否'

        await db.commit()
        print(f"✓ 已重置订单 {self.test_order_number} 的所有SKU状态")

    async def run_test(self):
        """运行完整测试"""
        print("=" * 60)
        print("多SKU订单导出标记测试")
        print("=" * 60)

        async for db in get_db():
            try:
                # 1. 查找测试订单
                order_number = await self.find_multi_sku_order(db)
                if not order_number:
                    print("\n无法进行测试：没有找到合适的测试数据")
                    return

                # 2. 获取SKU详情
                await self.get_order_skus(db)

                # 3. 模拟导出
                await self.simulate_export_single_sku(db)

                # 4. 验证结果
                test_passed = await self.verify_results(db)

                # 5. 重置数据
                await self.reset_test_data(db)

                print("\n" + "=" * 60)
                if test_passed:
                    print("最终结果: ✓ 测试通过")
                    print("结论: 系统正确处理了多SKU订单的导出，只标记了导出的SKU")
                else:
                    print("最终结果: ✗ 测试失败")
                    print("问题: 导出一个SKU时，同一订单的其他SKU也被标记了")
                    print("建议: 需要修复update_push_status_for_exported_products方法")
                print("=" * 60)

            except Exception as e:
                print(f"\n✗ 测试过程中发生错误: {e}")
                import traceback
                traceback.print_exc()
            finally:
                await db.close()
                break


async def main():
    """主函数"""
    tester = MultiSkuExportTester()
    await tester.run_test()


if __name__ == "__main__":
    asyncio.run(main())