single_part_queue.py 12 KB
"""
单种配件补货持久化队列服务

基于 MySQL 表 + asyncio worker 实现持久化异步队列。
支持并行处理、失败重试、崩溃恢复。
"""

import asyncio
import logging
from datetime import datetime, timedelta
from typing import Optional

from .db import get_connection

logger = logging.getLogger(__name__)

# 队列任务状态
QUEUE_PENDING = 0
QUEUE_PROCESSING = 1
QUEUE_SUCCESS = 2
QUEUE_FAILED = 3

# 建表 DDL
CREATE_TABLE_SQL = """
CREATE TABLE IF NOT EXISTS ai_single_part_queue (
    id                 BIGINT AUTO_INCREMENT PRIMARY KEY,
    part_code          VARCHAR(64)  NOT NULL COMMENT '配件编码',
    group_id           BIGINT       NOT NULL COMMENT '集团ID',
    dealer_grouping_id BIGINT       NOT NULL COMMENT '商家组合ID',
    status             TINYINT      DEFAULT 0 COMMENT '0-待处理 1-处理中 2-成功 3-失败',
    retry_count        INT          DEFAULT 0 COMMENT '已重试次数',
    max_retries        INT          DEFAULT 3 COMMENT '最大重试次数',
    error_message      TEXT         COMMENT '错误信息',
    task_no            VARCHAR(32)  COMMENT '关联的任务编号',
    request_time       DATETIME     NOT NULL COMMENT '请求时间',
    start_time         DATETIME     COMMENT '开始处理时间',
    end_time           DATETIME     COMMENT '处理完成时间',
    next_retry_time    DATETIME     COMMENT '下次重试时间',
    create_time        DATETIME     DEFAULT CURRENT_TIMESTAMP,
    update_time        DATETIME     DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
    INDEX idx_status (status, next_retry_time),
    INDEX idx_part (dealer_grouping_id, part_code)
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COMMENT='单种配件补货队列表';
"""


class SinglePartQueue:
    """MySQL 持久化 + asyncio worker 异步队列"""

    def __init__(self, worker_count: int = 3, poll_interval: float = 2.0):
        self._worker_count = worker_count
        self._poll_interval = poll_interval
        self._workers: list[asyncio.Task] = []
        self._running = False
        self._event = asyncio.Event()

    def ensure_table(self):
        """确保队列表存在"""
        conn = get_connection()
        cursor = conn.cursor()
        try:
            cursor.execute(CREATE_TABLE_SQL)
            conn.commit()
            logger.info("队列表 ai_single_part_queue 已就绪")
        finally:
            cursor.close()
            conn.close()

    async def start(self):
        """启动 worker"""
        self.ensure_table()

        # 恢复中断的任务(状态=处理中 → 待处理)
        self._recover_interrupted()

        self._running = True
        for i in range(self._worker_count):
            task = asyncio.create_task(self._worker(i))
            self._workers.append(task)
        logger.info(f"单配件补货队列已启动: workers={self._worker_count}")

    async def stop(self):
        """停止 worker"""
        self._running = False
        self._event.set()
        for task in self._workers:
            task.cancel()
        if self._workers:
            await asyncio.gather(*self._workers, return_exceptions=True)
        self._workers.clear()
        logger.info("单配件补货队列已停止")

    async def submit(
        self,
        part_code: str,
        group_id: int,
        dealer_grouping_id: int,
    ) -> int:
        """
        提交任务到队列

        Returns:
            队列记录ID
        """
        conn = get_connection()
        cursor = conn.cursor()
        try:
            sql = """
                INSERT INTO ai_single_part_queue 
                (part_code, group_id, dealer_grouping_id, status, request_time)
                VALUES (%s, %s, %s, %s, %s)
            """
            now = datetime.now()
            cursor.execute(sql, (part_code, group_id, dealer_grouping_id, QUEUE_PENDING, now))
            conn.commit()
            queue_id = cursor.lastrowid
            logger.info(
                f"任务入队: id={queue_id}, part_code={part_code}, "
                f"dealer_grouping_id={dealer_grouping_id}"
            )
            # 通知 worker 有新任务
            self._event.set()
            return queue_id
        finally:
            cursor.close()
            conn.close()

    async def _worker(self, worker_id: int):
        """worker 循环:轮询 → 抢任务 → 处理"""
        logger.info(f"Worker-{worker_id} 已启动")
        while self._running:
            try:
                item = self._claim_task()
                if item is None:
                    # 无任务,等待通知或超时
                    self._event.clear()
                    try:
                        await asyncio.wait_for(
                            self._event.wait(), timeout=self._poll_interval
                        )
                    except asyncio.TimeoutError:
                        pass
                    continue

                logger.info(
                    f"Worker-{worker_id} 开始处理: id={item['id']}, "
                    f"part_code={item['part_code']}"
                )
                await self._process(item, worker_id)

            except asyncio.CancelledError:
                break
            except Exception as e:
                logger.error(f"Worker-{worker_id} 异常: {e}", exc_info=True)
                await asyncio.sleep(1)

        logger.info(f"Worker-{worker_id} 已停止")

    def _claim_task(self) -> Optional[dict]:
        """原子抢任务:SELECT FOR UPDATE + 立即更新状态"""
        conn = get_connection()
        cursor = conn.cursor(dictionary=True)
        try:
            conn.start_transaction()

            sql = """
                SELECT id, part_code, group_id, dealer_grouping_id, retry_count, max_retries
                FROM ai_single_part_queue
                WHERE (status = %s OR (status = %s AND retry_count < max_retries AND next_retry_time <= NOW()))
                ORDER BY id ASC
                LIMIT 1
                FOR UPDATE SKIP LOCKED
            """
            cursor.execute(sql, (QUEUE_PENDING, QUEUE_FAILED))
            row = cursor.fetchone()

            if row is None:
                conn.commit()
                return None

            # 标记为处理中
            update_sql = """
                UPDATE ai_single_part_queue 
                SET status = %s, start_time = NOW()
                WHERE id = %s
            """
            cursor.execute(update_sql, (QUEUE_PROCESSING, row["id"]))
            conn.commit()

            return row

        except Exception as e:
            conn.rollback()
            logger.error(f"抢任务失败: {e}")
            return None
        finally:
            cursor.close()
            conn.close()

    async def _process(self, item: dict, worker_id: int):
        """处理单个任务"""
        queue_id = item["id"]
        try:
            # 在线程池中执行同步的补货逻辑
            loop = asyncio.get_event_loop()
            task_no = await loop.run_in_executor(
                None, self._execute_replenishment, item
            )
            self._complete_task(queue_id, task_no)
            logger.info(
                f"Worker-{worker_id} 处理完成: id={queue_id}, task_no={task_no}"
            )
        except Exception as e:
            retry_count = item.get("retry_count", 0) + 1
            max_retries = item.get("max_retries", 3)
            self._fail_task(queue_id, str(e), retry_count, max_retries)
            logger.error(
                f"Worker-{worker_id} 处理失败: id={queue_id}, "
                f"retry={retry_count}/{max_retries}, error={e}"
            )

    def _execute_replenishment(self, item: dict) -> str:
        """执行补货逻辑(同步,在线程池中调用)"""
        from ..agent import ReplenishmentAgent

        agent = ReplenishmentAgent()
        try:
            final_state = agent.run_single_part(
                group_id=item["group_id"],
                dealer_grouping_id=item["dealer_grouping_id"],
                part_code=item["part_code"],
            )
            return final_state.get("task_no", "")
        finally:
            agent._result_writer.close()

    def _complete_task(self, queue_id: int, task_no: str):
        """标记任务为成功"""
        conn = get_connection()
        cursor = conn.cursor()
        try:
            sql = """
                UPDATE ai_single_part_queue 
                SET status = %s, task_no = %s, end_time = NOW()
                WHERE id = %s
            """
            cursor.execute(sql, (QUEUE_SUCCESS, task_no, queue_id))
            conn.commit()
        finally:
            cursor.close()
            conn.close()

    def _fail_task(self, queue_id: int, error: str, retry_count: int, max_retries: int):
        """标记任务失败,如可重试则设置下次重试时间"""
        conn = get_connection()
        cursor = conn.cursor()
        try:
            if retry_count < max_retries:
                # 指数退避:2^retry_count 秒
                delay = 2 ** retry_count
                next_retry = datetime.now() + timedelta(seconds=delay)
                sql = """
                    UPDATE ai_single_part_queue 
                    SET status = %s, error_message = %s, retry_count = %s, 
                        next_retry_time = %s, end_time = NOW()
                    WHERE id = %s
                """
                cursor.execute(sql, (QUEUE_FAILED, error, retry_count, next_retry, queue_id))
                logger.info(f"任务将在 {delay}s 后重试: id={queue_id}, retry={retry_count}")
            else:
                sql = """
                    UPDATE ai_single_part_queue 
                    SET status = %s, error_message = %s, retry_count = %s, end_time = NOW()
                    WHERE id = %s
                """
                cursor.execute(sql, (QUEUE_FAILED, error, retry_count, queue_id))
                logger.warning(f"任务已达最大重试次数: id={queue_id}")
            conn.commit()
        finally:
            cursor.close()
            conn.close()

    def _recover_interrupted(self):
        """恢复被中断的任务(处理中 → 待处理)"""
        conn = get_connection()
        cursor = conn.cursor()
        try:
            sql = """
                UPDATE ai_single_part_queue 
                SET status = %s, start_time = NULL
                WHERE status = %s
            """
            cursor.execute(sql, (QUEUE_PENDING, QUEUE_PROCESSING))
            conn.commit()
            if cursor.rowcount > 0:
                logger.info(f"恢复中断任务: {cursor.rowcount} 条")
        finally:
            cursor.close()
            conn.close()

    def get_status(self) -> dict:
        """获取队列状态"""
        conn = get_connection()
        cursor = conn.cursor(dictionary=True)
        try:
            sql = """
                SELECT 
                    status,
                    COUNT(*) as count
                FROM ai_single_part_queue
                GROUP BY status
            """
            cursor.execute(sql)
            rows = cursor.fetchall()

            status_map = {
                QUEUE_PENDING: "pending",
                QUEUE_PROCESSING: "processing",
                QUEUE_SUCCESS: "success",
                QUEUE_FAILED: "failed",
            }

            result = {
                "workers": self._worker_count,
                "running": self._running,
                "pending": 0,
                "processing": 0,
                "success": 0,
                "failed": 0,
            }
            for row in rows:
                key = status_map.get(row["status"], "unknown")
                result[key] = row["count"]

            return result
        finally:
            cursor.close()
            conn.close()


# 全局单例
_queue_instance: Optional[SinglePartQueue] = None


def get_single_part_queue(worker_count: int = 3) -> SinglePartQueue:
    """获取全局队列实例"""
    global _queue_instance
    if _queue_instance is None:
        _queue_instance = SinglePartQueue(worker_count=worker_count)
    return _queue_instance