agent.py
9.52 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
"""
SQL Agent 主类模块
组合 Executor 和 Analyzer 提供完整的 SQL Agent 功能
"""
import logging
from typing import Any, Dict, List, Optional, Tuple
from decimal import Decimal
from .executor import SQLExecutor
from .analyzer import PartAnalyzer
from .prompts import load_prompt
from ...models import SQLExecutionResult, ReplenishmentSuggestion, PartAnalysisResult
logger = logging.getLogger(__name__)
class SQLAgent:
"""SQL Agent - 组合 SQL 执行和配件分析功能"""
def __init__(self, db_connection=None):
self._executor = SQLExecutor(db_connection)
self._analyzer = PartAnalyzer()
def close(self):
"""关闭连接"""
self._executor.close()
# === 委托给 SQLExecutor 的方法 ===
def generate_sql(
self,
question: str,
context: Optional[Dict] = None,
previous_error: Optional[str] = None,
) -> Tuple[str, str]:
"""使用LLM生成SQL"""
return self._executor.generate_sql(question, context, previous_error)
def execute_sql(self, sql: str) -> Tuple[bool, Any, Optional[str]]:
"""执行SQL查询"""
return self._executor.execute_sql(sql)
def query_with_retry(
self,
question: str,
context: Optional[Dict] = None,
) -> SQLExecutionResult:
"""带重试的查询"""
return self._executor.query_with_retry(question, context)
# === 委托给 PartAnalyzer 的方法 ===
def group_parts_by_code(self, part_ratios: List[Dict]) -> Dict[str, List[Dict]]:
"""按配件编码分组"""
return self._analyzer.group_parts_by_code(part_ratios)
def generate_suggestions(
self,
part_data: List[Dict],
dealer_grouping_id: int,
dealer_grouping_name: str,
statistics_date: str,
) -> Tuple[List[ReplenishmentSuggestion], Dict]:
"""生成补货建议"""
return self._analyzer.generate_suggestions(
part_data, dealer_grouping_id, dealer_grouping_name, statistics_date
)
def analyze_parts_by_group(
self,
part_ratios: List[Dict],
dealer_grouping_id: int,
dealer_grouping_name: str,
statistics_date: str,
target_ratio: Decimal = Decimal("1.3"),
limit: Optional[int] = None,
callback: Optional[Any] = None,
) -> Tuple[List[ReplenishmentSuggestion], List[PartAnalysisResult], Dict]:
"""按配件分组分析补货建议"""
return self._analyzer.analyze_parts_by_group(
part_ratios,
dealer_grouping_id,
dealer_grouping_name,
statistics_date,
target_ratio,
limit,
callback,
)
# === 数据查询方法 ===
def fetch_part_ratios(
self,
group_id: int,
dealer_grouping_id: int,
statistics_date: str,
) -> List[Dict]:
"""
查询part_ratio数据
Args:
group_id: 集团ID
dealer_grouping_id: 商家组合ID
statistics_date: 统计日期
Returns:
配件库销比数据列表
"""
conn = self._executor._get_connection()
cursor = conn.cursor(dictionary=True)
try:
# 1. 查询商家组合关联的品牌组合配置
brand_grouping_ids = []
try:
cursor.execute(
"SELECT part_purchase_brand_assemble_id FROM artificial_region_dealer WHERE id = %s",
(dealer_grouping_id,)
)
rows = cursor.fetchall()
for row in rows:
bid = row.get("part_purchase_brand_assemble_id")
if bid:
brand_grouping_ids.append(bid)
if brand_grouping_ids:
logger.info(f"商家组合关联品牌组合: dealer_grouping_id={dealer_grouping_id} -> brand_grouping_ids={brand_grouping_ids}")
except Exception as e:
logger.warning(f"查询商家组合配置失败: {e}")
sql = """
SELECT
id, group_id, brand_id, brand_grouping_id,
dealer_grouping_id,
supplier_id, supplier_name, area_id, area_name,
shop_id, shop_name, part_id, part_code, part_name,
unit, cost_price,
in_stock_unlocked_cnt, has_plan_cnt, on_the_way_cnt,
out_stock_cnt, buy_cnt, storage_locked_cnt,
out_stock_ongoing_cnt, stock_age, out_times, out_duration,
transfer_cnt, gen_transfer_cnt,
part_biz_type, statistics_date,
(in_stock_unlocked_cnt + on_the_way_cnt + has_plan_cnt) as valid_storage_cnt,
((out_stock_cnt + storage_locked_cnt + out_stock_ongoing_cnt + buy_cnt) / 3) as avg_sales_cnt
FROM part_ratio
WHERE group_id = %s
AND dealer_grouping_id = %s
AND statistics_date = %s
AND part_biz_type = 1
"""
params = [group_id, dealer_grouping_id, statistics_date]
# 如果有配置的品牌组合,用 IN 过滤
if brand_grouping_ids:
placeholders = ", ".join(["%s"] * len(brand_grouping_ids))
sql += f" AND brand_grouping_id IN ({placeholders})"
params.extend(brand_grouping_ids)
# 优先处理有销量的配件
sql += """ ORDER BY
CASE WHEN ((out_stock_cnt + storage_locked_cnt + out_stock_ongoing_cnt + buy_cnt) / 3) > 0 THEN 0 ELSE 1 END,
(in_stock_unlocked_cnt + on_the_way_cnt + has_plan_cnt) / NULLIF((out_stock_cnt + storage_locked_cnt + out_stock_ongoing_cnt + buy_cnt) / 3, 0) ASC,
((out_stock_cnt + storage_locked_cnt + out_stock_ongoing_cnt + buy_cnt) / 3) DESC
"""
cursor.execute(sql, params)
rows = cursor.fetchall()
logger.info(f"获取part_ratio数据: dealer_grouping_id={dealer_grouping_id}, count={len(rows)}")
return rows
finally:
cursor.close()
def fetch_part_ratios_by_part_code(
self,
group_id: int,
dealer_grouping_id: int,
statistics_date: str,
part_code: str,
) -> List[Dict]:
"""
按配件编码查询 part_ratio 数据(单种配件补货用)
Args:
group_id: 集团ID
dealer_grouping_id: 商家组合ID
statistics_date: 统计日期
part_code: 配件编码
Returns:
指定配件的库销比数据列表
"""
conn = self._executor._get_connection()
cursor = conn.cursor(dictionary=True)
try:
# 查询商家组合关联的品牌组合配置
brand_grouping_ids = []
try:
cursor.execute(
"SELECT part_purchase_brand_assemble_id FROM artificial_region_dealer WHERE id = %s",
(dealer_grouping_id,)
)
rows = cursor.fetchall()
for row in rows:
bid = row.get("part_purchase_brand_assemble_id")
if bid:
brand_grouping_ids.append(bid)
except Exception as e:
logger.warning(f"查询商家组合配置失败: {e}")
sql = """
SELECT
id, group_id, brand_id, brand_grouping_id,
dealer_grouping_id,
supplier_id, supplier_name, area_id, area_name,
shop_id, shop_name, part_id, part_code, part_name,
unit, cost_price,
in_stock_unlocked_cnt, has_plan_cnt, on_the_way_cnt,
out_stock_cnt, buy_cnt, storage_locked_cnt,
out_stock_ongoing_cnt, stock_age, out_times, out_duration,
transfer_cnt, gen_transfer_cnt,
part_biz_type, statistics_date,
(in_stock_unlocked_cnt + on_the_way_cnt + has_plan_cnt) as valid_storage_cnt,
((out_stock_cnt + storage_locked_cnt + out_stock_ongoing_cnt + buy_cnt) / 3) as avg_sales_cnt
FROM part_ratio
WHERE group_id = %s
AND dealer_grouping_id = %s
AND statistics_date = %s
AND part_biz_type = 1
AND part_code = %s
"""
params = [group_id, dealer_grouping_id, statistics_date, part_code]
if brand_grouping_ids:
placeholders = ", ".join(["%s"] * len(brand_grouping_ids))
sql += f" AND brand_grouping_id IN ({placeholders})"
params.extend(brand_grouping_ids)
sql += """ ORDER BY
CASE WHEN ((out_stock_cnt + storage_locked_cnt + out_stock_ongoing_cnt + buy_cnt) / 3) > 0 THEN 0 ELSE 1 END,
(in_stock_unlocked_cnt + on_the_way_cnt + has_plan_cnt) / NULLIF((out_stock_cnt + storage_locked_cnt + out_stock_ongoing_cnt + buy_cnt) / 3, 0) ASC,
((out_stock_cnt + storage_locked_cnt + out_stock_ongoing_cnt + buy_cnt) / 3) DESC
"""
cursor.execute(sql, params)
rows = cursor.fetchall()
logger.info(
f"获取单配件part_ratio数据: dealer_grouping_id={dealer_grouping_id}, "
f"part_code={part_code}, count={len(rows)}"
)
return rows
finally:
cursor.close()