result_writer.py 12.6 KB
"""
结果写入服务
负责将补货建议结果写入数据库
"""

import logging
import json
from typing import List, Optional
from datetime import datetime

from .db import get_connection
from ..models import (
    ReplenishmentTask,
    ReplenishmentDetail,
    TaskExecutionLog,
    ReplenishmentPartSummary,
    AnalysisReport,
)

logger = logging.getLogger(__name__)


class ResultWriter:
    """结果写入服务"""

    def __init__(self):
        self._conn = None

    def _get_connection(self):
        """获取数据库连接"""
        if self._conn is None or not self._conn.is_connected():
            self._conn = get_connection()
        return self._conn

    def close(self):
        """关闭连接"""
        if self._conn and self._conn.is_connected():
            self._conn.close()
            self._conn = None

    def save_task(self, task: ReplenishmentTask) -> int:
        """
        保存任务记录
        
        Returns:
            插入的任务ID
        """
        conn = self._get_connection()
        cursor = conn.cursor()

        try:
            sql = """
                INSERT INTO ai_replenishment_task (
                    task_no, group_id, dealer_grouping_id, dealer_grouping_name,
                    brand_grouping_id, plan_amount, actual_amount, part_count,
                    base_ratio, status, error_message, llm_provider, llm_model,
                    llm_total_tokens, statistics_date, start_time, end_time, create_time
                ) VALUES (
                    %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, NOW()
                )
            """
            
            values = (
                task.task_no,
                task.group_id,
                task.dealer_grouping_id,
                task.dealer_grouping_name,
                task.brand_grouping_id,
                float(task.plan_amount),
                float(task.actual_amount),
                task.part_count,
                float(task.base_ratio) if task.base_ratio else None,
                int(task.status),
                task.error_message,
                task.llm_provider,
                task.llm_model,
                task.llm_total_tokens,
                task.statistics_date,
                datetime.now() if task.start_time is None else task.start_time,
                task.end_time,
            )
            
            cursor.execute(sql, values)
            conn.commit()
            
            task_id = cursor.lastrowid
            logger.info(f"保存任务记录: task_no={task.task_no}, id={task_id}")
            return task_id

        finally:
            cursor.close()

    def update_task(self, task: ReplenishmentTask) -> int:
        """
        更新任务记录
        
        Returns:
            更新的行数
        """
        conn = self._get_connection()
        cursor = conn.cursor()

        try:
            sql = """
                UPDATE ai_replenishment_task
                SET actual_amount = %s,
                    part_count = %s,
                    base_ratio = %s,
                    status = %s,
                    error_message = %s,
                    llm_provider = %s,
                    llm_model = %s,
                    llm_total_tokens = %s,
                    end_time = %s
                WHERE task_no = %s
            """
            
            values = (
                float(task.actual_amount),
                task.part_count,
                float(task.base_ratio) if task.base_ratio else None,
                int(task.status),
                task.error_message,
                task.llm_provider,
                task.llm_model,
                task.llm_total_tokens,
                datetime.now() if task.end_time is None else task.end_time,
                task.task_no,
            )
            
            cursor.execute(sql, values)
            conn.commit()
            
            logger.info(f"更新任务记录: task_no={task.task_no}, rows={cursor.rowcount}")
            return cursor.rowcount

        finally:
            cursor.close()

    def save_details(self, details: List[ReplenishmentDetail]) -> int:
        """
        保存补货明细
        
        Returns:
            插入的行数
        """
        if not details:
            return 0

        conn = self._get_connection()
        cursor = conn.cursor()

        try:
            sql = """
                INSERT INTO ai_replenishment_detail (
                    task_no, group_id, dealer_grouping_id, brand_grouping_id,
                    shop_id, shop_name, part_code, part_name, unit, cost_price,
                    current_ratio, base_ratio, post_plan_ratio,
                    valid_storage_cnt, avg_sales_cnt, suggest_cnt, suggest_amount,
                    suggestion_reason, priority, llm_confidence, statistics_date, create_time
                ) VALUES (
                    %s, %s, %s, %s, %s, %s, %s, %s, %s, %s,
                    %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, NOW()
                )
            """
            
            values = [
                (
                    d.task_no, d.group_id, d.dealer_grouping_id, d.brand_grouping_id,
                    d.shop_id, d.shop_name, d.part_code, d.part_name, d.unit,
                    float(d.cost_price),
                    float(d.current_ratio) if d.current_ratio else None,
                    float(d.base_ratio) if d.base_ratio else None,
                    float(d.post_plan_ratio) if d.post_plan_ratio else None,
                    float(d.valid_storage_cnt), float(d.avg_sales_cnt),
                    d.suggest_cnt, float(d.suggest_amount),
                    d.suggestion_reason, d.priority, d.llm_confidence, d.statistics_date,
                )
                for d in details
            ]
            
            cursor.executemany(sql, values)
            conn.commit()
            
            logger.info(f"保存补货明细: {cursor.rowcount}条")
            return cursor.rowcount

        finally:
            cursor.close()

    def save_execution_log(self, log: TaskExecutionLog) -> int:
        """
        保存执行日志
        
        Returns:
            插入的日志ID
        """
        conn = self._get_connection()
        cursor = conn.cursor()

        try:
            sql = """
                INSERT INTO ai_task_execution_log (
                    task_no, group_id, dealer_grouping_id, brand_grouping_id,
                    brand_grouping_name, dealer_grouping_name,
                    step_name, step_order, status, input_data, output_data,
                    error_message, retry_count, sql_query, llm_prompt, llm_response,
                    llm_tokens, execution_time_ms, start_time, end_time, create_time
                ) VALUES (
                    %s, %s, %s, %s, %s, %s, %s, %s, %s, %s,
                    %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, NOW()
                )
            """
            
            values = (
                log.task_no, log.group_id, log.dealer_grouping_id,
                log.brand_grouping_id, log.brand_grouping_name,
                log.dealer_grouping_name,
                log.step_name, log.step_order, int(log.status),
                log.input_data, log.output_data, log.error_message,
                log.retry_count, log.sql_query, log.llm_prompt, log.llm_response,
                log.llm_tokens, log.execution_time_ms,
                log.start_time, log.end_time,
            )
            
            cursor.execute(sql, values)
            conn.commit()
            
            log_id = cursor.lastrowid
            return log_id

        finally:
            cursor.close()

    def save_part_summaries(self, summaries: List[ReplenishmentPartSummary]) -> int:
        """
        保存配件汇总信息
        
        Returns:
            插入的行数
        """
        if not summaries:
            return 0

        conn = self._get_connection()
        cursor = conn.cursor()

        try:
            sql = """
                INSERT INTO ai_replenishment_part_summary (
                    task_no, group_id, dealer_grouping_id, part_code, part_name,
                    unit, cost_price, total_storage_cnt, total_avg_sales_cnt,
                    group_current_ratio, total_suggest_cnt, total_suggest_amount,
                    shop_count, need_replenishment_shop_count, part_decision_reason,
                    priority, llm_confidence, statistics_date, create_time
                ) VALUES (
                    %s, %s, %s, %s, %s, %s, %s, %s, %s, %s,
                    %s, %s, %s, %s, %s, %s, %s, %s, NOW()
                )
            """
            
            values = [
                (
                    s.task_no, s.group_id, s.dealer_grouping_id, s.part_code, s.part_name,
                    s.unit, float(s.cost_price), float(s.total_storage_cnt),
                    float(s.total_avg_sales_cnt),
                    float(s.group_current_ratio) if s.group_current_ratio else None,
                    s.total_suggest_cnt, float(s.total_suggest_amount),
                    s.shop_count, s.need_replenishment_shop_count, s.part_decision_reason,
                    s.priority, s.llm_confidence, s.statistics_date,
                )
                for s in summaries
            ]
            
            cursor.executemany(sql, values)
            conn.commit()
            
            logger.info(f"保存配件汇总: {cursor.rowcount}条")
            return cursor.rowcount

        finally:
            cursor.close()

    def delete_part_summaries_by_task(self, task_no: str) -> int:
        """
        删除指定任务的配件汇总
        
        Returns:
            删除的行数
        """
        conn = self._get_connection()
        cursor = conn.cursor()

        try:
            sql = "DELETE FROM ai_replenishment_part_summary WHERE task_no = %s"
            cursor.execute(sql, (task_no,))
            conn.commit()
            
            logger.info(f"删除配件汇总: task_no={task_no}, rows={cursor.rowcount}")
            return cursor.rowcount

        finally:
            cursor.close()

    def delete_details_by_task(self, task_no: str) -> int:
        """
        删除指定任务的补货明细
        
        Returns:
            删除的行数
        """
        conn = self._get_connection()
        cursor = conn.cursor()

        try:
            sql = "DELETE FROM ai_replenishment_detail WHERE task_no = %s"
            cursor.execute(sql, (task_no,))
            conn.commit()
            
            logger.info(f"删除补货明细: task_no={task_no}, rows={cursor.rowcount}")
            return cursor.rowcount

        finally:
            cursor.close()

    def save_analysis_report(self, report: AnalysisReport) -> int:
        """
        保存分析报告(四大板块 JSON 结构)

        Returns:
            插入的报告ID
        """
        conn = self._get_connection()
        cursor = conn.cursor()

        try:
            sql = """
                INSERT INTO ai_analysis_report (
                    task_no, group_id, dealer_grouping_id, dealer_grouping_name,
                    brand_grouping_id, report_type,
                    inventory_overview, sales_analysis,
                    inventory_health, replenishment_summary,
                    llm_provider, llm_model, llm_tokens, execution_time_ms,
                    statistics_date, create_time
                ) VALUES (
                    %s, %s, %s, %s, %s, %s,
                    %s, %s, %s, %s,
                    %s, %s, %s, %s,
                    %s, NOW()
                )
            """

            values = (
                report.task_no,
                report.group_id,
                report.dealer_grouping_id,
                report.dealer_grouping_name,
                report.brand_grouping_id,
                report.report_type,
                json.dumps(report.inventory_overview, ensure_ascii=False) if report.inventory_overview else None,
                json.dumps(report.sales_analysis, ensure_ascii=False) if report.sales_analysis else None,
                json.dumps(report.inventory_health, ensure_ascii=False) if report.inventory_health else None,
                json.dumps(report.replenishment_summary, ensure_ascii=False) if report.replenishment_summary else None,
                report.llm_provider,
                report.llm_model,
                report.llm_tokens,
                report.execution_time_ms,
                report.statistics_date,
            )

            cursor.execute(sql, values)
            conn.commit()

            report_id = cursor.lastrowid
            logger.info(f"保存分析报告: task_no={report.task_no}, id={report_id}")
            return report_id

        finally:
            cursor.close()