agent.py
6.14 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
"""
SQL Agent 主类模块
组合 Executor 和 Analyzer 提供完整的 SQL Agent 功能
"""
import logging
from typing import Any, Dict, List, Optional, Tuple
from decimal import Decimal
from .executor import SQLExecutor
from .analyzer import PartAnalyzer
from .prompts import load_prompt
from ...models import SQLExecutionResult, ReplenishmentSuggestion, PartAnalysisResult
logger = logging.getLogger(__name__)
class SQLAgent:
"""SQL Agent - 组合 SQL 执行和配件分析功能"""
def __init__(self, db_connection=None):
self._executor = SQLExecutor(db_connection)
self._analyzer = PartAnalyzer()
def close(self):
"""关闭连接"""
self._executor.close()
# === 委托给 SQLExecutor 的方法 ===
def generate_sql(
self,
question: str,
context: Optional[Dict] = None,
previous_error: Optional[str] = None,
) -> Tuple[str, str]:
"""使用LLM生成SQL"""
return self._executor.generate_sql(question, context, previous_error)
def execute_sql(self, sql: str) -> Tuple[bool, Any, Optional[str]]:
"""执行SQL查询"""
return self._executor.execute_sql(sql)
def query_with_retry(
self,
question: str,
context: Optional[Dict] = None,
) -> SQLExecutionResult:
"""带重试的查询"""
return self._executor.query_with_retry(question, context)
# === 委托给 PartAnalyzer 的方法 ===
def group_parts_by_code(self, part_ratios: List[Dict]) -> Dict[str, List[Dict]]:
"""按配件编码分组"""
return self._analyzer.group_parts_by_code(part_ratios)
def generate_suggestions(
self,
part_data: List[Dict],
dealer_grouping_id: int,
dealer_grouping_name: str,
statistics_date: str,
) -> Tuple[List[ReplenishmentSuggestion], Dict]:
"""生成补货建议"""
return self._analyzer.generate_suggestions(
part_data, dealer_grouping_id, dealer_grouping_name, statistics_date
)
def analyze_parts_by_group(
self,
part_ratios: List[Dict],
dealer_grouping_id: int,
dealer_grouping_name: str,
statistics_date: str,
target_ratio: Decimal = Decimal("1.3"),
limit: Optional[int] = None,
callback: Optional[Any] = None,
) -> Tuple[List[ReplenishmentSuggestion], List[PartAnalysisResult], Dict]:
"""按配件分组分析补货建议"""
return self._analyzer.analyze_parts_by_group(
part_ratios,
dealer_grouping_id,
dealer_grouping_name,
statistics_date,
target_ratio,
limit,
callback,
)
# === 数据查询方法 ===
def fetch_part_ratios(
self,
group_id: int,
dealer_grouping_id: int,
statistics_date: str,
) -> List[Dict]:
"""
查询part_ratio数据
Args:
group_id: 集团ID
dealer_grouping_id: 商家组合ID
statistics_date: 统计日期
Returns:
配件库销比数据列表
"""
conn = self._executor._get_connection()
cursor = conn.cursor(dictionary=True)
try:
# 1. 查询商家组合关联的品牌组合配置
brand_grouping_ids = []
try:
cursor.execute(
"SELECT part_purchase_brand_assemble_id FROM artificial_region_dealer WHERE id = %s",
(dealer_grouping_id,)
)
rows = cursor.fetchall()
for row in rows:
bid = row.get("part_purchase_brand_assemble_id")
if bid:
brand_grouping_ids.append(bid)
if brand_grouping_ids:
logger.info(f"商家组合关联品牌组合: dealer_grouping_id={dealer_grouping_id} -> brand_grouping_ids={brand_grouping_ids}")
except Exception as e:
logger.warning(f"查询商家组合配置失败: {e}")
sql = """
SELECT
id, group_id, brand_id, brand_grouping_id,
dealer_grouping_id,
supplier_id, supplier_name, area_id, area_name,
shop_id, shop_name, part_id, part_code, part_name,
unit, cost_price,
in_stock_unlocked_cnt, has_plan_cnt, on_the_way_cnt,
out_stock_cnt, buy_cnt, storage_locked_cnt,
out_stock_ongoing_cnt, stock_age, out_times, out_duration,
transfer_cnt, gen_transfer_cnt,
part_biz_type, statistics_date,
(in_stock_unlocked_cnt + on_the_way_cnt + has_plan_cnt) as valid_storage_cnt,
((out_stock_cnt + storage_locked_cnt + out_stock_ongoing_cnt + buy_cnt) / 3) as avg_sales_cnt
FROM part_ratio
WHERE group_id = %s
AND dealer_grouping_id = %s
AND statistics_date = %s
AND part_biz_type = 1
"""
params = [group_id, dealer_grouping_id, statistics_date]
# 如果有配置的品牌组合,用 IN 过滤
if brand_grouping_ids:
placeholders = ", ".join(["%s"] * len(brand_grouping_ids))
sql += f" AND brand_grouping_id IN ({placeholders})"
params.extend(brand_grouping_ids)
# 优先处理有销量的配件
sql += """ ORDER BY
CASE WHEN ((out_stock_cnt + storage_locked_cnt + out_stock_ongoing_cnt + buy_cnt) / 3) > 0 THEN 0 ELSE 1 END,
(in_stock_unlocked_cnt + on_the_way_cnt + has_plan_cnt) / NULLIF((out_stock_cnt + storage_locked_cnt + out_stock_ongoing_cnt + buy_cnt) / 3, 0) ASC,
((out_stock_cnt + storage_locked_cnt + out_stock_ongoing_cnt + buy_cnt) / 3) DESC
"""
cursor.execute(sql, params)
rows = cursor.fetchall()
logger.info(f"获取part_ratio数据: dealer_grouping_id={dealer_grouping_id}, count={len(rows)}")
return rows
finally:
cursor.close()