base.py 1.89 KB
"""
LLM 基础抽象类
定义统一的 LLM 接口
"""

from abc import ABC, abstractmethod
from typing import Any
from dataclasses import dataclass, field
from langchain_core.language_models import BaseChatModel
from langchain_core.messages import BaseMessage


@dataclass
class LLMUsage:
    """LLM 使用统计"""
    provider: str = ""
    model: str = ""
    prompt_tokens: int = 0
    completion_tokens: int = 0
    total_tokens: int = 0

    def add(self, other: "LLMUsage") -> "LLMUsage":
        """累加使用量"""
        return LLMUsage(
            provider=self.provider or other.provider,
            model=self.model or other.model,
            prompt_tokens=self.prompt_tokens + other.prompt_tokens,
            completion_tokens=self.completion_tokens + other.completion_tokens,
            total_tokens=self.total_tokens + other.total_tokens,
        )


@dataclass
class LLMResponse:
    """LLM 响应"""
    content: str
    usage: LLMUsage = field(default_factory=LLMUsage)
    raw_response: Any = None


class BaseLLMClient(ABC):
    """LLM 客户端基类"""

    @property
    @abstractmethod
    def provider(self) -> str:
        """供应商名称"""
        pass

    @property
    @abstractmethod
    def model_name(self) -> str:
        """模型名称"""
        pass

    @abstractmethod
    def get_chat_model(self) -> BaseChatModel:
        """获取 LangChain 聊天模型"""
        pass

    @abstractmethod
    def invoke(self, messages: list[BaseMessage]) -> LLMResponse:
        """调用 LLM"""
        pass

    def create_usage(self, prompt_tokens: int = 0, completion_tokens: int = 0) -> LLMUsage:
        """创建使用统计"""
        return LLMUsage(
            provider=self.provider,
            model=self.model_name,
            prompt_tokens=prompt_tokens,
            completion_tokens=completion_tokens,
            total_tokens=prompt_tokens + completion_tokens,
        )