analysis_report_node.py 8.49 KB
"""
分析报告生成节点

在补货建议工作流的最后一个节点执行,生成结构化分析报告
"""

import logging
import time
import json
import os
from typing import Dict, Any
from decimal import Decimal
from datetime import datetime

from langchain_core.messages import HumanMessage

from ..llm import get_llm_client
from ..models import AnalysisReport
from ..services.result_writer import ResultWriter

logger = logging.getLogger(__name__)


def _load_prompt(filename: str) -> str:
    """从prompts目录加载提示词文件"""
    prompts_dir = os.path.join(
        os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(__file__)))),
        "prompts"
    )
    filepath = os.path.join(prompts_dir, filename)
    
    if not os.path.exists(filepath):
        raise FileNotFoundError(f"Prompt文件未找到: {filepath}")
    
    with open(filepath, "r", encoding="utf-8") as f:
        return f.read()


def _calculate_risk_stats(part_ratios: list) -> dict:
    """计算风险统计数据"""
    stats = {
        "shortage_cnt": 0,
        "shortage_amount": Decimal("0"),
        "stagnant_cnt": 0,
        "stagnant_amount": Decimal("0"),
        "low_freq_cnt": 0,
        "low_freq_amount": Decimal("0"),
    }
    
    for pr in part_ratios:
        valid_storage = Decimal(str(pr.get("valid_storage_cnt", 0) or 0))
        avg_sales = Decimal(str(pr.get("avg_sales_cnt", 0) or 0))
        out_stock = Decimal(str(pr.get("out_stock_cnt", 0) or 0))
        cost_price = Decimal(str(pr.get("cost_price", 0) or 0))
        
        # 呆滞件: 有库存但90天无出库
        if valid_storage > 0 and out_stock == 0:
            stats["stagnant_cnt"] += 1
            stats["stagnant_amount"] += valid_storage * cost_price
        
        # 低频件: 无库存且月均销量<1
        elif valid_storage == 0 and avg_sales < 1:
            stats["low_freq_cnt"] += 1
        
        # 缺货件: 无库存且月均销量>=1
        elif valid_storage == 0 and avg_sales >= 1:
            stats["shortage_cnt"] += 1
            # 缺货损失估算:月均销量 * 成本价
            stats["shortage_amount"] += avg_sales * cost_price
    
    return stats


def _build_suggestion_summary(part_results: list, allocated_details: list) -> str:
    """构建补货建议汇总文本"""
    if not part_results and not allocated_details:
        return "暂无补货建议"
    
    lines = []
    total_cnt = 0
    total_amount = Decimal("0")
    
    # 优先使用 part_results (配件级汇总)
    if part_results:
        for pr in part_results[:10]:  # 只取前10个
            if hasattr(pr, "part_code"):
                lines.append(
                    f"- {pr.part_code} {pr.part_name}: "
                    f"建议{pr.total_suggest_cnt}件, "
                    f"金额{pr.total_suggest_amount:.2f}元, "
                    f"优先级{pr.priority}"
                )
                total_cnt += pr.total_suggest_cnt
                total_amount += pr.total_suggest_amount
            elif isinstance(pr, dict):
                lines.append(
                    f"- {pr.get('part_code', '')} {pr.get('part_name', '')}: "
                    f"建议{pr.get('total_suggest_cnt', 0)}件, "
                    f"金额{pr.get('total_suggest_amount', 0):.2f}元"
                )
    
    lines.insert(0, f"**总计**: {total_cnt}件配件, 金额{total_amount:.2f}元\n")
    return "\n".join(lines)


def generate_analysis_report_node(state: dict) -> dict:
    """
    生成分析报告节点
    
    输入: part_ratios, llm_suggestions, allocated_details, part_results
    输出: analysis_report
    """
    start_time = time.time()
    
    task_no = state.get("task_no", "")
    group_id = state.get("group_id", 0)
    dealer_grouping_id = state.get("dealer_grouping_id", 0)
    dealer_grouping_name = state.get("dealer_grouping_name", "")
    brand_grouping_id = state.get("brand_grouping_id")
    statistics_date = state.get("statistics_date", "")
    
    part_ratios = state.get("part_ratios", [])
    part_results = state.get("part_results", [])
    allocated_details = state.get("allocated_details", [])
    
    logger.info(f"[{task_no}] 开始生成分析报告: dealer={dealer_grouping_name}")
    
    try:
        # 计算风险统计
        risk_stats = _calculate_risk_stats(part_ratios)
        
        # 构建建议汇总
        suggestion_summary = _build_suggestion_summary(part_results, allocated_details)
        
        # 加载 Prompt
        prompt_template = _load_prompt("analysis_report.md")
        
        # 填充 Prompt 变量
        prompt = prompt_template.format(
            dealer_grouping_id=dealer_grouping_id,
            dealer_grouping_name=dealer_grouping_name,
            statistics_date=statistics_date,
            suggestion_summary=suggestion_summary,
            shortage_cnt=risk_stats["shortage_cnt"],
            shortage_amount=f"{risk_stats['shortage_amount']:.2f}",
            stagnant_cnt=risk_stats["stagnant_cnt"],
            stagnant_amount=f"{risk_stats['stagnant_amount']:.2f}",
            low_freq_cnt=risk_stats["low_freq_cnt"],
            low_freq_amount="0.00",  # 低频件无库存
        )
        
        # 调用 LLM
        llm_client = get_llm_client()
        response = llm_client.invoke(
            messages=[HumanMessage(content=prompt)],
        )
        
        # 解析 JSON 响应
        response_text = response.content.strip()
        # 移除可能的 markdown 代码块
        if response_text.startswith("```"):
            lines = response_text.split("\n")
            response_text = "\n".join(lines[1:-1])
        
        report_data = json.loads(response_text)
        
        # 计算统计信息
        total_suggest_cnt = sum(
            d.suggest_cnt if hasattr(d, "suggest_cnt") else d.get("suggest_cnt", 0)
            for d in allocated_details
        )
        total_suggest_amount = sum(
            d.suggest_amount if hasattr(d, "suggest_amount") else Decimal(str(d.get("suggest_amount", 0)))
            for d in allocated_details
        )
        
        execution_time_ms = int((time.time() - start_time) * 1000)
        
        # 创建报告对象
        # 新 prompt 字段名映射到现有数据库字段:
        # overall_assessment -> replenishment_insights
        # risk_alerts -> urgency_assessment
        # procurement_strategy -> strategy_recommendations
        # expected_impact -> expected_outcomes
        # execution_guide 已移除,置为 None
        report = AnalysisReport(
            task_no=task_no,
            group_id=group_id,
            dealer_grouping_id=dealer_grouping_id,
            dealer_grouping_name=dealer_grouping_name,
            brand_grouping_id=brand_grouping_id,
            report_type="replenishment",
            replenishment_insights=report_data.get("overall_assessment"),
            urgency_assessment=report_data.get("risk_alerts"),
            strategy_recommendations=report_data.get("procurement_strategy"),
            execution_guide=None,
            expected_outcomes=report_data.get("expected_impact"),
            total_suggest_cnt=total_suggest_cnt,
            total_suggest_amount=total_suggest_amount,
            shortage_risk_cnt=risk_stats["shortage_cnt"],
            excess_risk_cnt=risk_stats["stagnant_cnt"],
            stagnant_cnt=risk_stats["stagnant_cnt"],
            low_freq_cnt=risk_stats["low_freq_cnt"],
            llm_provider=getattr(llm_client, "provider", ""),
            llm_model=getattr(llm_client, "model", ""),
            llm_tokens=response.usage.total_tokens,
            execution_time_ms=execution_time_ms,
            statistics_date=statistics_date,
        )
        
        # 保存到数据库
        result_writer = ResultWriter()
        try:
            result_writer.save_analysis_report(report)
        finally:
            result_writer.close()
        
        logger.info(
            f"[{task_no}] 分析报告生成完成: "
            f"shortage={risk_stats['shortage_cnt']}, "
            f"stagnant={risk_stats['stagnant_cnt']}, "
            f"time={execution_time_ms}ms"
        )
        
        return {
            "analysis_report": report.to_dict(),
            "end_time": time.time(),
        }
        
    except Exception as e:
        logger.error(f"[{task_no}] 分析报告生成失败: {e}", exc_info=True)
        
        # 返回空报告,不中断整个流程
        return {
            "analysis_report": {
                "error": str(e),
                "task_no": task_no,
            },
            "end_time": time.time(),
        }