""" SQL 执行器模块 提供 SQL 执行、重试和错误分类功能 """ import logging import time import json from typing import Any, Dict, List, Optional, Tuple from ..sql_agent.prompts import load_prompt, SQL_AGENT_SYSTEM_PROMPT from ...llm import get_llm_client from ...config import get_settings from ...models import SQLExecutionResult from langchain_core.messages import SystemMessage, HumanMessage logger = logging.getLogger(__name__) class SQLExecutor: """SQL 执行器 - 负责 SQL 生成、执行和重试""" def __init__(self, db_connection=None): self._settings = get_settings() self._conn = db_connection self._llm = get_llm_client() self._max_retries = 3 self._base_delay = 1.0 def _get_connection(self): """获取数据库连接""" if self._conn is None: from ...services.db import get_connection self._conn = get_connection() return self._conn def close(self): """关闭连接""" if self._conn and hasattr(self._conn, 'close'): self._conn.close() self._conn = None def generate_sql( self, question: str, context: Optional[Dict] = None, previous_error: Optional[str] = None, ) -> Tuple[str, str]: """ 使用LLM生成SQL Args: question: 自然语言查询需求 context: 上下文信息 previous_error: 上一次执行的错误(用于重试) Returns: (SQL语句, 解释说明) """ user_prompt = question if context: user_prompt += f"\n\n上下文信息:\n{json.dumps(context, ensure_ascii=False, indent=2)}" if previous_error: user_prompt += f"\n\n上一次执行错误:\n{previous_error}\n请修正SQL语句。" messages = [ SystemMessage(content=SQL_AGENT_SYSTEM_PROMPT), HumanMessage(content=user_prompt), ] response = self._llm.invoke(messages) content = response.content.strip() try: # 提取JSON if "```json" in content: content = content.split("```json")[1].split("```")[0].strip() elif "```" in content: content = content.split("```")[1].split("```")[0].strip() result = json.loads(content) return result.get("sql", ""), result.get("explanation", "") except json.JSONDecodeError: logger.warning(f"无法解析LLM响应为JSON: {content[:200]}") # 尝试直接提取SQL if "SELECT" in content.upper(): lines = content.split("\n") for line in lines: if "SELECT" in line.upper(): return line.strip(), "直接提取的SQL" return "", "解析失败" def execute_sql(self, sql: str) -> Tuple[bool, Any, Optional[str]]: """ 执行SQL查询 Returns: (成功标志, 数据/None, 错误信息/None) """ if not sql or not sql.strip(): return False, None, "SQL语句为空" # 安全检查:只允许SELECT语句 sql_upper = sql.upper().strip() if not sql_upper.startswith("SELECT"): return False, None, "只允许执行SELECT查询" conn = self._get_connection() cursor = conn.cursor(dictionary=True) try: start_time = time.time() cursor.execute(sql) rows = cursor.fetchall() execution_time = int((time.time() - start_time) * 1000) logger.info(f"SQL执行成功: 返回{len(rows)}行, 耗时{execution_time}ms") return True, rows, None except Exception as e: error_msg = str(e) logger.error(f"SQL执行失败: {error_msg}") return False, None, error_msg finally: cursor.close() def query_with_retry( self, question: str, context: Optional[Dict] = None, ) -> SQLExecutionResult: """ 带重试的查询 错误类型区分: - SyntaxError: 立即重试,让LLM修正SQL - OperationalError: 指数退避重试(连接超时、死锁等) Args: question: 查询问题 context: 上下文 Returns: SQLExecutionResult """ start_time = time.time() last_error = None last_sql = "" for attempt in range(self._max_retries): # 生成SQL sql, explanation = self.generate_sql(question, context, last_error) last_sql = sql if not sql: last_error = "LLM未能生成有效的SQL语句" logger.warning(f"重试 {attempt + 1}/{self._max_retries}: {last_error}") if attempt < self._max_retries - 1: delay = self._base_delay * (2 ** attempt) time.sleep(delay) continue logger.info(f"尝试 {attempt + 1}: 执行SQL: {sql[:100]}...") # 执行SQL success, data, error, error_type = self._execute_sql_with_type(sql) if success: execution_time = int((time.time() - start_time) * 1000) return SQLExecutionResult( success=True, sql=sql, data=data, retry_count=attempt, execution_time_ms=execution_time, ) last_error = error logger.warning(f"重试 {attempt + 1}/{self._max_retries}: [{error_type}] {error}") if attempt < self._max_retries - 1: if error_type == "syntax": # 语法错误:立即重试,不等待 logger.info("SQL语法错误,立即重试让LLM修正") else: # 连接/死锁等操作错误:指数退避 delay = self._base_delay * (2 ** attempt) logger.info(f"操作错误,等待 {delay}秒 后重试...") time.sleep(delay) execution_time = int((time.time() - start_time) * 1000) return SQLExecutionResult( success=False, sql=last_sql, error=last_error, retry_count=self._max_retries, execution_time_ms=execution_time, ) def _execute_sql_with_type(self, sql: str) -> Tuple[bool, Any, Optional[str], str]: """ 执行SQL并返回错误类型 Returns: (成功标志, 数据/None, 错误信息/None, 错误类型) 错误类型: "syntax" | "operational" | "unknown" """ if not sql or not sql.strip(): return False, None, "SQL语句为空", "unknown" # 安全检查:只允许SELECT语句 sql_upper = sql.upper().strip() if not sql_upper.startswith("SELECT"): return False, None, "只允许执行SELECT查询", "syntax" conn = self._get_connection() cursor = conn.cursor(dictionary=True) try: start_time = time.time() cursor.execute(sql) rows = cursor.fetchall() execution_time = int((time.time() - start_time) * 1000) logger.info(f"SQL执行成功: 返回{len(rows)}行, 耗时{execution_time}ms") return True, rows, None, "" except Exception as e: error_msg = str(e) error_type = self._classify_error(e) logger.error(f"SQL执行失败 [{error_type}]: {error_msg}") return False, None, error_msg, error_type finally: cursor.close() def _classify_error(self, error: Exception) -> str: """ 分类SQL错误类型 Returns: "syntax" - 语法错误,需要LLM修正 "operational" - 操作错误,需要指数退避 "unknown" - 未知错误 """ error_msg = str(error).lower() error_class = type(error).__name__ # MySQL语法错误特征 syntax_keywords = [ "syntax error", "you have an error in your sql syntax", "unknown column", "unknown table", "doesn't exist", "ambiguous column", "invalid", "near", ] # 操作错误特征(需要退避) operational_keywords = [ "timeout", "timed out", "connection", "deadlock", "lock wait", "too many connections", "gone away", "lost connection", "can't connect", ] for keyword in syntax_keywords: if keyword in error_msg: return "syntax" for keyword in operational_keywords: if keyword in error_msg: return "operational" # 根据异常类型判断 if "ProgrammingError" in error_class or "InterfaceError" in error_class: return "syntax" if "OperationalError" in error_class: return "operational" return "unknown"