executor.py 8.98 KB
"""
SQL 执行器模块

提供 SQL 执行、重试和错误分类功能
"""

import logging
import time
import json
from typing import Any, Dict, List, Optional, Tuple

from ..sql_agent.prompts import load_prompt, SQL_AGENT_SYSTEM_PROMPT
from ...llm import get_llm_client
from ...config import get_settings
from ...models import SQLExecutionResult

from langchain_core.messages import SystemMessage, HumanMessage

logger = logging.getLogger(__name__)


class SQLExecutor:
    """SQL 执行器 - 负责 SQL 生成、执行和重试"""

    def __init__(self, db_connection=None):
        self._settings = get_settings()
        self._conn = db_connection
        self._llm = get_llm_client()
        self._max_retries = 3
        self._base_delay = 1.0

    def _get_connection(self):
        """获取数据库连接"""
        if self._conn is None:
            from ...services.db import get_connection
            self._conn = get_connection()
        return self._conn

    def close(self):
        """关闭连接"""
        if self._conn and hasattr(self._conn, 'close'):
            self._conn.close()
            self._conn = None

    def generate_sql(
        self,
        question: str,
        context: Optional[Dict] = None,
        previous_error: Optional[str] = None,
    ) -> Tuple[str, str]:
        """
        使用LLM生成SQL

        Args:
            question: 自然语言查询需求
            context: 上下文信息
            previous_error: 上一次执行的错误(用于重试)

        Returns:
            (SQL语句, 解释说明)
        """
        user_prompt = question

        if context:
            user_prompt += f"\n\n上下文信息:\n{json.dumps(context, ensure_ascii=False, indent=2)}"

        if previous_error:
            user_prompt += f"\n\n上一次执行错误:\n{previous_error}\n请修正SQL语句。"

        messages = [
            SystemMessage(content=SQL_AGENT_SYSTEM_PROMPT),
            HumanMessage(content=user_prompt),
        ]

        response = self._llm.invoke(messages)
        content = response.content.strip()

        try:
            # 提取JSON
            if "```json" in content:
                content = content.split("```json")[1].split("```")[0].strip()
            elif "```" in content:
                content = content.split("```")[1].split("```")[0].strip()

            result = json.loads(content)
            return result.get("sql", ""), result.get("explanation", "")
        except json.JSONDecodeError:
            logger.warning(f"无法解析LLM响应为JSON: {content[:200]}")
            # 尝试直接提取SQL
            if "SELECT" in content.upper():
                lines = content.split("\n")
                for line in lines:
                    if "SELECT" in line.upper():
                        return line.strip(), "直接提取的SQL"
            return "", "解析失败"

    def execute_sql(self, sql: str) -> Tuple[bool, Any, Optional[str]]:
        """
        执行SQL查询

        Returns:
            (成功标志, 数据/None, 错误信息/None)
        """
        if not sql or not sql.strip():
            return False, None, "SQL语句为空"

        # 安全检查:只允许SELECT语句
        sql_upper = sql.upper().strip()
        if not sql_upper.startswith("SELECT"):
            return False, None, "只允许执行SELECT查询"

        conn = self._get_connection()
        cursor = conn.cursor(dictionary=True)

        try:
            start_time = time.time()
            cursor.execute(sql)
            rows = cursor.fetchall()
            execution_time = int((time.time() - start_time) * 1000)

            logger.info(f"SQL执行成功: 返回{len(rows)}行, 耗时{execution_time}ms")
            return True, rows, None

        except Exception as e:
            error_msg = str(e)
            logger.error(f"SQL执行失败: {error_msg}")
            return False, None, error_msg

        finally:
            cursor.close()

    def query_with_retry(
        self,
        question: str,
        context: Optional[Dict] = None,
    ) -> SQLExecutionResult:
        """
        带重试的查询

        错误类型区分:
        - SyntaxError: 立即重试,让LLM修正SQL
        - OperationalError: 指数退避重试(连接超时、死锁等)

        Args:
            question: 查询问题
            context: 上下文

        Returns:
            SQLExecutionResult
        """
        start_time = time.time()
        last_error = None
        last_sql = ""

        for attempt in range(self._max_retries):
            # 生成SQL
            sql, explanation = self.generate_sql(question, context, last_error)
            last_sql = sql

            if not sql:
                last_error = "LLM未能生成有效的SQL语句"
                logger.warning(f"重试 {attempt + 1}/{self._max_retries}: {last_error}")

                if attempt < self._max_retries - 1:
                    delay = self._base_delay * (2 ** attempt)
                    time.sleep(delay)
                continue

            logger.info(f"尝试 {attempt + 1}: 执行SQL: {sql[:100]}...")

            # 执行SQL
            success, data, error, error_type = self._execute_sql_with_type(sql)

            if success:
                execution_time = int((time.time() - start_time) * 1000)
                return SQLExecutionResult(
                    success=True,
                    sql=sql,
                    data=data,
                    retry_count=attempt,
                    execution_time_ms=execution_time,
                )

            last_error = error
            logger.warning(f"重试 {attempt + 1}/{self._max_retries}: [{error_type}] {error}")

            if attempt < self._max_retries - 1:
                if error_type == "syntax":
                    # 语法错误:立即重试,不等待
                    logger.info("SQL语法错误,立即重试让LLM修正")
                else:
                    # 连接/死锁等操作错误:指数退避
                    delay = self._base_delay * (2 ** attempt)
                    logger.info(f"操作错误,等待 {delay}秒 后重试...")
                    time.sleep(delay)

        execution_time = int((time.time() - start_time) * 1000)
        return SQLExecutionResult(
            success=False,
            sql=last_sql,
            error=last_error,
            retry_count=self._max_retries,
            execution_time_ms=execution_time,
        )

    def _execute_sql_with_type(self, sql: str) -> Tuple[bool, Any, Optional[str], str]:
        """
        执行SQL并返回错误类型

        Returns:
            (成功标志, 数据/None, 错误信息/None, 错误类型)
            错误类型: "syntax" | "operational" | "unknown"
        """
        if not sql or not sql.strip():
            return False, None, "SQL语句为空", "unknown"

        # 安全检查:只允许SELECT语句
        sql_upper = sql.upper().strip()
        if not sql_upper.startswith("SELECT"):
            return False, None, "只允许执行SELECT查询", "syntax"

        conn = self._get_connection()
        cursor = conn.cursor(dictionary=True)

        try:
            start_time = time.time()
            cursor.execute(sql)
            rows = cursor.fetchall()
            execution_time = int((time.time() - start_time) * 1000)

            logger.info(f"SQL执行成功: 返回{len(rows)}行, 耗时{execution_time}ms")
            return True, rows, None, ""

        except Exception as e:
            error_msg = str(e)
            error_type = self._classify_error(e)
            logger.error(f"SQL执行失败 [{error_type}]: {error_msg}")
            return False, None, error_msg, error_type

        finally:
            cursor.close()

    def _classify_error(self, error: Exception) -> str:
        """
        分类SQL错误类型

        Returns:
            "syntax" - 语法错误,需要LLM修正
            "operational" - 操作错误,需要指数退避
            "unknown" - 未知错误
        """
        error_msg = str(error).lower()
        error_class = type(error).__name__

        # MySQL语法错误特征
        syntax_keywords = [
            "syntax error", "you have an error in your sql syntax",
            "unknown column", "unknown table", "doesn't exist",
            "ambiguous column", "invalid", "near",
        ]

        # 操作错误特征(需要退避)
        operational_keywords = [
            "timeout", "timed out", "connection", "deadlock",
            "lock wait", "too many connections", "gone away",
            "lost connection", "can't connect",
        ]

        for keyword in syntax_keywords:
            if keyword in error_msg:
                return "syntax"

        for keyword in operational_keywords:
            if keyword in error_msg:
                return "operational"

        # 根据异常类型判断
        if "ProgrammingError" in error_class or "InterfaceError" in error_class:
            return "syntax"
        if "OperationalError" in error_class:
            return "operational"

        return "unknown"