doubao.py 2.7 KB
"""
豆包 (字节跳动) 集成 - 备选 LLM
"""

import logging
import httpx
from langchain_core.language_models import BaseChatModel
from langchain_core.messages import BaseMessage, AIMessage

from .base import BaseLLMClient, LLMResponse
from ..config import get_settings

logger = logging.getLogger(__name__)


class DoubaoClient(BaseLLMClient):
    """豆包客户端 (备选)"""

    def __init__(self, api_key: str = None, model: str = None):
        settings = get_settings()
        self._api_key = api_key or settings.doubao_api_key
        self._model = model or settings.doubao_model
        self._base_url = "https://ark.cn-beijing.volces.com/api/v3"

    @property
    def provider(self) -> str:
        return "doubao"

    @property
    def model_name(self) -> str:
        return self._model

    def get_chat_model(self) -> BaseChatModel:
        """获取 LangChain 聊天模型"""
        raise NotImplementedError("豆包 LangChain 集成待实现")

    def invoke(self, messages: list[BaseMessage]) -> LLMResponse:
        """调用豆包"""
        try:
            formatted_messages = []
            for msg in messages:
                if hasattr(msg, 'type'):
                    role = "user" if msg.type == "human" else "assistant" if msg.type == "ai" else "system"
                else:
                    role = "user"
                formatted_messages.append({"role": role, "content": msg.content})

            with httpx.Client() as client:
                response = client.post(
                    f"{self._base_url}/chat/completions",
                    headers={
                        "Authorization": f"Bearer {self._api_key}",
                        "Content-Type": "application/json",
                    },
                    json={
                        "model": self._model,
                        "messages": formatted_messages,
                        "temperature": 0.7,
                        "max_tokens": 4096,
                    },
                    timeout=60.0,
                )
                response.raise_for_status()
                data = response.json()

            content = data["choices"][0]["message"]["content"]
            usage_data = data.get("usage", {})
            usage = self.create_usage(
                prompt_tokens=usage_data.get("prompt_tokens", 0),
                completion_tokens=usage_data.get("completion_tokens", 0),
            )

            logger.info(
                f"豆包调用完成: model={self._model}, "
                f"tokens={usage.total_tokens}"
            )

            return LLMResponse(content=content, usage=usage, raw_response=data)

        except Exception as e:
            logger.error(f"豆包调用失败: {e}")
            raise