openai_compat.py
4.2 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
"""
OpenAI 兼容模式客户端
支持火山引擎等 OpenAI 兼容接口
"""
import logging
from langchain_core.language_models import BaseChatModel
from langchain_core.messages import BaseMessage, AIMessage
from langchain_openai import ChatOpenAI
from openai import OpenAI
from .base import BaseLLMClient, LLMResponse, LLMUsage
from ..config import get_settings
logger = logging.getLogger(__name__)
class OpenAICompatClient(BaseLLMClient):
"""OpenAI 兼容模式客户端(火山引擎等)"""
def __init__(self, api_key: str = None, base_url: str = None, model: str = None):
settings = get_settings()
self._api_key = api_key or settings.openai_compat_api_key
self._base_url = base_url or settings.openai_compat_base_url
self._model = model or settings.openai_compat_model
self._client = OpenAI(
api_key=self._api_key,
base_url=self._base_url,
)
self._chat_model = None
@property
def provider(self) -> str:
return "openai_compat"
@property
def model_name(self) -> str:
return self._model
def get_chat_model(self) -> BaseChatModel:
"""获取 LangChain 聊天模型"""
if self._chat_model is None:
self._chat_model = ChatOpenAI(
api_key=self._api_key,
base_url=self._base_url,
model=self._model,
temperature=0.7,
max_tokens=4096,
)
return self._chat_model
def invoke(self, messages: list[BaseMessage]) -> LLMResponse:
"""调用 OpenAI 兼容接口(带自定义重试和速率限制处理)"""
import time
import random
max_retries = 5
base_delay = 2.0
max_delay = 30.0
post_request_delay = 1.0
formatted_messages = []
for msg in messages:
if hasattr(msg, 'type'):
if msg.type == "system":
role = "system"
elif msg.type == "human":
role = "user"
else:
role = "assistant"
else:
role = "user"
formatted_messages.append({"role": role, "content": msg.content})
last_exception = None
for attempt in range(max_retries):
try:
response = self._client.chat.completions.create(
model=self._model,
messages=formatted_messages,
max_tokens=4096,
)
content = response.choices[0].message.content
usage = self.create_usage(
prompt_tokens=response.usage.prompt_tokens,
completion_tokens=response.usage.completion_tokens,
)
logger.info(
f"OpenAI兼容接口调用完成: model={self._model}, "
f"tokens={usage.total_tokens}"
)
time.sleep(post_request_delay)
return LLMResponse(content=content, usage=usage, raw_response=response)
except Exception as e:
last_exception = e
error_str = str(e)
is_rate_limit = "429" in error_str or "rate" in error_str.lower()
if attempt < max_retries - 1:
if is_rate_limit:
delay = min(base_delay * (2 ** attempt) + random.uniform(0, 1), max_delay)
logger.warning(
f"API速率限制 (尝试 {attempt + 1}/{max_retries}), "
f"等待 {delay:.1f}秒 后重试..."
)
else:
delay = base_delay * (attempt + 1)
logger.warning(
f"API调用失败 (尝试 {attempt + 1}/{max_retries}): {e}, "
f"等待 {delay:.1f}秒 后重试..."
)
time.sleep(delay)
else:
logger.error(f"OpenAI兼容接口调用失败(已重试{max_retries}次): {e}")
raise