""" 任务数据访问层 提供 ai_replenishment_task 表的 CRUD 操作 """ import logging from datetime import datetime from typing import Optional from ..db import get_connection from ...models import ReplenishmentTask logger = logging.getLogger(__name__) class TaskRepository: """任务数据访问""" def __init__(self, connection=None): self._conn = connection def _get_connection(self): """获取数据库连接""" if self._conn is None or not self._conn.is_connected(): self._conn = get_connection() return self._conn def close(self): """关闭连接""" if self._conn and self._conn.is_connected(): self._conn.close() self._conn = None def create(self, task: ReplenishmentTask) -> int: """ 创建任务记录 Returns: 插入的任务ID """ conn = self._get_connection() cursor = conn.cursor() try: sql = """ INSERT INTO ai_replenishment_task ( task_no, group_id, dealer_grouping_id, dealer_grouping_name, brand_grouping_id, plan_amount, actual_amount, part_count, base_ratio, status, error_message, llm_provider, llm_model, llm_total_tokens, statistics_date, start_time, end_time, create_time ) VALUES ( %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, NOW() ) """ values = ( task.task_no, task.group_id, task.dealer_grouping_id, task.dealer_grouping_name, task.brand_grouping_id, float(task.plan_amount), float(task.actual_amount), task.part_count, float(task.base_ratio) if task.base_ratio else None, int(task.status), task.error_message, task.llm_provider, task.llm_model, task.llm_total_tokens, task.statistics_date, datetime.now() if task.start_time is None else task.start_time, task.end_time, ) cursor.execute(sql, values) conn.commit() task_id = cursor.lastrowid logger.info(f"创建任务记录: task_no={task.task_no}, id={task_id}") return task_id finally: cursor.close() def update(self, task: ReplenishmentTask) -> int: """ 更新任务记录 Returns: 更新的行数 """ conn = self._get_connection() cursor = conn.cursor() try: sql = """ UPDATE ai_replenishment_task SET actual_amount = %s, part_count = %s, base_ratio = %s, status = %s, error_message = %s, llm_provider = %s, llm_model = %s, llm_total_tokens = %s, end_time = %s WHERE task_no = %s """ values = ( float(task.actual_amount), task.part_count, float(task.base_ratio) if task.base_ratio else None, int(task.status), task.error_message, task.llm_provider, task.llm_model, task.llm_total_tokens, datetime.now() if task.end_time is None else task.end_time, task.task_no, ) cursor.execute(sql, values) conn.commit() logger.info(f"更新任务记录: task_no={task.task_no}, rows={cursor.rowcount}") return cursor.rowcount finally: cursor.close() def find_by_task_no(self, task_no: str) -> Optional[ReplenishmentTask]: """根据 task_no 查询任务""" conn = self._get_connection() cursor = conn.cursor(dictionary=True) try: sql = "SELECT * FROM ai_replenishment_task WHERE task_no = %s" cursor.execute(sql, (task_no,)) row = cursor.fetchone() if row: return ReplenishmentTask(**row) return None finally: cursor.close()