anthropic_compat.py 6.18 KB
"""
Anthropic 兼容模式客户端
支持智谱AI的Anthropic兼容接口
"""

import logging
from langchain_core.language_models import BaseChatModel
from langchain_core.messages import BaseMessage, AIMessage
import anthropic

from .base import BaseLLMClient, LLMResponse, LLMUsage
from ..config import get_settings

logger = logging.getLogger(__name__)


class ChatAnthropicCompat(BaseChatModel):
    """Anthropic兼容模式 LangChain 包装器"""
    
    client: anthropic.Anthropic = None
    model: str = "glm-4.7"
    temperature: float = 0.7
    max_tokens: int = 4096
    
    def __init__(self, api_key: str, base_url: str, model: str = "glm-4.7", **kwargs):
        super().__init__(**kwargs)
        self.client = anthropic.Anthropic(
            api_key=api_key,
            base_url=base_url,
        )
        self.model = model
    
    @property
    def _llm_type(self) -> str:
        return "anthropic_compat"
    
    def _generate(self, messages, stop=None, run_manager=None, **kwargs):
        from langchain_core.outputs import ChatGeneration, ChatResult
        
        formatted_messages = []
        system_content = None
        
        for msg in messages:
            if hasattr(msg, 'type'):
                if msg.type == "system":
                    system_content = msg.content
                    continue
                role = "user" if msg.type == "human" else "assistant"
            else:
                role = "user"
            formatted_messages.append({"role": role, "content": msg.content})
        
        create_kwargs = {
            "model": self.model,
            "messages": formatted_messages,
            "max_tokens": self.max_tokens,
        }
        if system_content:
            create_kwargs["system"] = system_content
            
        response = self.client.messages.create(**create_kwargs)
        
        content = response.content[0].text
        generation = ChatGeneration(message=AIMessage(content=content))
        
        return ChatResult(
            generations=[generation],
            llm_output={
                "token_usage": {
                    "prompt_tokens": response.usage.input_tokens,
                    "completion_tokens": response.usage.output_tokens,
                    "total_tokens": response.usage.input_tokens + response.usage.output_tokens,
                }
            }
        )


class AnthropicCompatClient(BaseLLMClient):
    """Anthropic兼容模式客户端"""

    def __init__(self, api_key: str = None, base_url: str = None, model: str = None):
        settings = get_settings()
        self._api_key = api_key or settings.anthropic_api_key
        self._base_url = base_url or settings.anthropic_base_url
        self._model = model or settings.anthropic_model
        self._client = anthropic.Anthropic(
            api_key=self._api_key,
            base_url=self._base_url,
        )
        self._chat_model = None

    @property
    def provider(self) -> str:
        return "anthropic_compat"

    @property
    def model_name(self) -> str:
        return self._model

    def get_chat_model(self) -> BaseChatModel:
        """获取 LangChain 聊天模型"""
        if self._chat_model is None:
            self._chat_model = ChatAnthropicCompat(
                api_key=self._api_key,
                base_url=self._base_url,
                model=self._model,
            )
        return self._chat_model

    def invoke(self, messages: list[BaseMessage]) -> LLMResponse:
        """调用 Anthropic 兼容接口(带自定义重试和速率限制处理)"""
        import time
        import random
        
        max_retries = 5
        base_delay = 2.0
        max_delay = 30.0
        post_request_delay = 1.0
        
        formatted_messages = []
        system_content = None
        
        for msg in messages:
            if hasattr(msg, 'type'):
                if msg.type == "system":
                    system_content = msg.content
                    continue
                role = "user" if msg.type == "human" else "assistant"
            else:
                role = "user"
            formatted_messages.append({"role": role, "content": msg.content})

        create_kwargs = {
            "model": self._model,
            "messages": formatted_messages,
            "max_tokens": 4096,
        }
        if system_content:
            create_kwargs["system"] = system_content
        
        last_exception = None
        
        for attempt in range(max_retries):
            try:
                response = self._client.messages.create(**create_kwargs)

                content = response.content[0].text
                usage = self.create_usage(
                    prompt_tokens=response.usage.input_tokens,
                    completion_tokens=response.usage.output_tokens,
                )

                logger.info(
                    f"Anthropic兼容接口调用完成: model={self._model}, "
                    f"tokens={usage.total_tokens}"
                )
                
                time.sleep(post_request_delay)

                return LLMResponse(content=content, usage=usage, raw_response=response)

            except Exception as e:
                last_exception = e
                error_str = str(e)
                is_rate_limit = "429" in error_str or "rate" in error_str.lower() or "1302" in error_str
                
                if attempt < max_retries - 1:
                    if is_rate_limit:
                        delay = min(base_delay * (2 ** attempt) + random.uniform(0, 1), max_delay)
                        logger.warning(
                            f"API速率限制 (尝试 {attempt + 1}/{max_retries}), "
                            f"等待 {delay:.1f}秒 后重试..."
                        )
                    else:
                        delay = base_delay * (attempt + 1)
                        logger.warning(
                            f"API调用失败 (尝试 {attempt + 1}/{max_retries}): {e}, "
                            f"等待 {delay:.1f}秒 后重试..."
                        )
                    time.sleep(delay)
                else:
                    logger.error(f"Anthropic兼容接口调用失败(已重试{max_retries}次): {e}")
                    raise