doubao.py
2.7 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
"""
豆包 (字节跳动) 集成 - 备选 LLM
"""
import logging
import httpx
from langchain_core.language_models import BaseChatModel
from langchain_core.messages import BaseMessage, AIMessage
from .base import BaseLLMClient, LLMResponse
from ..config import get_settings
logger = logging.getLogger(__name__)
class DoubaoClient(BaseLLMClient):
"""豆包客户端 (备选)"""
def __init__(self, api_key: str = None, model: str = None):
settings = get_settings()
self._api_key = api_key or settings.doubao_api_key
self._model = model or settings.doubao_model
self._base_url = "https://ark.cn-beijing.volces.com/api/v3"
@property
def provider(self) -> str:
return "doubao"
@property
def model_name(self) -> str:
return self._model
def get_chat_model(self) -> BaseChatModel:
"""获取 LangChain 聊天模型"""
raise NotImplementedError("豆包 LangChain 集成待实现")
def invoke(self, messages: list[BaseMessage]) -> LLMResponse:
"""调用豆包"""
try:
formatted_messages = []
for msg in messages:
if hasattr(msg, 'type'):
role = "user" if msg.type == "human" else "assistant" if msg.type == "ai" else "system"
else:
role = "user"
formatted_messages.append({"role": role, "content": msg.content})
with httpx.Client() as client:
response = client.post(
f"{self._base_url}/chat/completions",
headers={
"Authorization": f"Bearer {self._api_key}",
"Content-Type": "application/json",
},
json={
"model": self._model,
"messages": formatted_messages,
"temperature": 0.7,
"max_tokens": 4096,
},
timeout=60.0,
)
response.raise_for_status()
data = response.json()
content = data["choices"][0]["message"]["content"]
usage_data = data.get("usage", {})
usage = self.create_usage(
prompt_tokens=usage_data.get("prompt_tokens", 0),
completion_tokens=usage_data.get("completion_tokens", 0),
)
logger.info(
f"豆包调用完成: model={self._model}, "
f"tokens={usage.total_tokens}"
)
return LLMResponse(content=content, usage=usage, raw_response=data)
except Exception as e:
logger.error(f"豆包调用失败: {e}")
raise