detail_repo.py 3.65 KB
"""
补货明细数据访问层

提供 ai_replenishment_detail 表的 CRUD 操作
"""

import logging
from typing import List

from ..db import get_connection
from ...models import ReplenishmentDetail

logger = logging.getLogger(__name__)


class DetailRepository:
    """补货明细数据访问"""

    def __init__(self, connection=None):
        self._conn = connection

    def _get_connection(self):
        """获取数据库连接"""
        if self._conn is None or not self._conn.is_connected():
            self._conn = get_connection()
        return self._conn

    def close(self):
        """关闭连接"""
        if self._conn and self._conn.is_connected():
            self._conn.close()
            self._conn = None

    def save_batch(self, details: List[ReplenishmentDetail]) -> int:
        """
        批量保存补货明细

        Returns:
            插入的行数
        """
        if not details:
            return 0

        conn = self._get_connection()
        cursor = conn.cursor()

        try:
            sql = """
                INSERT INTO ai_replenishment_detail (
                    task_no, group_id, dealer_grouping_id, brand_grouping_id,
                    shop_id, shop_name, part_code, part_name, unit, cost_price,
                    current_ratio, base_ratio, post_plan_ratio,
                    valid_storage_cnt, avg_sales_cnt, suggest_cnt, suggest_amount,
                    suggestion_reason, priority, llm_confidence, statistics_date, create_time
                ) VALUES (
                    %s, %s, %s, %s, %s, %s, %s, %s, %s, %s,
                    %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, NOW()
                )
            """

            values = [
                (
                    d.task_no, d.group_id, d.dealer_grouping_id, d.brand_grouping_id,
                    d.shop_id, d.shop_name, d.part_code, d.part_name, d.unit,
                    float(d.cost_price),
                    float(d.current_ratio) if d.current_ratio else None,
                    float(d.base_ratio) if d.base_ratio else None,
                    float(d.post_plan_ratio) if d.post_plan_ratio else None,
                    float(d.valid_storage_cnt), float(d.avg_sales_cnt),
                    d.suggest_cnt, float(d.suggest_amount),
                    d.suggestion_reason, d.priority, d.llm_confidence, d.statistics_date,
                )
                for d in details
            ]

            cursor.executemany(sql, values)
            conn.commit()

            logger.info(f"保存补货明细: {cursor.rowcount}条")
            return cursor.rowcount

        finally:
            cursor.close()

    def delete_by_task_no(self, task_no: str) -> int:
        """
        删除指定任务的补货明细

        Returns:
            删除的行数
        """
        conn = self._get_connection()
        cursor = conn.cursor()

        try:
            sql = "DELETE FROM ai_replenishment_detail WHERE task_no = %s"
            cursor.execute(sql, (task_no,))
            conn.commit()

            logger.info(f"删除补货明细: task_no={task_no}, rows={cursor.rowcount}")
            return cursor.rowcount

        finally:
            cursor.close()

    def find_by_task_no(self, task_no: str) -> List[ReplenishmentDetail]:
        """根据 task_no 查询补货明细"""
        conn = self._get_connection()
        cursor = conn.cursor(dictionary=True)

        try:
            sql = "SELECT * FROM ai_replenishment_detail WHERE task_no = %s"
            cursor.execute(sql, (task_no,))
            rows = cursor.fetchall()

            return [ReplenishmentDetail(**row) for row in rows]

        finally:
            cursor.close()