""" OpenAI 兼容模式客户端 支持火山引擎等 OpenAI 兼容接口 """ import logging from langchain_core.language_models import BaseChatModel from langchain_core.messages import BaseMessage, AIMessage from langchain_openai import ChatOpenAI from openai import OpenAI from .base import BaseLLMClient, LLMResponse, LLMUsage from ..config import get_settings logger = logging.getLogger(__name__) class OpenAICompatClient(BaseLLMClient): """OpenAI 兼容模式客户端(火山引擎等)""" def __init__(self, api_key: str = None, base_url: str = None, model: str = None): settings = get_settings() self._api_key = api_key or settings.openai_compat_api_key self._base_url = base_url or settings.openai_compat_base_url self._model = model or settings.openai_compat_model self._client = OpenAI( api_key=self._api_key, base_url=self._base_url, ) self._chat_model = None @property def provider(self) -> str: return "openai_compat" @property def model_name(self) -> str: return self._model def get_chat_model(self) -> BaseChatModel: """获取 LangChain 聊天模型""" if self._chat_model is None: self._chat_model = ChatOpenAI( api_key=self._api_key, base_url=self._base_url, model=self._model, temperature=0.7, max_tokens=4096, ) return self._chat_model def invoke(self, messages: list[BaseMessage]) -> LLMResponse: """调用 OpenAI 兼容接口(带自定义重试和速率限制处理)""" import time import random max_retries = 5 base_delay = 2.0 max_delay = 30.0 post_request_delay = 1.0 formatted_messages = [] for msg in messages: if hasattr(msg, 'type'): if msg.type == "system": role = "system" elif msg.type == "human": role = "user" else: role = "assistant" else: role = "user" formatted_messages.append({"role": role, "content": msg.content}) last_exception = None for attempt in range(max_retries): try: response = self._client.chat.completions.create( model=self._model, messages=formatted_messages, max_tokens=4096, ) content = response.choices[0].message.content usage = self.create_usage( prompt_tokens=response.usage.prompt_tokens, completion_tokens=response.usage.completion_tokens, ) logger.info( f"OpenAI兼容接口调用完成: model={self._model}, " f"tokens={usage.total_tokens}" ) time.sleep(post_request_delay) return LLMResponse(content=content, usage=usage, raw_response=response) except Exception as e: last_exception = e error_str = str(e) is_rate_limit = "429" in error_str or "rate" in error_str.lower() if attempt < max_retries - 1: if is_rate_limit: delay = min(base_delay * (2 ** attempt) + random.uniform(0, 1), max_delay) logger.warning( f"API速率限制 (尝试 {attempt + 1}/{max_retries}), " f"等待 {delay:.1f}秒 后重试..." ) else: delay = base_delay * (attempt + 1) logger.warning( f"API调用失败 (尝试 {attempt + 1}/{max_retries}): {e}, " f"等待 {delay:.1f}秒 后重试..." ) time.sleep(delay) else: logger.error(f"OpenAI兼容接口调用失败(已重试{max_retries}次): {e}") raise