"""
任务数据访问层

提供 ai_replenishment_task 表的 CRUD 操作
"""

import logging
from datetime import datetime
from typing import Optional

from ..db import get_connection
from ...models import ReplenishmentTask

logger = logging.getLogger(__name__)


class TaskRepository:
    """任务数据访问"""

    def __init__(self, connection=None):
        self._conn = connection

    def _get_connection(self):
        """获取数据库连接"""
        if self._conn is None or not self._conn.is_connected():
            self._conn = get_connection()
        return self._conn

    def close(self):
        """关闭连接"""
        if self._conn and self._conn.is_connected():
            self._conn.close()
            self._conn = None

    def create(self, task: ReplenishmentTask) -> int:
        """
        创建任务记录

        Returns:
            插入的任务ID
        """
        conn = self._get_connection()
        cursor = conn.cursor()

        try:
            sql = """
                INSERT INTO ai_replenishment_task (
                    task_no, group_id, dealer_grouping_id, dealer_grouping_name,
                    brand_grouping_id, plan_amount, actual_amount, part_count,
                    base_ratio, status, error_message, llm_provider, llm_model,
                    llm_total_tokens, statistics_date, start_time, end_time, create_time
                ) VALUES (
                    %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, NOW()
                )
            """

            values = (
                task.task_no,
                task.group_id,
                task.dealer_grouping_id,
                task.dealer_grouping_name,
                task.brand_grouping_id,
                float(task.plan_amount),
                float(task.actual_amount),
                task.part_count,
                float(task.base_ratio) if task.base_ratio else None,
                int(task.status),
                task.error_message,
                task.llm_provider,
                task.llm_model,
                task.llm_total_tokens,
                task.statistics_date,
                datetime.now() if task.start_time is None else task.start_time,
                task.end_time,
            )

            cursor.execute(sql, values)
            conn.commit()

            task_id = cursor.lastrowid
            logger.info(f"创建任务记录: task_no={task.task_no}, id={task_id}")
            return task_id

        finally:
            cursor.close()

    def update(self, task: ReplenishmentTask) -> int:
        """
        更新任务记录

        Returns:
            更新的行数
        """
        conn = self._get_connection()
        cursor = conn.cursor()

        try:
            sql = """
                UPDATE ai_replenishment_task
                SET actual_amount = %s,
                    part_count = %s,
                    base_ratio = %s,
                    status = %s,
                    error_message = %s,
                    llm_provider = %s,
                    llm_model = %s,
                    llm_total_tokens = %s,
                    end_time = %s
                WHERE task_no = %s
            """

            values = (
                float(task.actual_amount),
                task.part_count,
                float(task.base_ratio) if task.base_ratio else None,
                int(task.status),
                task.error_message,
                task.llm_provider,
                task.llm_model,
                task.llm_total_tokens,
                datetime.now() if task.end_time is None else task.end_time,
                task.task_no,
            )

            cursor.execute(sql, values)
            conn.commit()

            logger.info(f"更新任务记录: task_no={task.task_no}, rows={cursor.rowcount}")
            return cursor.rowcount

        finally:
            cursor.close()

    def find_by_task_no(self, task_no: str) -> Optional[ReplenishmentTask]:
        """根据 task_no 查询任务"""
        conn = self._get_connection()
        cursor = conn.cursor(dictionary=True)

        try:
            sql = "SELECT * FROM ai_replenishment_task WHERE task_no = %s"
            cursor.execute(sql, (task_no,))
            row = cursor.fetchone()

            if row:
                return ReplenishmentTask(**row)
            return None

        finally:
            cursor.close()
