"""
Excel 文件读取工具
"""

import re
from pathlib import Path
from typing import Dict, Any, List, Optional
from datetime import datetime
import pandas as pd
import numpy as np
import warnings
import logging

warnings.filterwarnings('ignore', category=UserWarning, module='openpyxl')

logger = logging.getLogger(__name__)


class ExcelReader:
    """Excel 文件读取器"""
    
    def __init__(self):
        # 常见日期格式
        self.date_formats = [
            '%Y-%m-%d %H:%M:%S',
            '%Y/%m/%d %H:%M:%S',
            '%Y-%m-%d %H:%M',
            '%Y/%m/%d %H:%M',
            '%Y-%m-%d',
            '%Y/%m/%d',
            '%d/%m/%Y %H:%M:%S',
            '%d-%m-%Y %H:%M:%S',
            '%m/%d/%Y %H:%M:%S',
        ]
    
    def read_file(self, file_path: Path, sheet_name: Optional[str] = None) -> pd.DataFrame:
        """
        读取 Excel 或 CSV 文件
        
        Args:
            file_path: 文件路径（支持 .xlsx, .xls, .csv）
            sheet_name: 工作表名称（仅对 Excel 文件有效）
            
        Returns:
            DataFrame
        """
        file_extension = file_path.suffix.lower()
        
        try:
            if file_extension == '.csv':
                # 读取 CSV 文件
                df = pd.read_csv(file_path, encoding='utf-8')
                logger.info(f"Successfully read CSV {file_path.name}: {len(df)} rows, {len(df.columns)} columns")
                return df
            elif file_extension in ['.xlsx', '.xls']:
                # 读取 Excel 文件 - 先尝试智能识别标题行
                sheet_param = sheet_name if sheet_name is not None else 0
                
                # 首先读取前几行检查结构
                if file_extension == '.xlsx':
                    df_sample = pd.read_excel(file_path, engine='openpyxl', sheet_name=sheet_param, nrows=10, header=None)
                else:
                    df_sample = pd.read_excel(file_path, engine='xlrd', sheet_name=sheet_param, nrows=10, header=None)
                
                # 查找包含"原始订单编号"的行作为标题行
                header_row = None
                for idx, row in df_sample.iterrows():
                    if any(cell and '原始订单编号' in str(cell) for cell in row):
                        header_row = idx
                        logger.info(f"Found header row at index: {header_row}")
                        break
                
                # 使用找到的标题行重新读取
                if header_row is not None:
                    if file_extension == '.xlsx':
                        df = pd.read_excel(file_path, engine='openpyxl', sheet_name=sheet_param, header=header_row)
                    else:
                        df = pd.read_excel(file_path, engine='xlrd', sheet_name=sheet_param, header=header_row)
                else:
                    # 没找到标准标题行，使用默认方式
                    if file_extension == '.xlsx':
                        df = pd.read_excel(file_path, engine='openpyxl', sheet_name=sheet_param)
                    else:
                        df = pd.read_excel(file_path, engine='xlrd', sheet_name=sheet_param)
                
                logger.info(f"Successfully read Excel {file_path.name}: {len(df)} rows, {len(df.columns)} columns")
                return df
            else:
                raise ValueError(f"Unsupported file format: {file_extension}")
            
        except UnicodeDecodeError:
            # CSV 编码问题，尝试其他编码
            if file_extension == '.csv':
                for encoding in ['gbk', 'gb2312', 'big5']:
                    try:
                        df = pd.read_csv(file_path, encoding=encoding)
                        logger.info(f"Successfully read CSV with {encoding} encoding: {len(df)} rows, {len(df.columns)} columns")
                        return df
                    except UnicodeDecodeError:
                        continue
                raise ValueError("Failed to read CSV file with any supported encoding")
            else:
                raise
        except Exception as e:
            logger.error(f"Failed to read {file_path}: {e}")
            
            # 对 Excel 文件尝试默认引擎
            if file_extension in ['.xlsx', '.xls']:
                try:
                    sheet_param = sheet_name if sheet_name is not None else 0
                    df = pd.read_excel(file_path, sheet_name=sheet_param)
                    logger.info(f"Read with default engine: {len(df)} rows, {len(df.columns)} columns")
                    return df
                except Exception as e2:
                    logger.error(f"Failed with default engine: {e2}")
                    raise e2
            else:
                raise e
    
    def normalize_column_names(self, df: pd.DataFrame) -> pd.DataFrame:
        """
        规范化列名：移除空格、换行符等，处理重复列名
        对于关键的产品字段，优先使用后面出现的（有数据的）列
        
        Args:
            df: 原始 DataFrame
            
        Returns:
            列名规范化后的 DataFrame
        """
        df = df.copy()
        
        # 关键字段列表 - 这些字段如果重复，使用后面的列（通常有数据）
        key_fields = ['线上宝贝名称', '线上销售属性', '线上商家编码', '商品编号', 
                      'SKU编号', '图片', '数量', '订单单价', '订单金额']
        
        # 规范化列名
        normalized_columns = {}
        duplicate_counts = {}  # 记录每个列名出现的次数
        
        for col_index, col in enumerate(df.columns):
            if pd.isna(col) or col == '':
                # 处理空列名
                normalized_col = f'unnamed_column_{col_index}'
            else:
                # 移除空白字符、换行符
                base_col = str(col).strip().replace('\n', '').replace('\r', '')
                
                # 对于关键字段，检查哪个列有更多数据
                if base_col in key_fields and base_col in normalized_columns.values():
                    # 找到之前同名列的索引
                    prev_col_idx = list(normalized_columns.values()).index(base_col)
                    prev_col_name = list(normalized_columns.keys())[prev_col_idx]
                    
                    # 比较两列的非空值数量
                    prev_non_null = df[prev_col_name].notna().sum()
                    curr_non_null = df[col].notna().sum()
                    
                    if curr_non_null > prev_non_null:
                        # 当前列有更多数据，重命名之前的列，使用当前列名
                        old_normalized = normalized_columns[prev_col_name]
                        normalized_columns[prev_col_name] = f"{old_normalized}_old"
                        normalized_col = base_col
                        logger.debug(f"Column '{base_col}': using column at index {col_index} with {curr_non_null} non-null values (previous had {prev_non_null})")
                    else:
                        # 之前的列有更多数据，重命名当前列
                        normalized_col = f"{base_col}_dup"
                        logger.debug(f"Column '{base_col}': keeping previous column with {prev_non_null} non-null values (current has {curr_non_null})")
                else:
                    # 非关键字段或首次出现
                    if base_col in normalized_columns.values():
                        # 添加后缀处理重复
                        counter = duplicate_counts.get(base_col, 1)
                        duplicate_counts[base_col] = counter + 1
                        normalized_col = f"{base_col}_{counter}"
                    else:
                        normalized_col = base_col
            
            normalized_columns[col] = normalized_col
        
        df = df.rename(columns=normalized_columns)
        return df
    
    def clean_data(self, df: pd.DataFrame) -> pd.DataFrame:
        """
        清理数据：处理空值、数据类型等
        
        Args:
            df: 原始 DataFrame
            
        Returns:
            清理后的 DataFrame
        """
        df = df.copy()
        
        # 替换各种空值表示
        df = df.replace({
            'nan': np.nan,
            'NaN': np.nan,
            'NULL': np.nan,
            'null': np.nan,
            '': np.nan,
            ' ': np.nan,
        })
        
        # 移除完全空白的行
        df = df.dropna(how='all')
        
        # 额外过滤：移除所有关键字段都为空的行
        key_fields = ['原始订单编号', '网店名称', '线上宝贝名称', '数量', '订单单价']
        existing_key_fields = [f for f in key_fields if f in df.columns]
        
        if existing_key_fields:
            # 如果有任何关键字段，则保留至少有一个关键字段非空的行
            mask = df[existing_key_fields].notna().any(axis=1)
            df = df[mask]
            logger.info(f"After filtering rows with empty key fields: {len(df)} rows remain")
        
        # 处理数值列的字符串表示
        for col in df.columns:
            if df[col].dtype == 'object':
                # 尝试转换数值
                if '数量' in col or '金额' in col or '单价' in col:
                    df[col] = pd.to_numeric(df[col], errors='ignore')
        
        return df
    
    def parse_dates(self, df: pd.DataFrame, date_columns: List[str] = None) -> pd.DataFrame:
        """
        解析日期列
        
        Args:
            df: DataFrame
            date_columns: 日期列名列表（None 表示自动检测）
            
        Returns:
            日期解析后的 DataFrame
        """
        df = df.copy()
        
        # 自动检测日期列
        if date_columns is None:
            date_columns = [col for col in df.columns 
                          if any(keyword in col for keyword in ['时间', '日期', 'time', 'date'])]
        
        for col in date_columns:
            if col not in df.columns:
                continue
                
            df[col] = self._parse_date_column(df[col])
        
        return df
    
    def _parse_date_column(self, series: pd.Series) -> pd.Series:
        """
        解析单个日期列
        
        Args:
            series: 原始日期序列
            
        Returns:
            解析后的日期序列
        """
        if series.dtype == 'datetime64[ns]':
            return series
        
        # 尝试不同格式
        for fmt in self.date_formats:
            try:
                parsed = pd.to_datetime(series, format=fmt, errors='coerce')
                success_rate = 1 - parsed.isna().mean()
                if success_rate > 0.7:  # 如果成功率超过70%
                    logger.debug(f"Date parsing success with format '{fmt}': {success_rate:.2%}")
                    return parsed
            except:
                continue
        
        # 如果所有格式都失败，使用 pandas 自动推断
        try:
            return pd.to_datetime(series, errors='coerce')
        except:
            logger.warning(f"Failed to parse dates in column")
            return series
    
    def convert_to_dict_records(self, df: pd.DataFrame, add_row_index: bool = True) -> List[Dict[str, Any]]:
        """
        将 DataFrame 转换为字典列表
        
        Args:
            df: DataFrame
            add_row_index: 是否添加行索引
            
        Returns:
            字典列表
        """
        records = []
        
        for idx, row in df.iterrows():
            record = {}
            
            # 添加行索引
            if add_row_index:
                record['row_idx'] = int(idx) + 2  # Excel 从第2行开始（第1行是表头）
            
            # 转换每个字段
            for col, value in row.items():
                # 处理 NaN 值
                if pd.isna(value):
                    record[col] = None
                # 处理日期时间
                elif isinstance(value, (pd.Timestamp, datetime)):
                    # 将pandas.Timestamp转换为Python datetime
                    if isinstance(value, pd.Timestamp):
                        record[col] = value.to_pydatetime()
                    else:
                        record[col] = value
                # 处理数值
                elif isinstance(value, (np.integer, np.floating)):
                    if np.isnan(value):
                        record[col] = None
                    else:
                        record[col] = float(value) if isinstance(value, np.floating) else int(value)
                else:
                    # 字符串等其他类型
                    record[col] = str(value) if value is not None else None
            
            records.append(record)
        
        return records
    
    def process_excel_file(self, file_path: Path) -> List[Dict[str, Any]]:
        """
        处理 Excel 文件的完整流程
        
        Args:
            file_path: Excel 文件路径
            
        Returns:
            处理后的记录列表
        """
        logger.info(f"Starting to process {file_path.name}")
        
        # 读取文件
        df = self.read_file(file_path)
        logger.info(f"Raw data: {len(df)} rows, columns: {list(df.columns)}")
        
        # 规范化列名
        df = self.normalize_column_names(df)
        logger.info(f"After column normalization: {len(df)} rows, columns: {list(df.columns)}")
        
        # 清理数据
        df = self.clean_data(df)
        logger.info(f"After data cleaning: {len(df)} rows")
        
        # 解析日期
        df = self.parse_dates(df)
        logger.info(f"After date parsing: {len(df)} rows")
        
        # 转换为记录
        records = self.convert_to_dict_records(df)
        
        # 统计非空记录
        valid_records = 0
        for record in records:
            key_fields = ['原始订单编号', '网店名称', '线上宝贝名称', '数量', '订单单价']
            if any(record.get(field) is not None for field in key_fields):
                valid_records += 1
        
        logger.info(f"Processed {file_path.name}: {len(records)} total records, {valid_records} with valid data")
        
        # 如果有效记录很少，显示样本数据用于调试
        if valid_records < 10:
            logger.info(f"Sample of first 5 records:")
            for i, record in enumerate(records[:5]):
                logger.info(f"  Record {i+1}: {record}")
        
        return records