agent.py 6.2 KB
"""
SQL Agent 主类模块

组合 Executor 和 Analyzer 提供完整的 SQL Agent 功能
"""

import logging
from typing import Any, Dict, List, Optional, Tuple
from decimal import Decimal

from .executor import SQLExecutor
from .analyzer import PartAnalyzer
from .prompts import load_prompt
from ...models import SQLExecutionResult, ReplenishmentSuggestion, PartAnalysisResult

logger = logging.getLogger(__name__)


class SQLAgent:
    """SQL Agent - 组合 SQL 执行和配件分析功能"""

    def __init__(self, db_connection=None):
        self._executor = SQLExecutor(db_connection)
        self._analyzer = PartAnalyzer()

    def close(self):
        """关闭连接"""
        self._executor.close()

    # === 委托给 SQLExecutor 的方法 ===

    def generate_sql(
        self,
        question: str,
        context: Optional[Dict] = None,
        previous_error: Optional[str] = None,
    ) -> Tuple[str, str]:
        """使用LLM生成SQL"""
        return self._executor.generate_sql(question, context, previous_error)

    def execute_sql(self, sql: str) -> Tuple[bool, Any, Optional[str]]:
        """执行SQL查询"""
        return self._executor.execute_sql(sql)

    def query_with_retry(
        self,
        question: str,
        context: Optional[Dict] = None,
    ) -> SQLExecutionResult:
        """带重试的查询"""
        return self._executor.query_with_retry(question, context)

    # === 委托给 PartAnalyzer 的方法 ===

    def group_parts_by_code(self, part_ratios: List[Dict]) -> Dict[str, List[Dict]]:
        """按配件编码分组"""
        return self._analyzer.group_parts_by_code(part_ratios)

    def generate_suggestions(
        self,
        part_data: List[Dict],
        dealer_grouping_id: int,
        dealer_grouping_name: str,
        statistics_date: str,
    ) -> Tuple[List[ReplenishmentSuggestion], Dict]:
        """生成补货建议"""
        return self._analyzer.generate_suggestions(
            part_data, dealer_grouping_id, dealer_grouping_name, statistics_date
        )

    def analyze_parts_by_group(
        self,
        part_ratios: List[Dict],
        dealer_grouping_id: int,
        dealer_grouping_name: str,
        statistics_date: str,
        target_ratio: Decimal = Decimal("1.3"),
        limit: Optional[int] = None,
        callback: Optional[Any] = None,
    ) -> Tuple[List[ReplenishmentSuggestion], List[PartAnalysisResult], Dict]:
        """按配件分组分析补货建议"""
        return self._analyzer.analyze_parts_by_group(
            part_ratios,
            dealer_grouping_id,
            dealer_grouping_name,
            statistics_date,
            target_ratio,
            limit,
            callback,
        )

    # === 数据查询方法 ===

    def fetch_part_ratios(
        self,
        group_id: int,
        dealer_grouping_id: int,
        statistics_date: str,
    ) -> List[Dict]:
        """
        查询part_ratio数据

        Args:
            group_id: 集团ID
            dealer_grouping_id: 商家组合ID
            statistics_date: 统计日期
        Returns:
            配件库销比数据列表
        """
        conn = self._executor._get_connection()
        cursor = conn.cursor(dictionary=True)

        try:
            # 1. 查询商家组合关联的品牌组合配置
            brand_grouping_ids = []
            try:
                cursor.execute(
                    "SELECT part_purchase_brand_assemble_id FROM artificial_region_dealer WHERE id = %s",
                    (dealer_grouping_id,)
                )
                rows = cursor.fetchall()
                for row in rows:
                    bid = row.get("part_purchase_brand_assemble_id")
                    if bid:
                        brand_grouping_ids.append(bid)
                if brand_grouping_ids:
                    logger.info(f"商家组合关联品牌组合: dealer_grouping_id={dealer_grouping_id} -> brand_grouping_ids={brand_grouping_ids}")
            except Exception as e:
                logger.warning(f"查询商家组合配置失败: {e}")

            sql = """
                SELECT 
                    id, group_id, brand_id, brand_grouping_id,
                    dealer_grouping_id,
                    supplier_id, supplier_name, area_id, area_name,
                    shop_id, shop_name, part_id, part_code, part_name,
                    unit, cost_price,
                    in_stock_unlocked_cnt, has_plan_cnt, on_the_way_cnt,
                    out_stock_cnt, buy_cnt, storage_locked_cnt,
                    out_stock_ongoing_cnt, stock_age, out_times, out_duration,
                    transfer_cnt, gen_transfer_cnt,
                    part_biz_type, statistics_date,
                    (in_stock_unlocked_cnt + on_the_way_cnt + has_plan_cnt + transfer_cnt + gen_transfer_cnt) as valid_storage_cnt,
                    ((out_stock_cnt + storage_locked_cnt + out_stock_ongoing_cnt + buy_cnt) / 3) as avg_sales_cnt
                FROM part_ratio
                WHERE group_id = %s
                AND dealer_grouping_id = %s
                AND statistics_date = %s
                AND part_biz_type = 1
            """
            params = [group_id, dealer_grouping_id, statistics_date]

            # 如果有配置的品牌组合,用 IN 过滤
            if brand_grouping_ids:
                placeholders = ", ".join(["%s"] * len(brand_grouping_ids))
                sql += f" AND brand_grouping_id IN ({placeholders})"
                params.extend(brand_grouping_ids)

            # 优先处理有销量的配件
            sql += """ ORDER BY 
                CASE WHEN ((out_stock_cnt + storage_locked_cnt + out_stock_ongoing_cnt + buy_cnt) / 3) > 0 THEN 0 ELSE 1 END,
                (in_stock_unlocked_cnt + on_the_way_cnt + has_plan_cnt + transfer_cnt + gen_transfer_cnt) / NULLIF((out_stock_cnt + storage_locked_cnt + out_stock_ongoing_cnt + buy_cnt) / 3, 0) ASC,
                ((out_stock_cnt + storage_locked_cnt + out_stock_ongoing_cnt + buy_cnt) / 3) DESC
            """

            cursor.execute(sql, params)
            rows = cursor.fetchall()

            logger.info(f"获取part_ratio数据: dealer_grouping_id={dealer_grouping_id}, count={len(rows)}")
            return rows

        finally:
            cursor.close()