glm.py 3.92 KB
"""
GLM-4.7 (智谱AI) 集成
"""

import logging
from langchain_core.language_models import BaseChatModel
from langchain_core.messages import BaseMessage, AIMessage
from zhipuai import ZhipuAI

from .base import BaseLLMClient, LLMResponse, LLMUsage
from ..config import get_settings

logger = logging.getLogger(__name__)


class ChatZhipuAI(BaseChatModel):
    """智谱AI 聊天模型 LangChain 包装器"""
    
    client: ZhipuAI = None
    model: str = "glm-4"
    temperature: float = 0.7
    max_tokens: int = 4096
    
    def __init__(self, api_key: str, model: str = "glm-4", **kwargs):
        super().__init__(**kwargs)
        self.client = ZhipuAI(api_key=api_key)
        self.model = model
    
    @property
    def _llm_type(self) -> str:
        return "zhipuai"
    
    def _generate(self, messages, stop=None, run_manager=None, **kwargs):
        from langchain_core.outputs import ChatGeneration, ChatResult
        
        formatted_messages = []
        for msg in messages:
            if hasattr(msg, 'type'):
                role = "user" if msg.type == "human" else "assistant" if msg.type == "ai" else "system"
            else:
                role = "user"
            formatted_messages.append({"role": role, "content": msg.content})
        
        response = self.client.chat.completions.create(
            model=self.model,
            messages=formatted_messages,
            temperature=self.temperature,
            max_tokens=self.max_tokens,
        )
        
        content = response.choices[0].message.content
        generation = ChatGeneration(message=AIMessage(content=content))
        
        return ChatResult(
            generations=[generation],
            llm_output={
                "token_usage": {
                    "prompt_tokens": response.usage.prompt_tokens,
                    "completion_tokens": response.usage.completion_tokens,
                    "total_tokens": response.usage.total_tokens,
                }
            }
        )


class GLMClient(BaseLLMClient):
    """GLM-4.7 客户端"""

    def __init__(self, api_key: str = None, model: str = None):
        settings = get_settings()
        self._api_key = api_key or settings.glm_api_key
        self._model = model or settings.glm_model
        self._client = ZhipuAI(api_key=self._api_key)
        self._chat_model = None

    @property
    def provider(self) -> str:
        return "glm"

    @property
    def model_name(self) -> str:
        return self._model

    def get_chat_model(self) -> BaseChatModel:
        """获取 LangChain 聊天模型"""
        if self._chat_model is None:
            self._chat_model = ChatZhipuAI(api_key=self._api_key, model=self._model)
        return self._chat_model

    def invoke(self, messages: list[BaseMessage]) -> LLMResponse:
        """调用 GLM"""
        try:
            formatted_messages = []
            for msg in messages:
                if hasattr(msg, 'type'):
                    role = "user" if msg.type == "human" else "assistant" if msg.type == "ai" else "system"
                else:
                    role = "user"
                formatted_messages.append({"role": role, "content": msg.content})

            response = self._client.chat.completions.create(
                model=self._model,
                messages=formatted_messages,
                temperature=0.7,
                max_tokens=4096,
            )

            content = response.choices[0].message.content
            usage = self.create_usage(
                prompt_tokens=response.usage.prompt_tokens,
                completion_tokens=response.usage.completion_tokens,
            )

            logger.info(
                f"GLM 调用完成: model={self._model}, "
                f"tokens={usage.total_tokens}"
            )

            return LLMResponse(content=content, usage=usage, raw_response=response)

        except Exception as e:
            logger.error(f"GLM 调用失败: {e}")
            raise