task_repo.py
4.35 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
"""
任务数据访问层
提供 ai_replenishment_task 表的 CRUD 操作
"""
import logging
from datetime import datetime
from typing import Optional
from ..db import get_connection
from ...models import ReplenishmentTask
logger = logging.getLogger(__name__)
class TaskRepository:
"""任务数据访问"""
def __init__(self, connection=None):
self._conn = connection
def _get_connection(self):
"""获取数据库连接"""
if self._conn is None or not self._conn.is_connected():
self._conn = get_connection()
return self._conn
def close(self):
"""关闭连接"""
if self._conn and self._conn.is_connected():
self._conn.close()
self._conn = None
def create(self, task: ReplenishmentTask) -> int:
"""
创建任务记录
Returns:
插入的任务ID
"""
conn = self._get_connection()
cursor = conn.cursor()
try:
sql = """
INSERT INTO ai_replenishment_task (
task_no, group_id, dealer_grouping_id, dealer_grouping_name,
brand_grouping_id, plan_amount, actual_amount, part_count,
base_ratio, status, error_message, llm_provider, llm_model,
llm_total_tokens, statistics_date, start_time, end_time, create_time
) VALUES (
%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, NOW()
)
"""
values = (
task.task_no,
task.group_id,
task.dealer_grouping_id,
task.dealer_grouping_name,
task.brand_grouping_id,
float(task.plan_amount),
float(task.actual_amount),
task.part_count,
float(task.base_ratio) if task.base_ratio else None,
int(task.status),
task.error_message,
task.llm_provider,
task.llm_model,
task.llm_total_tokens,
task.statistics_date,
datetime.now() if task.start_time is None else task.start_time,
task.end_time,
)
cursor.execute(sql, values)
conn.commit()
task_id = cursor.lastrowid
logger.info(f"创建任务记录: task_no={task.task_no}, id={task_id}")
return task_id
finally:
cursor.close()
def update(self, task: ReplenishmentTask) -> int:
"""
更新任务记录
Returns:
更新的行数
"""
conn = self._get_connection()
cursor = conn.cursor()
try:
sql = """
UPDATE ai_replenishment_task
SET actual_amount = %s,
part_count = %s,
base_ratio = %s,
status = %s,
error_message = %s,
llm_provider = %s,
llm_model = %s,
llm_total_tokens = %s,
end_time = %s
WHERE task_no = %s
"""
values = (
float(task.actual_amount),
task.part_count,
float(task.base_ratio) if task.base_ratio else None,
int(task.status),
task.error_message,
task.llm_provider,
task.llm_model,
task.llm_total_tokens,
datetime.now() if task.end_time is None else task.end_time,
task.task_no,
)
cursor.execute(sql, values)
conn.commit()
logger.info(f"更新任务记录: task_no={task.task_no}, rows={cursor.rowcount}")
return cursor.rowcount
finally:
cursor.close()
def find_by_task_no(self, task_no: str) -> Optional[ReplenishmentTask]:
"""根据 task_no 查询任务"""
conn = self._get_connection()
cursor = conn.cursor(dictionary=True)
try:
sql = "SELECT * FROM ai_replenishment_task WHERE task_no = %s"
cursor.execute(sql, (task_no,))
row = cursor.fetchone()
if row:
return ReplenishmentTask(**row)
return None
finally:
cursor.close()