""" LLM 基础抽象类 定义统一的 LLM 接口 """ from abc import ABC, abstractmethod from typing import Any from dataclasses import dataclass, field from langchain_core.language_models import BaseChatModel from langchain_core.messages import BaseMessage @dataclass class LLMUsage: """LLM 使用统计""" provider: str = "" model: str = "" prompt_tokens: int = 0 completion_tokens: int = 0 total_tokens: int = 0 def add(self, other: "LLMUsage") -> "LLMUsage": """累加使用量""" return LLMUsage( provider=self.provider or other.provider, model=self.model or other.model, prompt_tokens=self.prompt_tokens + other.prompt_tokens, completion_tokens=self.completion_tokens + other.completion_tokens, total_tokens=self.total_tokens + other.total_tokens, ) @dataclass class LLMResponse: """LLM 响应""" content: str usage: LLMUsage = field(default_factory=LLMUsage) raw_response: Any = None class BaseLLMClient(ABC): """LLM 客户端基类""" @property @abstractmethod def provider(self) -> str: """供应商名称""" pass @property @abstractmethod def model_name(self) -> str: """模型名称""" pass @abstractmethod def get_chat_model(self) -> BaseChatModel: """获取 LangChain 聊天模型""" pass @abstractmethod def invoke(self, messages: list[BaseMessage]) -> LLMResponse: """调用 LLM""" pass def create_usage(self, prompt_tokens: int = 0, completion_tokens: int = 0) -> LLMUsage: """创建使用统计""" return LLMUsage( provider=self.provider, model=self.model_name, prompt_tokens=prompt_tokens, completion_tokens=completion_tokens, total_tokens=prompt_tokens + completion_tokens, )