"""
LangGraph Agent 节点实现

重构版本：直接使用 part_ratio 数据 + SQL Agent
"""

import logging
import time
import json
from typing import Dict, List
from decimal import Decimal
from datetime import datetime

from langchain_core.messages import SystemMessage, HumanMessage

from .state import AgentState
from .sql_agent import SQLAgent
from ..models import ReplenishmentSuggestion, PartAnalysisResult
from ..llm import get_llm_client
from ..services import DataService
from ..services.result_writer import ResultWriter
from ..models import ReplenishmentDetail, TaskExecutionLog, LogStatus, ReplenishmentPartSummary

logger = logging.getLogger(__name__)


def _load_prompt(filename: str) -> str:
    """从prompts目录加载提示词文件"""
    import os
    # 从 src/fw_pms_ai/agent/nodes.py 向上4层到达项目根目录
    prompt_path = os.path.join(
        os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(__file__)))),
        "prompts", filename
    )
    try:
        with open(prompt_path, "r", encoding="utf-8") as f:
            return f.read()
    except FileNotFoundError:
        logger.warning(f"Prompt文件未找到: {prompt_path}")
        return ""



def fetch_part_ratio_node(state: AgentState) -> AgentState:
    """
    节点1: 获取 part_ratio 数据
    直接通过 dealer_grouping_id 从 part_ratio 表获取配件库销比数据
    """
    logger.info(f"[FetchPartRatio] ========== 开始获取数据 ==========")
    logger.info(
        f"[FetchPartRatio] group_id={state['group_id']}, "
        f"dealer_grouping_id={state['dealer_grouping_id']}, "
        f"date={state['statistics_date']}"
    )
    
    start_time = time.time()
    
    sql_agent = SQLAgent()
    
    try:
        # 直接使用 dealer_grouping_id 获取 part_ratio 数据
        part_ratios = sql_agent.fetch_part_ratios(
            group_id=state["group_id"],
            dealer_grouping_id=state["dealer_grouping_id"],
            statistics_date=state["statistics_date"],
        )
        
        execution_time = int((time.time() - start_time) * 1000)
        
        # 记录执行日志
        log_entry = {
            "step_name": "fetch_part_ratio",
            "step_order": 1,
            "status": LogStatus.SUCCESS if part_ratios else LogStatus.SKIPPED,
            "input_data": json.dumps({
                "dealer_grouping_id": state["dealer_grouping_id"],
                "statistics_date": state["statistics_date"],
            }),
            "output_data": json.dumps({"part_ratios_count": len(part_ratios)}),
            "execution_time_ms": execution_time,
            "start_time": datetime.now().isoformat(),
        }
        
        logger.info(
            f"[FetchPartRatio] 数据获取完成: part_ratios={len(part_ratios)}, "
            f"耗时={execution_time}ms"
        )
        
        return {
            **state,
            "part_ratios": part_ratios,
            "sql_execution_logs": [log_entry],
            "current_node": "fetch_part_ratio",
            "next_node": "sql_agent",
        }
        
    finally:
        sql_agent.close()



def sql_agent_node(state: AgentState) -> AgentState:
    """
    节点2: SQL Agent 分析和生成建议
    按 part_code 分组，逐个配件分析各门店的补货需求
    """
    part_ratios = state.get("part_ratios", [])


    logger.info(f"[SQLAgent] 开始分析: part_ratios={len(part_ratios)}")
    
    start_time = time.time()
    retry_count = state.get("sql_retry_count", 0)
    
    if not part_ratios:
        logger.warning("[SQLAgent] 无配件数据可分析")
        
        log_entry = {
            "step_name": "sql_agent",
            "step_order": 2,
            "status": LogStatus.SKIPPED,
            "error_message": "无配件数据",
            "execution_time_ms": int((time.time() - start_time) * 1000),
        }
        
        return {
            **state,
            "llm_suggestions": [],
            "llm_analysis_summary": "无配件数据可分析",
            "sql_execution_logs": [log_entry],
            "current_node": "sql_agent",
            "next_node": "allocate_budget",
        }
    
    sql_agent = SQLAgent()
    
    try:
        # 计算基准库销比（仅用于记录，不影响LLM建议）
        total_valid_storage = sum(
            Decimal(str(p.get("valid_storage_cnt", 0) or 0))
            for p in part_ratios
        )
        total_avg_sales = sum(
            Decimal(str(p.get("avg_sales_cnt", 0) or 0))
            for p in part_ratios
        )
        
        if total_avg_sales > 0:
            base_ratio = total_valid_storage / total_avg_sales
        else:
            base_ratio = Decimal("0")
        
        logger.info(
            f"[SQLAgent] 当前库销比: 总库存={total_valid_storage}, "
            f"总销量={total_avg_sales}, 库销比={base_ratio}"
        )
        
        # 定义批处理回调
        # 由于 models 中没有 ResultWriter 的引用，这里尝试直接从 services 导入或实例化
        # 为避免循环导入，我们在函数内导入
        from ..services import ResultWriter as WriterService
        writer = WriterService()
        
        # 1. 任务开始时清理旧数据（确保重试时不会产生重复数据）
        # logger.info(f"[SQLAgent] 清理旧建议数据: task_no={state['task_no']}")
        # writer.clear_llm_suggestions(state["task_no"])

        # 2. 移除批处理回调（不再过程写入，改为最后统一写入）
        save_batch_callback = None

        # 使用分组分析生成补货建议（按 part_code 分组，逐个配件分析各门店需求）
        suggestions, part_results, llm_stats = sql_agent.analyze_parts_by_group(
            part_ratios=part_ratios,
            dealer_grouping_id=state["dealer_grouping_id"],
            dealer_grouping_name=state["dealer_grouping_name"],
            statistics_date=state["statistics_date"],
            target_ratio=base_ratio if base_ratio > 0 else Decimal("1.3"),
            limit=1000,
            callback=save_batch_callback,
        )
        
        execution_time = int((time.time() - start_time) * 1000)
        
        # 记录执行日志
        log_entry = {
            "step_name": "sql_agent",
            "step_order": 2,
            "status": LogStatus.SUCCESS,
            "input_data": json.dumps({
                "part_ratios_count": len(part_ratios),
            }),
            "output_data": json.dumps({
                "suggestions_count": len(suggestions),
                "part_results_count": len(part_results),
                "base_ratio": float(base_ratio),
            }),
            "llm_tokens": llm_stats.get("prompt_tokens", 0) + llm_stats.get("completion_tokens", 0),
            "execution_time_ms": execution_time,
            "retry_count": retry_count,
        }
        
        logger.info(
            f"[SQLAgent] 分析完成: 建议数={len(suggestions)}, "
            f"配件汇总数={len(part_results)}, tokens={llm_stats}, 耗时={execution_time}ms"
        )
        
        return {
            **state,
            "base_ratio": base_ratio,
            "llm_suggestions": suggestions,
            "part_results": part_results,
            "llm_prompt_tokens": state.get("llm_prompt_tokens", 0) + llm_stats.get("prompt_tokens", 0),
            "llm_completion_tokens": state.get("llm_completion_tokens", 0) + llm_stats.get("completion_tokens", 0),
            "sql_execution_logs": [log_entry],
            "current_node": "sql_agent",
            "next_node": "allocate_budget",
        }
        
    except Exception as e:
        logger.error(f"[SQLAgent] 执行失败: {e}")
        
        log_entry = {
            "step_name": "sql_agent",
            "step_order": 2,
            "status": LogStatus.FAILED,
            "error_message": str(e),
            "retry_count": retry_count,
            "execution_time_ms": int((time.time() - start_time) * 1000),
        }
        
        # 检查是否需要重试
        if retry_count < 3:
            return {
                **state,
                "sql_retry_count": retry_count + 1,
                "sql_execution_logs": [log_entry],
                "current_node": "sql_agent",
                "next_node": "sql_agent",  # 重试
                "error_message": str(e),
            }
        
        return {
            **state,
            "llm_suggestions": [],
            "sql_execution_logs": [log_entry],
            "current_node": "sql_agent",
            "next_node": "allocate_budget",
            "error_message": str(e),
        }
        
    finally:
        sql_agent.close()


def allocate_budget_node(state: AgentState) -> AgentState:
    """
    节点3: 转换LLM建议为补货明细
    注意：不做预算截断，所有建议直接输出
    """
    logger.info(f"[AllocateBudget] 开始处理LLM建议")
    
    start_time = time.time()
    
    llm_suggestions = state.get("llm_suggestions", [])
    
    if not llm_suggestions:
        logger.warning("[AllocateBudget] 无LLM建议可处理")
        
        log_entry = {
            "step_name": "allocate_budget",
            "step_order": 3,
            "status": LogStatus.SKIPPED,
            "error_message": "无LLM建议",
            "execution_time_ms": int((time.time() - start_time) * 1000),
        }
        
        return {
            **state,
            "details": [],
            "sql_execution_logs": [log_entry],
            "current_node": "allocate_budget",
            "next_node": "end",
        }
    
    # 按优先级和库销比排序（优先级升序，库销比升序）
    sorted_suggestions = sorted(
        llm_suggestions,
        key=lambda x: (x.priority, float(x.current_ratio))
    )
    
    # 建立 part_code -> brand_grouping_id 映射，确保明细归属正确的品牌组合
    part_ratios = state.get("part_ratios", [])
    part_brand_map = {p.get("part_code"): p.get("brand_grouping_id") for p in part_ratios if p.get("part_code")}
    
    allocated_details = []
    total_amount = Decimal("0")
    
    # 转换所有建议为明细（包括不需要补货的配件，以便记录完整分析结果）
    for suggestion in sorted_suggestions:
        # 获取该配件对应的 brand_grouping_id
        bg_id = part_brand_map.get(suggestion.part_code)
        if bg_id is None:
            bg_id = state.get("brand_grouping_id")
            
        detail = ReplenishmentDetail(
            task_no=state["task_no"],
            group_id=state["group_id"],
            dealer_grouping_id=state["dealer_grouping_id"],
            brand_grouping_id=bg_id,
            shop_id=suggestion.shop_id,
            shop_name=suggestion.shop_name,
            part_code=suggestion.part_code,
            part_name=suggestion.part_name,
            unit=suggestion.unit,
            cost_price=suggestion.cost_price,
            base_ratio=state.get("base_ratio", Decimal("1.1")),
            current_ratio=suggestion.current_ratio,
            valid_storage_cnt=suggestion.current_storage_cnt,
            avg_sales_cnt=suggestion.avg_sales_cnt,
            suggest_cnt=suggestion.suggest_cnt,
            suggest_amount=suggestion.suggest_amount,
            suggestion_reason=suggestion.suggestion_reason,
            priority=suggestion.priority,
            llm_confidence=suggestion.confidence,
            statistics_date=state["statistics_date"],
        )
        # 计算预计库销比
        post_storage = detail.valid_storage_cnt + detail.suggest_cnt
        if post_storage <= 0 or detail.avg_sales_cnt <= 0:
            # 库存为0或销量为0时，库销比设为0
            detail.post_plan_ratio = Decimal("0")
        else:
            detail.post_plan_ratio = post_storage / detail.avg_sales_cnt
        
        allocated_details.append(detail)
        total_amount += suggestion.suggest_amount
    
    execution_time = int((time.time() - start_time) * 1000)
    
    # 记录执行日志
    log_entry = {
        "step_name": "allocate_budget",
        "step_order": 3,
        "status": LogStatus.SUCCESS,
        "input_data": json.dumps({
            "suggestions_count": len(llm_suggestions),
        }),
        "output_data": json.dumps({
            "details_count": len(allocated_details),
            "total_amount": float(total_amount),
        }),
        "execution_time_ms": execution_time,
    }
    
    logger.info(
        f"[AllocateBudget] 分配完成: 配件数={len(allocated_details)}, "
        f"金额={total_amount}"
    )
    
    # 保存结果到数据库
    try:
        writer = ResultWriter()
        
        # 0. 先清理旧数据（防止重试或重复执行时产生重复记录）
        writer.delete_details_by_task(state["task_no"])
        writer.delete_part_summaries_by_task(state["task_no"])
        logger.info(f"[AllocateBudget] 已清理旧数据: task_no={state['task_no']}")
        
        # 1. 保存补货明细
        if allocated_details:
            writer.save_details(allocated_details)
            logger.info(f"[AllocateBudget] 已保存 {len(allocated_details)} 条补货明细")
            
        # 2. 保存配件汇总
        part_results = state.get("part_results", [])
        if part_results:
            part_summaries = []
            for pr in part_results:
                summary = ReplenishmentPartSummary(
                    task_no=state["task_no"],
                    group_id=state["group_id"],
                    dealer_grouping_id=state["dealer_grouping_id"],
                    part_code=pr.part_code,
                    part_name=pr.part_name,
                    unit=pr.unit,
                    cost_price=pr.cost_price,
                    total_storage_cnt=pr.total_storage_cnt,
                    total_avg_sales_cnt=pr.total_avg_sales_cnt,
                    group_current_ratio=pr.group_current_ratio,
                    total_suggest_cnt=pr.total_suggest_cnt,
                    total_suggest_amount=pr.total_suggest_amount,
                    shop_count=pr.shop_count,
                    need_replenishment_shop_count=pr.need_replenishment_shop_count,
                    part_decision_reason=pr.part_decision_reason,
                    priority=pr.priority,
                    llm_confidence=pr.confidence,
                    statistics_date=state["statistics_date"],
                )
                part_summaries.append(summary)
            
            writer.save_part_summaries(part_summaries)
            logger.info(f"[AllocateBudget] 已保存 {len(part_summaries)} 条配件分析汇总")
            
        writer.close()
    except Exception as e:
        logger.error(f"[AllocateBudget] 保存结果失败: {e}")
        # 记录错误但不中断流程
        error_log = {
            "step_name": "allocate_budget",
            "step_order": 3,
            "status": LogStatus.FAILED,
            "error_message": f"保存结果失败: {str(e)}",
            "execution_time_ms": 0,
        }
        
        return {
            **state,
            "details": allocated_details,
            "sql_execution_logs": [log_entry, error_log],
            "current_node": "allocate_budget",
            "next_node": "end",
            "status": "success",
            "end_time": time.time(),
        }
    
    return {
        **state,
        "details": allocated_details,
        "sql_execution_logs": [log_entry],
        "current_node": "allocate_budget",
        "next_node": "end",
        "status": "success",
        "end_time": time.time(),
    }


def should_retry_sql(state: AgentState) -> str:
    """条件边: 判断是否需要重试SQL Agent"""
    next_node = state.get("next_node", "allocate_budget")
    retry_count = state.get("sql_retry_count", 0)
    
    if next_node == "sql_agent" and retry_count < 3:
        logger.info(f"[Routing] SQL Agent需要重试: retry_count={retry_count}")
        return "retry"
    
    return "continue"


def should_continue(state: AgentState) -> str:
    """条件边: 判断是否继续"""
    return state.get("next_node", "end")

