replenishment.py 4.79 KB
"""
补货建议触发接口
替代原定时任务,提供接口触发补货建议生成
"""

import logging
from typing import Optional

from fastapi import APIRouter, HTTPException
from pydantic import BaseModel

from ...agent import ReplenishmentAgent
from ...services import DataService

logger = logging.getLogger(__name__)
router = APIRouter()


class InitRequest(BaseModel):
    """初始化请求"""
    group_id: int
    dealer_grouping_id: Optional[int] = None


class InitResponse(BaseModel):
    """初始化响应"""
    success: bool
    message: str
    task_count: int = 0


@router.post("/replenishment/init", response_model=InitResponse)
async def init_replenishment(req: InitRequest):
    """
    初始化补货建议

    触发全量补货建议生成。
    - 若指定 dealer_grouping_id,仅处理该商家组合
    - 若未指定,处理 group_id 下所有商家组合
    """
    try:
        agent = ReplenishmentAgent()

        if req.dealer_grouping_id:
            data_service = DataService()
            try:
                groupings = data_service.get_dealer_groupings(req.group_id)
                grouping = next(
                    (g for g in groupings if g["id"] == req.dealer_grouping_id),
                    None,
                )
                if not grouping:
                    raise HTTPException(
                        status_code=404,
                        detail=f"未找到商家组合: {req.dealer_grouping_id}",
                    )
                agent.run(
                    group_id=req.group_id,
                    dealer_grouping_id=grouping["id"],
                    dealer_grouping_name=grouping["name"],
                )
                return InitResponse(
                    success=True,
                    message=f"商家组合 [{grouping['name']}] 补货建议生成完成",
                    task_count=1,
                )
            finally:
                data_service.close()
        else:
            data_service = DataService()
            try:
                groupings = data_service.get_dealer_groupings(req.group_id)
            finally:
                data_service.close()

            task_count = 0
            for grouping in groupings:
                try:
                    agent.run(
                        group_id=req.group_id,
                        dealer_grouping_id=grouping["id"],
                        dealer_grouping_name=grouping["name"],
                    )
                    task_count += 1
                except Exception as e:
                    logger.error(
                        f"商家组合执行失败: {grouping['name']}, error={e}",
                        exc_info=True,
                    )
                    continue

            return InitResponse(
                success=True,
                message=f"补货建议生成完成,共处理 {task_count}/{len(groupings)} 个商家组合",
                task_count=task_count,
            )

    except HTTPException:
        raise
    except Exception as e:
        logger.error(f"补货建议初始化失败: {e}", exc_info=True)
        raise HTTPException(status_code=500, detail=str(e))


# ============ 单种配件补货接口 ============

class SinglePartRequest(BaseModel):
    """单种配件补货请求"""
    part_code: str
    group_id: int
    dealer_grouping_id: int


class SinglePartResponse(BaseModel):
    """单种配件补货响应"""
    success: bool
    message: str
    queue_id: Optional[int] = None


@router.post("/replenishment/single-part", response_model=SinglePartResponse)
async def submit_single_part(req: SinglePartRequest):
    """
    提交单种配件补货任务

    异步处理,立即返回。通过 /replenishment/single-part/status 查询队列状态。
    """
    try:
        from ...services.single_part_queue import get_single_part_queue

        queue = get_single_part_queue()
        queue_id = await queue.submit(
            part_code=req.part_code,
            group_id=req.group_id,
            dealer_grouping_id=req.dealer_grouping_id,
        )
        return SinglePartResponse(
            success=True,
            message=f"已提交单种配件补货任务: {req.part_code}",
            queue_id=queue_id,
        )
    except Exception as e:
        logger.error(f"单配件补货任务提交失败: {e}", exc_info=True)
        raise HTTPException(status_code=500, detail=str(e))


@router.get("/replenishment/single-part/status")
async def get_single_part_status():
    """查询单种配件补货队列状态"""
    try:
        from ...services.single_part_queue import get_single_part_queue

        queue = get_single_part_queue()
        return queue.get_status()
    except Exception as e:
        logger.error(f"查询队列状态失败: {e}", exc_info=True)
        raise HTTPException(status_code=500, detail=str(e))