state.py 2.92 KB
"""
LangGraph Agent 状态定义

重构版本:直接使用 part_ratio 数据,支持 SQL Agent
使用 Annotated 类型和 reducer 函数解决并发状态更新问题
"""

from typing import TypedDict, Optional, List, Any, Annotated
from decimal import Decimal
import operator


def merge_lists(left: List[Any], right: List[Any]) -> List[Any]:
    """合并两个列表的 reducer 函数"""
    if left is None:
        left = []
    if right is None:
        right = []
    return left + right


def merge_dicts(left: List[dict], right: List[dict]) -> List[dict]:
    """合并字典列表的 reducer 函数"""
    if left is None:
        left = []
    if right is None:
        right = []
    return left + right


def keep_last(left: Any, right: Any) -> Any:
    """保留最后一个值的 reducer 函数"""
    return right if right is not None else left


def sum_values(left: int, right: int) -> int:
    """累加数值的 reducer 函数"""
    return (left or 0) + (right or 0)


class AgentState(TypedDict, total=False):
    """补货建议 Agent 状态
    
    使用 Annotated 类型定义 reducer 函数,处理并行节点的状态合并
    """
    
    # 任务标识(使用 keep_last,因为这些值在并行执行时相同)
    task_no: Annotated[str, keep_last]
    group_id: Annotated[int, keep_last]
    brand_grouping_id: Annotated[Optional[int], keep_last]
    brand_grouping_name: Annotated[str, keep_last]
    dealer_grouping_id: Annotated[int, keep_last]
    dealer_grouping_name: Annotated[str, keep_last]
    statistics_date: Annotated[str, keep_last]
    
    # part_ratio 原始数据(使用 keep_last,因为只在 fetch_part_ratio 节点写入一次)
    part_ratios: Annotated[List[dict], keep_last]
    
    # SQL Agent 相关
    sql_queries: Annotated[List[str], merge_lists]
    sql_results: Annotated[List[dict], merge_dicts]
    sql_retry_count: Annotated[int, keep_last]
    sql_execution_logs: Annotated[List[dict], merge_dicts]
    
    # 计算结果
    base_ratio: Annotated[Decimal, keep_last]
    allocated_details: Annotated[List[dict], merge_dicts]
    details: Annotated[List[Any], merge_lists]
    
    # LLM 建议明细
    llm_suggestions: Annotated[List[Any], merge_lists]
    
    # 配件汇总结果
    part_results: Annotated[List[Any], merge_lists]
    
    # 分析报告
    analysis_report: Annotated[Optional[dict], keep_last]
    
    # LLM 统计(使用累加,合并多个并行节点的 token 使用量)
    llm_provider: Annotated[str, keep_last]
    llm_model: Annotated[str, keep_last]
    llm_prompt_tokens: Annotated[int, sum_values]
    llm_completion_tokens: Annotated[int, sum_values]
    
    # 执行状态
    status: Annotated[str, keep_last]
    error_message: Annotated[str, keep_last]
    start_time: Annotated[float, keep_last]
    end_time: Annotated[float, keep_last]
    
    # 流程控制
    current_node: Annotated[str, keep_last]
    next_node: Annotated[str, keep_last]