anthropic_compat.py
6.18 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
"""
Anthropic 兼容模式客户端
支持智谱AI的Anthropic兼容接口
"""
import logging
from langchain_core.language_models import BaseChatModel
from langchain_core.messages import BaseMessage, AIMessage
import anthropic
from .base import BaseLLMClient, LLMResponse, LLMUsage
from ..config import get_settings
logger = logging.getLogger(__name__)
class ChatAnthropicCompat(BaseChatModel):
"""Anthropic兼容模式 LangChain 包装器"""
client: anthropic.Anthropic = None
model: str = "glm-4.7"
temperature: float = 0.7
max_tokens: int = 4096
def __init__(self, api_key: str, base_url: str, model: str = "glm-4.7", **kwargs):
super().__init__(**kwargs)
self.client = anthropic.Anthropic(
api_key=api_key,
base_url=base_url,
)
self.model = model
@property
def _llm_type(self) -> str:
return "anthropic_compat"
def _generate(self, messages, stop=None, run_manager=None, **kwargs):
from langchain_core.outputs import ChatGeneration, ChatResult
formatted_messages = []
system_content = None
for msg in messages:
if hasattr(msg, 'type'):
if msg.type == "system":
system_content = msg.content
continue
role = "user" if msg.type == "human" else "assistant"
else:
role = "user"
formatted_messages.append({"role": role, "content": msg.content})
create_kwargs = {
"model": self.model,
"messages": formatted_messages,
"max_tokens": self.max_tokens,
}
if system_content:
create_kwargs["system"] = system_content
response = self.client.messages.create(**create_kwargs)
content = response.content[0].text
generation = ChatGeneration(message=AIMessage(content=content))
return ChatResult(
generations=[generation],
llm_output={
"token_usage": {
"prompt_tokens": response.usage.input_tokens,
"completion_tokens": response.usage.output_tokens,
"total_tokens": response.usage.input_tokens + response.usage.output_tokens,
}
}
)
class AnthropicCompatClient(BaseLLMClient):
"""Anthropic兼容模式客户端"""
def __init__(self, api_key: str = None, base_url: str = None, model: str = None):
settings = get_settings()
self._api_key = api_key or settings.anthropic_api_key
self._base_url = base_url or settings.anthropic_base_url
self._model = model or settings.anthropic_model
self._client = anthropic.Anthropic(
api_key=self._api_key,
base_url=self._base_url,
)
self._chat_model = None
@property
def provider(self) -> str:
return "anthropic_compat"
@property
def model_name(self) -> str:
return self._model
def get_chat_model(self) -> BaseChatModel:
"""获取 LangChain 聊天模型"""
if self._chat_model is None:
self._chat_model = ChatAnthropicCompat(
api_key=self._api_key,
base_url=self._base_url,
model=self._model,
)
return self._chat_model
def invoke(self, messages: list[BaseMessage]) -> LLMResponse:
"""调用 Anthropic 兼容接口(带自定义重试和速率限制处理)"""
import time
import random
max_retries = 5
base_delay = 2.0
max_delay = 30.0
post_request_delay = 1.0
formatted_messages = []
system_content = None
for msg in messages:
if hasattr(msg, 'type'):
if msg.type == "system":
system_content = msg.content
continue
role = "user" if msg.type == "human" else "assistant"
else:
role = "user"
formatted_messages.append({"role": role, "content": msg.content})
create_kwargs = {
"model": self._model,
"messages": formatted_messages,
"max_tokens": 4096,
}
if system_content:
create_kwargs["system"] = system_content
last_exception = None
for attempt in range(max_retries):
try:
response = self._client.messages.create(**create_kwargs)
content = response.content[0].text
usage = self.create_usage(
prompt_tokens=response.usage.input_tokens,
completion_tokens=response.usage.output_tokens,
)
logger.info(
f"Anthropic兼容接口调用完成: model={self._model}, "
f"tokens={usage.total_tokens}"
)
time.sleep(post_request_delay)
return LLMResponse(content=content, usage=usage, raw_response=response)
except Exception as e:
last_exception = e
error_str = str(e)
is_rate_limit = "429" in error_str or "rate" in error_str.lower() or "1302" in error_str
if attempt < max_retries - 1:
if is_rate_limit:
delay = min(base_delay * (2 ** attempt) + random.uniform(0, 1), max_delay)
logger.warning(
f"API速率限制 (尝试 {attempt + 1}/{max_retries}), "
f"等待 {delay:.1f}秒 后重试..."
)
else:
delay = base_delay * (attempt + 1)
logger.warning(
f"API调用失败 (尝试 {attempt + 1}/{max_retries}): {e}, "
f"等待 {delay:.1f}秒 后重试..."
)
time.sleep(delay)
else:
logger.error(f"Anthropic兼容接口调用失败(已重试{max_retries}次): {e}")
raise