log_repo.py 5.14 KB
"""
日志和汇总数据访问层

提供 ai_task_execution_log 和 ai_replenishment_part_summary 表的 CRUD 操作
"""

import logging
from typing import List

from ..db import get_connection
from ...models import TaskExecutionLog, ReplenishmentPartSummary

logger = logging.getLogger(__name__)


class LogRepository:
    """执行日志数据访问"""

    def __init__(self, connection=None):
        self._conn = connection

    def _get_connection(self):
        """获取数据库连接"""
        if self._conn is None or not self._conn.is_connected():
            self._conn = get_connection()
        return self._conn

    def close(self):
        """关闭连接"""
        if self._conn and self._conn.is_connected():
            self._conn.close()
            self._conn = None

    def create(self, log: TaskExecutionLog) -> int:
        """
        保存执行日志

        Returns:
            插入的日志ID
        """
        conn = self._get_connection()
        cursor = conn.cursor()

        try:
            sql = """
                INSERT INTO ai_task_execution_log (
                    task_no, group_id, dealer_grouping_id, brand_grouping_id,
                    brand_grouping_name, dealer_grouping_name,
                    step_name, step_order, status, input_data, output_data,
                    error_message, retry_count, sql_query, llm_prompt, llm_response,
                    llm_tokens, execution_time_ms, start_time, end_time, create_time
                ) VALUES (
                    %s, %s, %s, %s, %s, %s, %s, %s, %s, %s,
                    %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, NOW()
                )
            """

            values = (
                log.task_no, log.group_id, log.dealer_grouping_id,
                log.brand_grouping_id, log.brand_grouping_name,
                log.dealer_grouping_name,
                log.step_name, log.step_order, int(log.status),
                log.input_data, log.output_data, log.error_message,
                log.retry_count, log.sql_query, log.llm_prompt, log.llm_response,
                log.llm_tokens, log.execution_time_ms,
                log.start_time, log.end_time,
            )

            cursor.execute(sql, values)
            conn.commit()

            return cursor.lastrowid

        finally:
            cursor.close()


class SummaryRepository:
    """配件汇总数据访问"""

    def __init__(self, connection=None):
        self._conn = connection

    def _get_connection(self):
        """获取数据库连接"""
        if self._conn is None or not self._conn.is_connected():
            self._conn = get_connection()
        return self._conn

    def close(self):
        """关闭连接"""
        if self._conn and self._conn.is_connected():
            self._conn.close()
            self._conn = None

    def save_batch(self, summaries: List[ReplenishmentPartSummary]) -> int:
        """
        批量保存配件汇总

        Returns:
            插入的行数
        """
        if not summaries:
            return 0

        conn = self._get_connection()
        cursor = conn.cursor()

        try:
            sql = """
                INSERT INTO ai_replenishment_part_summary (
                    task_no, group_id, dealer_grouping_id, part_code, part_name,
                    unit, cost_price, total_storage_cnt, total_avg_sales_cnt,
                    group_current_ratio, total_suggest_cnt, total_suggest_amount,
                    shop_count, need_replenishment_shop_count, part_decision_reason,
                    priority, llm_confidence, statistics_date, create_time
                ) VALUES (
                    %s, %s, %s, %s, %s, %s, %s, %s, %s, %s,
                    %s, %s, %s, %s, %s, %s, %s, %s, NOW()
                )
            """

            values = [
                (
                    s.task_no, s.group_id, s.dealer_grouping_id, s.part_code, s.part_name,
                    s.unit, float(s.cost_price), float(s.total_storage_cnt),
                    float(s.total_avg_sales_cnt),
                    float(s.group_current_ratio) if s.group_current_ratio else None,
                    s.total_suggest_cnt, float(s.total_suggest_amount),
                    s.shop_count, s.need_replenishment_shop_count, s.part_decision_reason,
                    s.priority, s.llm_confidence, s.statistics_date,
                )
                for s in summaries
            ]

            cursor.executemany(sql, values)
            conn.commit()

            logger.info(f"保存配件汇总: {cursor.rowcount}条")
            return cursor.rowcount

        finally:
            cursor.close()

    def delete_by_task_no(self, task_no: str) -> int:
        """
        删除指定任务的配件汇总

        Returns:
            删除的行数
        """
        conn = self._get_connection()
        cursor = conn.cursor()

        try:
            sql = "DELETE FROM ai_replenishment_part_summary WHERE task_no = %s"
            cursor.execute(sql, (task_no,))
            conn.commit()

            logger.info(f"删除配件汇总: task_no={task_no}, rows={cursor.rowcount}")
            return cursor.rowcount

        finally:
            cursor.close()