state.py
2.77 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
"""
LangGraph Agent 状态定义
重构版本:直接使用 part_ratio 数据,支持 SQL Agent
使用 Annotated 类型和 reducer 函数解决并发状态更新问题
"""
from typing import TypedDict, Optional, List, Any, Annotated
from decimal import Decimal
import operator
def merge_lists(left: List[Any], right: List[Any]) -> List[Any]:
"""合并两个列表的 reducer 函数"""
if left is None:
left = []
if right is None:
right = []
return left + right
def merge_dicts(left: List[dict], right: List[dict]) -> List[dict]:
"""合并字典列表的 reducer 函数"""
if left is None:
left = []
if right is None:
right = []
return left + right
def keep_last(left: Any, right: Any) -> Any:
"""保留最后一个值的 reducer 函数"""
return right if right is not None else left
def sum_values(left: int, right: int) -> int:
"""累加数值的 reducer 函数"""
return (left or 0) + (right or 0)
class AgentState(TypedDict, total=False):
"""补货建议 Agent 状态
使用 Annotated 类型定义 reducer 函数,处理并行节点的状态合并
"""
# 任务标识(使用 keep_last,因为这些值在并行执行时相同)
task_no: Annotated[str, keep_last]
group_id: Annotated[int, keep_last]
brand_grouping_id: Annotated[Optional[int], keep_last]
brand_grouping_name: Annotated[str, keep_last]
dealer_grouping_id: Annotated[int, keep_last]
dealer_grouping_name: Annotated[str, keep_last]
statistics_date: Annotated[str, keep_last]
# part_ratio 原始数据
part_ratios: Annotated[List[dict], merge_dicts]
# SQL Agent 相关
sql_queries: Annotated[List[str], merge_lists]
sql_results: Annotated[List[dict], merge_dicts]
sql_retry_count: Annotated[int, keep_last]
sql_execution_logs: Annotated[List[dict], merge_dicts]
# 计算结果
base_ratio: Annotated[Decimal, keep_last]
allocated_details: Annotated[List[dict], merge_dicts]
details: Annotated[List[Any], merge_lists]
# LLM 建议明细
llm_suggestions: Annotated[List[Any], merge_lists]
# 配件汇总结果
part_results: Annotated[List[Any], merge_lists]
# LLM 统计(使用累加,合并多个并行节点的 token 使用量)
llm_provider: Annotated[str, keep_last]
llm_model: Annotated[str, keep_last]
llm_prompt_tokens: Annotated[int, sum_values]
llm_completion_tokens: Annotated[int, sum_values]
# 执行状态
status: Annotated[str, keep_last]
error_message: Annotated[str, keep_last]
start_time: Annotated[float, keep_last]
end_time: Annotated[float, keep_last]
# 流程控制
current_node: Annotated[str, keep_last]
next_node: Annotated[str, keep_last]