openai_compat.py 4.2 KB
"""
OpenAI 兼容模式客户端
支持火山引擎等 OpenAI 兼容接口
"""

import logging
from langchain_core.language_models import BaseChatModel
from langchain_core.messages import BaseMessage, AIMessage
from langchain_openai import ChatOpenAI
from openai import OpenAI

from .base import BaseLLMClient, LLMResponse, LLMUsage
from ..config import get_settings

logger = logging.getLogger(__name__)


class OpenAICompatClient(BaseLLMClient):
    """OpenAI 兼容模式客户端(火山引擎等)"""

    def __init__(self, api_key: str = None, base_url: str = None, model: str = None):
        settings = get_settings()
        self._api_key = api_key or settings.openai_compat_api_key
        self._base_url = base_url or settings.openai_compat_base_url
        self._model = model or settings.openai_compat_model
        self._client = OpenAI(
            api_key=self._api_key,
            base_url=self._base_url,
        )
        self._chat_model = None

    @property
    def provider(self) -> str:
        return "openai_compat"

    @property
    def model_name(self) -> str:
        return self._model

    def get_chat_model(self) -> BaseChatModel:
        """获取 LangChain 聊天模型"""
        if self._chat_model is None:
            self._chat_model = ChatOpenAI(
                api_key=self._api_key,
                base_url=self._base_url,
                model=self._model,
                temperature=0.7,
                max_tokens=4096,
            )
        return self._chat_model

    def invoke(self, messages: list[BaseMessage]) -> LLMResponse:
        """调用 OpenAI 兼容接口(带自定义重试和速率限制处理)"""
        import time
        import random
        
        max_retries = 5
        base_delay = 2.0
        max_delay = 30.0
        post_request_delay = 1.0
        
        formatted_messages = []
        
        for msg in messages:
            if hasattr(msg, 'type'):
                if msg.type == "system":
                    role = "system"
                elif msg.type == "human":
                    role = "user"
                else:
                    role = "assistant"
            else:
                role = "user"
            formatted_messages.append({"role": role, "content": msg.content})
        
        last_exception = None
        
        for attempt in range(max_retries):
            try:
                response = self._client.chat.completions.create(
                    model=self._model,
                    messages=formatted_messages,
                    max_tokens=4096,
                )

                content = response.choices[0].message.content
                usage = self.create_usage(
                    prompt_tokens=response.usage.prompt_tokens,
                    completion_tokens=response.usage.completion_tokens,
                )

                logger.info(
                    f"OpenAI兼容接口调用完成: model={self._model}, "
                    f"tokens={usage.total_tokens}"
                )
                
                time.sleep(post_request_delay)

                return LLMResponse(content=content, usage=usage, raw_response=response)

            except Exception as e:
                last_exception = e
                error_str = str(e)
                is_rate_limit = "429" in error_str or "rate" in error_str.lower()
                
                if attempt < max_retries - 1:
                    if is_rate_limit:
                        delay = min(base_delay * (2 ** attempt) + random.uniform(0, 1), max_delay)
                        logger.warning(
                            f"API速率限制 (尝试 {attempt + 1}/{max_retries}), "
                            f"等待 {delay:.1f}秒 后重试..."
                        )
                    else:
                        delay = base_delay * (attempt + 1)
                        logger.warning(
                            f"API调用失败 (尝试 {attempt + 1}/{max_retries}): {e}, "
                            f"等待 {delay:.1f}秒 后重试..."
                        )
                    time.sleep(delay)
                else:
                    logger.error(f"OpenAI兼容接口调用失败(已重试{max_retries}次): {e}")
                    raise