""" LangGraph Agent 节点实现 重构版本:直接使用 part_ratio 数据 + SQL Agent """ import logging import time import json from typing import Dict, List from decimal import Decimal from datetime import datetime from langchain_core.messages import SystemMessage, HumanMessage from .state import AgentState from .sql_agent import SQLAgent from ..models import ReplenishmentSuggestion, PartAnalysisResult from ..llm import get_llm_client from ..services import DataService from ..services.result_writer import ResultWriter from ..models import ReplenishmentDetail, TaskExecutionLog, LogStatus, ReplenishmentPartSummary logger = logging.getLogger(__name__) def _load_prompt(filename: str) -> str: """从prompts目录加载提示词文件""" import os # 从 src/fw_pms_ai/agent/nodes.py 向上4层到达项目根目录 prompt_path = os.path.join( os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(__file__)))), "prompts", filename ) try: with open(prompt_path, "r", encoding="utf-8") as f: return f.read() except FileNotFoundError: logger.warning(f"Prompt文件未找到: {prompt_path}") return "" def fetch_part_ratio_node(state: AgentState) -> AgentState: """ 节点1: 获取 part_ratio 数据 直接通过 dealer_grouping_id 从 part_ratio 表获取配件库销比数据 """ logger.info(f"[FetchPartRatio] ========== 开始获取数据 ==========") logger.info( f"[FetchPartRatio] group_id={state['group_id']}, " f"dealer_grouping_id={state['dealer_grouping_id']}, " f"date={state['statistics_date']}" ) start_time = time.time() sql_agent = SQLAgent() try: # 直接使用 dealer_grouping_id 获取 part_ratio 数据 part_ratios = sql_agent.fetch_part_ratios( group_id=state["group_id"], dealer_grouping_id=state["dealer_grouping_id"], statistics_date=state["statistics_date"], ) execution_time = int((time.time() - start_time) * 1000) # 记录执行日志 log_entry = { "step_name": "fetch_part_ratio", "step_order": 1, "status": LogStatus.SUCCESS if part_ratios else LogStatus.SKIPPED, "input_data": json.dumps({ "dealer_grouping_id": state["dealer_grouping_id"], "statistics_date": state["statistics_date"], }), "output_data": json.dumps({"part_ratios_count": len(part_ratios)}), "execution_time_ms": execution_time, "start_time": datetime.now().isoformat(), } logger.info( f"[FetchPartRatio] 数据获取完成: part_ratios={len(part_ratios)}, " f"耗时={execution_time}ms" ) return { **state, "part_ratios": part_ratios, "sql_execution_logs": [log_entry], "current_node": "fetch_part_ratio", "next_node": "sql_agent", } finally: sql_agent.close() def sql_agent_node(state: AgentState) -> AgentState: """ 节点2: SQL Agent 分析和生成建议 按 part_code 分组,逐个配件分析各门店的补货需求 """ part_ratios = state.get("part_ratios", []) logger.info(f"[SQLAgent] 开始分析: part_ratios={len(part_ratios)}") start_time = time.time() retry_count = state.get("sql_retry_count", 0) if not part_ratios: logger.warning("[SQLAgent] 无配件数据可分析") log_entry = { "step_name": "sql_agent", "step_order": 2, "status": LogStatus.SKIPPED, "error_message": "无配件数据", "execution_time_ms": int((time.time() - start_time) * 1000), } return { **state, "llm_suggestions": [], "llm_analysis_summary": "无配件数据可分析", "sql_execution_logs": [log_entry], "current_node": "sql_agent", "next_node": "allocate_budget", } sql_agent = SQLAgent() try: # 计算基准库销比(仅用于记录,不影响LLM建议) total_valid_storage = sum( Decimal(str(p.get("valid_storage_cnt", 0) or 0)) for p in part_ratios ) total_avg_sales = sum( Decimal(str(p.get("avg_sales_cnt", 0) or 0)) for p in part_ratios ) if total_avg_sales > 0: base_ratio = total_valid_storage / total_avg_sales else: base_ratio = Decimal("0") logger.info( f"[SQLAgent] 当前库销比: 总库存={total_valid_storage}, " f"总销量={total_avg_sales}, 库销比={base_ratio}" ) # 定义批处理回调 # 由于 models 中没有 ResultWriter 的引用,这里尝试直接从 services 导入或实例化 # 为避免循环导入,我们在函数内导入 from ..services import ResultWriter as WriterService writer = WriterService() # 1. 任务开始时清理旧数据(确保重试时不会产生重复数据) # logger.info(f"[SQLAgent] 清理旧建议数据: task_no={state['task_no']}") # writer.clear_llm_suggestions(state["task_no"]) # 2. 移除批处理回调(不再过程写入,改为最后统一写入) save_batch_callback = None # 使用分组分析生成补货建议(按 part_code 分组,逐个配件分析各门店需求) suggestions, part_results, llm_stats = sql_agent.analyze_parts_by_group( part_ratios=part_ratios, dealer_grouping_id=state["dealer_grouping_id"], dealer_grouping_name=state["dealer_grouping_name"], statistics_date=state["statistics_date"], target_ratio=base_ratio if base_ratio > 0 else Decimal("1.3"), limit=1000, callback=save_batch_callback, ) execution_time = int((time.time() - start_time) * 1000) # 记录执行日志 log_entry = { "step_name": "sql_agent", "step_order": 2, "status": LogStatus.SUCCESS, "input_data": json.dumps({ "part_ratios_count": len(part_ratios), }), "output_data": json.dumps({ "suggestions_count": len(suggestions), "part_results_count": len(part_results), "base_ratio": float(base_ratio), }), "llm_tokens": llm_stats.get("prompt_tokens", 0) + llm_stats.get("completion_tokens", 0), "execution_time_ms": execution_time, "retry_count": retry_count, } logger.info( f"[SQLAgent] 分析完成: 建议数={len(suggestions)}, " f"配件汇总数={len(part_results)}, tokens={llm_stats}, 耗时={execution_time}ms" ) return { **state, "base_ratio": base_ratio, "llm_suggestions": suggestions, "part_results": part_results, "llm_prompt_tokens": state.get("llm_prompt_tokens", 0) + llm_stats.get("prompt_tokens", 0), "llm_completion_tokens": state.get("llm_completion_tokens", 0) + llm_stats.get("completion_tokens", 0), "sql_execution_logs": [log_entry], "current_node": "sql_agent", "next_node": "allocate_budget", } except Exception as e: logger.error(f"[SQLAgent] 执行失败: {e}") log_entry = { "step_name": "sql_agent", "step_order": 2, "status": LogStatus.FAILED, "error_message": str(e), "retry_count": retry_count, "execution_time_ms": int((time.time() - start_time) * 1000), } # 检查是否需要重试 if retry_count < 3: return { **state, "sql_retry_count": retry_count + 1, "sql_execution_logs": [log_entry], "current_node": "sql_agent", "next_node": "sql_agent", # 重试 "error_message": str(e), } return { **state, "llm_suggestions": [], "sql_execution_logs": [log_entry], "current_node": "sql_agent", "next_node": "allocate_budget", "error_message": str(e), } finally: sql_agent.close() def allocate_budget_node(state: AgentState) -> AgentState: """ 节点3: 转换LLM建议为补货明细 注意:不做预算截断,所有建议直接输出 """ logger.info(f"[AllocateBudget] 开始处理LLM建议") start_time = time.time() llm_suggestions = state.get("llm_suggestions", []) if not llm_suggestions: logger.warning("[AllocateBudget] 无LLM建议可处理") log_entry = { "step_name": "allocate_budget", "step_order": 3, "status": LogStatus.SKIPPED, "error_message": "无LLM建议", "execution_time_ms": int((time.time() - start_time) * 1000), } return { **state, "details": [], "sql_execution_logs": [log_entry], "current_node": "allocate_budget", "next_node": "end", } # 按优先级和库销比排序(优先级升序,库销比升序) sorted_suggestions = sorted( llm_suggestions, key=lambda x: (x.priority, float(x.current_ratio)) ) # 建立 part_code -> brand_grouping_id 映射,确保明细归属正确的品牌组合 part_ratios = state.get("part_ratios", []) part_brand_map = {p.get("part_code"): p.get("brand_grouping_id") for p in part_ratios if p.get("part_code")} allocated_details = [] total_amount = Decimal("0") # 转换所有建议为明细(包括不需要补货的配件,以便记录完整分析结果) for suggestion in sorted_suggestions: # 获取该配件对应的 brand_grouping_id bg_id = part_brand_map.get(suggestion.part_code) if bg_id is None: bg_id = state.get("brand_grouping_id") detail = ReplenishmentDetail( task_no=state["task_no"], group_id=state["group_id"], dealer_grouping_id=state["dealer_grouping_id"], brand_grouping_id=bg_id, shop_id=suggestion.shop_id, shop_name=suggestion.shop_name, part_code=suggestion.part_code, part_name=suggestion.part_name, unit=suggestion.unit, cost_price=suggestion.cost_price, base_ratio=state.get("base_ratio", Decimal("1.1")), current_ratio=suggestion.current_ratio, valid_storage_cnt=suggestion.current_storage_cnt, avg_sales_cnt=suggestion.avg_sales_cnt, suggest_cnt=suggestion.suggest_cnt, suggest_amount=suggestion.suggest_amount, suggestion_reason=suggestion.suggestion_reason, priority=suggestion.priority, llm_confidence=suggestion.confidence, statistics_date=state["statistics_date"], ) # 计算预计库销比 post_storage = detail.valid_storage_cnt + detail.suggest_cnt if post_storage <= 0 or detail.avg_sales_cnt <= 0: # 库存为0或销量为0时,库销比设为0 detail.post_plan_ratio = Decimal("0") else: detail.post_plan_ratio = post_storage / detail.avg_sales_cnt allocated_details.append(detail) total_amount += suggestion.suggest_amount execution_time = int((time.time() - start_time) * 1000) # 记录执行日志 log_entry = { "step_name": "allocate_budget", "step_order": 3, "status": LogStatus.SUCCESS, "input_data": json.dumps({ "suggestions_count": len(llm_suggestions), }), "output_data": json.dumps({ "details_count": len(allocated_details), "total_amount": float(total_amount), }), "execution_time_ms": execution_time, } logger.info( f"[AllocateBudget] 分配完成: 配件数={len(allocated_details)}, " f"金额={total_amount}" ) # 保存结果到数据库 try: writer = ResultWriter() # 0. 先清理旧数据(防止重试或重复执行时产生重复记录) writer.delete_details_by_task(state["task_no"]) writer.delete_part_summaries_by_task(state["task_no"]) logger.info(f"[AllocateBudget] 已清理旧数据: task_no={state['task_no']}") # 1. 保存补货明细 if allocated_details: writer.save_details(allocated_details) logger.info(f"[AllocateBudget] 已保存 {len(allocated_details)} 条补货明细") # 2. 保存配件汇总 part_results = state.get("part_results", []) if part_results: part_summaries = [] for pr in part_results: summary = ReplenishmentPartSummary( task_no=state["task_no"], group_id=state["group_id"], dealer_grouping_id=state["dealer_grouping_id"], part_code=pr.part_code, part_name=pr.part_name, unit=pr.unit, cost_price=pr.cost_price, total_storage_cnt=pr.total_storage_cnt, total_avg_sales_cnt=pr.total_avg_sales_cnt, group_current_ratio=pr.group_current_ratio, total_suggest_cnt=pr.total_suggest_cnt, total_suggest_amount=pr.total_suggest_amount, shop_count=pr.shop_count, need_replenishment_shop_count=pr.need_replenishment_shop_count, part_decision_reason=pr.part_decision_reason, priority=pr.priority, llm_confidence=pr.confidence, statistics_date=state["statistics_date"], ) part_summaries.append(summary) writer.save_part_summaries(part_summaries) logger.info(f"[AllocateBudget] 已保存 {len(part_summaries)} 条配件分析汇总") writer.close() except Exception as e: logger.error(f"[AllocateBudget] 保存结果失败: {e}") # 记录错误但不中断流程 error_log = { "step_name": "allocate_budget", "step_order": 3, "status": LogStatus.FAILED, "error_message": f"保存结果失败: {str(e)}", "execution_time_ms": 0, } return { **state, "details": allocated_details, "sql_execution_logs": [log_entry, error_log], "current_node": "allocate_budget", "next_node": "end", "status": "success", "end_time": time.time(), } return { **state, "details": allocated_details, "sql_execution_logs": [log_entry], "current_node": "allocate_budget", "next_node": "end", "status": "success", "end_time": time.time(), } def should_retry_sql(state: AgentState) -> str: """条件边: 判断是否需要重试SQL Agent""" next_node = state.get("next_node", "allocate_budget") retry_count = state.get("sql_retry_count", 0) if next_node == "sql_agent" and retry_count < 3: logger.info(f"[Routing] SQL Agent需要重试: retry_count={retry_count}") return "retry" return "continue" def should_continue(state: AgentState) -> str: """条件边: 判断是否继续""" return state.get("next_node", "end")