base.py
1.89 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
"""
LLM 基础抽象类
定义统一的 LLM 接口
"""
from abc import ABC, abstractmethod
from typing import Any
from dataclasses import dataclass, field
from langchain_core.language_models import BaseChatModel
from langchain_core.messages import BaseMessage
@dataclass
class LLMUsage:
"""LLM 使用统计"""
provider: str = ""
model: str = ""
prompt_tokens: int = 0
completion_tokens: int = 0
total_tokens: int = 0
def add(self, other: "LLMUsage") -> "LLMUsage":
"""累加使用量"""
return LLMUsage(
provider=self.provider or other.provider,
model=self.model or other.model,
prompt_tokens=self.prompt_tokens + other.prompt_tokens,
completion_tokens=self.completion_tokens + other.completion_tokens,
total_tokens=self.total_tokens + other.total_tokens,
)
@dataclass
class LLMResponse:
"""LLM 响应"""
content: str
usage: LLMUsage = field(default_factory=LLMUsage)
raw_response: Any = None
class BaseLLMClient(ABC):
"""LLM 客户端基类"""
@property
@abstractmethod
def provider(self) -> str:
"""供应商名称"""
pass
@property
@abstractmethod
def model_name(self) -> str:
"""模型名称"""
pass
@abstractmethod
def get_chat_model(self) -> BaseChatModel:
"""获取 LangChain 聊天模型"""
pass
@abstractmethod
def invoke(self, messages: list[BaseMessage]) -> LLMResponse:
"""调用 LLM"""
pass
def create_usage(self, prompt_tokens: int = 0, completion_tokens: int = 0) -> LLMUsage:
"""创建使用统计"""
return LLMUsage(
provider=self.provider,
model=self.model_name,
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens,
)