from abc import ABC, abstractmethod
from typing import List, Dict, Any, Optional, Union
from enum import Enum
import os
import logging
logger = logging.getLogger(__name__)
class ModelCapability(Enum):
"""模型能力枚举"""
EMBED = "embed"
RERANK = "rerank"
CHAT = "chat"
class BaseModel(ABC):
"""模型基类
提供统一的模型访问接口,包括:
- embed: 文本嵌入
- rerank: 文本重排序
- chat: 对话/生成
子类需要实现具体的 API 调用逻辑。
"""
def __init__(
self,
api_key: str = None,
base_url: str = None,
chat_model: str = None,
embed_model: str = None,
rerank_model: str = None,
timeout: int = 30,
max_retries: int = 3,
**kwargs
):
"""
Args:
api_key: API 密钥
base_url: API 基础 URL(可选,用于自定义端点)
chat_model: 默认模型名称(用于 chat)
embed_model: 嵌入模型名称
rerank_model: 重排序模型名称
timeout: 请求超时时间(秒)
max_retries: 最大重试次数
**kwargs: 额外配置
"""
self.api_key = api_key or os.getenv(self._get_api_key_env_name(), "")
self.base_url = base_url or self._get_default_base_url()
self.chat_model = chat_model or self._get_default_model()
self.embed_model = embed_model or self._get_default_embed_model()
self.rerank_model = rerank_model or self._get_default_rerank_model()
self.timeout = timeout
self.max_retries = max_retries
self.extra = kwargs
self._client = None
@property
@abstractmethod
def provider(self) -> str:
"""返回提供商名称"""
pass
@property
def capabilities(self) -> List[ModelCapability]:
"""返回模型支持的能力列表"""
caps = []
if self.embed_model:
caps.append(ModelCapability.EMBED)
if self.rerank_model:
caps.append(ModelCapability.RERANK)
if self.chat_model:
caps.append(ModelCapability.CHAT)
return caps
def _get_api_key_env_name(self) -> str:
"""获取 API Key 的环境变量名"""
return f"{self.provider.upper()}_API_KEY"
def _get_default_base_url(self) -> str:
"""获取默认 API 基础 URL"""
return ""
def _get_default_model(self) -> str:
"""获取默认 chat 模型"""
return ""
def _get_default_embed_model(self) -> str:
"""获取默认嵌入模型"""
return ""
def _get_default_rerank_model(self) -> str:
"""获取默认重排序模型"""
return ""
@abstractmethod
def embed(
self,
texts: Union[str, List[str]],
**kwargs
) -> List[List[float]]:
"""文本嵌入
Args:
texts: 单个文本或文本列表
**kwargs: 额外参数
Returns:
嵌入向量列表
"""
pass
@abstractmethod
def rerank(
self,
query: str,
documents: List[str],
top_n: int = None,
**kwargs
) -> List[Dict[str, Any]]:
"""文本重排序
Args:
query: 查询文本
documents: 待排序的文档列表
top_n: 返回的最大结果数
**kwargs: 额外参数
Returns:
排序后的结果列表,每个元素包含:
- index: 原始索引
- score: 相关性分数
- text: 文档文本(可选)
"""
pass
@abstractmethod
def chat(
self,
messages: List[Dict[str, str]],
**kwargs
) -> str:
"""对话/生成
Args:
messages: 消息列表,格式为 [{"role": "user", "content": "..."}]
**kwargs: 额外参数(temperature, max_tokens 等)
Returns:
生成的文本
"""
pass
def embed_single(self, text: str, **kwargs) -> List[float]:
"""嵌入单个文本"""
result = self.embed([text], **kwargs)
return result[0] if result else []
def rerank_with_scores(
self,
query: str,
documents: List[str],
**kwargs
) -> List[float]:
"""返回重排序分数列表(按原始顺序)"""
results = self.rerank(query, documents, **kwargs)
scores = [0.0] * len(documents)
for r in results:
scores[r["index"]] = r["score"]
return scores
def generate(self, prompt: str, **kwargs) -> str:
"""简单文本生成(单轮对话)"""
return self.chat([{"role": "user", "content": prompt}], **kwargs)
def __repr__(self) -> str:
return f"{self.__class__.__name__}(provider={self.provider}, model={self.chat_model})"
class OpenAIModel(BaseModel):
"""OpenAI 模型实现
支持:
- Chat: gpt-4o, gpt-4o-mini, gpt-4-turbo, gpt-3.5-turbo
- Embed: text-embedding-3-small, text-embedding-3-large, text-embedding-ada-002
Example:
>>> model = OpenAIModel(api_key="sk-xxx")
>>> embeddings = model.embed(["Hello, world!"])
>>> response = model.chat([{"role": "user", "content": "Hi"}])
"""
@property
def provider(self) -> str:
return "openai"
def _get_default_base_url(self) -> str:
return "https://api.openai.com/v1"
def _get_default_model(self) -> str:
return "gpt-4o-mini"
def _get_default_embed_model(self) -> str:
return "text-embedding-3-small"
@property
def client(self):
"""延迟初始化 OpenAI 客户端"""
if self._client is None:
try:
from openai import OpenAI
self._client = OpenAI(
api_key=self.api_key,
base_url=self.base_url,
timeout=self.timeout,
max_retries=self.max_retries
)
except ImportError:
raise ImportError("OpenAI requires openai package. Install with: pip install openai")
return self._client
def embed(
self,
texts: Union[str, List[str]],
model: str = None,
**kwargs
) -> List[List[float]]:
"""使用 OpenAI 嵌入模型"""
if isinstance(texts, str):
texts = [texts]
response = self.client.embeddings.create(
model=model or self.embed_model,
input=texts,
**kwargs
)
return [item.embedding for item in response.data]
def rerank(
self,
query: str,
documents: List[str],
top_n: int = None,
**kwargs
) -> List[Dict[str, Any]]:
"""OpenAI 不支持 rerank。
Raises:
NotImplementedError: 始终抛出,提示使用支持 rerank 的模型。
"""
raise NotImplementedError(
"OpenAI does not provide a rerank API. "
"Use a provider with native rerank support "
"(e.g. DashScope with gte-rerank-v2, Cohere, Jina)."
)
def chat(
self,
messages: List[Dict[str, str]],
model: str = None,
temperature: float = None,
max_tokens: int = None,
top_p: float = None,
frequency_penalty: float = None,
presence_penalty: float = None,
response_format: Dict = None,
**kwargs
) -> str:
"""使用 OpenAI Chat Completions API
Args:
messages: 消息列表
model: 模型名称,默认为 gpt-4o-mini
temperature: 采样温度 [0, 2]
max_tokens: 最大生成 token 数
top_p: 核采样概率阈值
frequency_penalty: 频率惩罚 [-2.0, 2.0]
presence_penalty: 存在惩罚 [-2.0, 2.0]
response_format: 输出格式,如 {"type": "json_object"}
**kwargs: 额外参数 (如 tools, seed, stop 等)
Returns:
生成的文本
"""
params = {
"model": model or self.chat_model,
"messages": messages,
}
if temperature is not None:
params["temperature"] = temperature
if max_tokens is not None:
params["max_tokens"] = max_tokens
if top_p is not None:
params["top_p"] = top_p
if frequency_penalty is not None:
params["frequency_penalty"] = frequency_penalty
if presence_penalty is not None:
params["presence_penalty"] = presence_penalty
if response_format is not None:
params["response_format"] = response_format
params.update(kwargs)
response = self.client.chat.completions.create(**params)
return response.choices[0].message.content
class DashScopeModel(BaseModel):
"""阿里云 DashScope (千问) 模型实现
支持:
- Chat: qwen-max, qwen-plus, qwen-turbo 等 (Generation API)
- Embed: text-embedding-v3 (TextEmbedding API)
- Rerank: qwen3-rerank, gte-rerank-v2 (TextReRank API)
文档:
- Chat: https://help.aliyun.com/zh/model-studio/qwen-api-via-dashscope
- Rerank: https://help.aliyun.com/zh/model-studio/text-rerank-api
Example:
>>> model = DashScopeModel(api_key="sk-xxx")
>>> response = model.chat([{"role": "user", "content": "你好"}])
>>> embeddings = model.embed(["Hello, world!"])
>>> results = model.rerank("查询", ["文档1", "文档2"])
"""
@property
def provider(self) -> str:
return "dashscope"
def _get_api_key_env_name(self) -> str:
return "DASHSCOPE_API_KEY"
def _get_default_base_url(self) -> str:
return "https://dashscope.aliyuncs.com/api/v1"
def _get_default_model(self) -> str:
return "qwen-plus"
def _get_default_embed_model(self) -> str:
return "text-embedding-v3"
def _get_default_rerank_model(self) -> str:
return "gte-rerank-v2"
@staticmethod
def _ensure_dashscope():
"""确保 dashscope 包已安装"""
try:
import dashscope
except ImportError:
raise ImportError(
"DashScopeModel requires dashscope package. "
"Install with: pip/pip3 install -U dashscope"
)
def embed(
self,
texts: Union[str, List[str]],
model: str = None,
text_type: str = "document",
**kwargs
) -> List[List[float]]:
"""使用 DashScope TextEmbedding API
Args:
texts: 文本或文本列表
model: 模型名称,默认为 text-embedding-v3
text_type: 文本类型,可选 "query" 或 "document"
**kwargs: 额外参数
Returns:
嵌入向量列表
"""
self._ensure_dashscope()
from dashscope import TextEmbedding
if isinstance(texts, str):
texts = [texts]
response = TextEmbedding.call(
api_key=self.api_key,
model=model or self.embed_model,
input=texts,
text_type=text_type,
**kwargs
)
if response.status_code != 200:
raise RuntimeError(
f"DashScope embedding failed: {response.code} - {response.message}"
)
return [item['embedding'] for item in response.output['embeddings']]
def rerank(
self,
query: str,
documents: List[str],
top_n: int = None,
**kwargs
) -> List[Dict[str, Any]]:
"""使用 DashScope TextReRank API
文档: https://help.aliyun.com/zh/model-studio/text-rerank-api
Args:
query: 查询文本
documents: 待排序的文档列表
top_n: 返回的最大结果数
**kwargs: 额外参数 (如 return_documents, instruct)
Returns:
排序后的结果列表,每个元素包含 index 和 score
"""
self._ensure_dashscope()
from dashscope import TextReRank
model = kwargs.pop("model", None) or self.rerank_model
params = dict(
api_key=self.api_key,
model=model,
query=query,
documents=documents,
**kwargs
)
if top_n is not None:
params["top_n"] = top_n
response = TextReRank.call(**params)
if response.status_code != 200:
raise RuntimeError(
f"DashScope rerank failed: {response.code} - {response.message}"
)
return [
{
"index": r["index"],
"score": r["relevance_score"],
}
for r in response.output["results"]
]
def chat(
self,
messages: List[Dict[str, str]],
model: str = None,
temperature: float = None,
top_p: float = None,
top_k: int = None,
max_tokens: int = None,
repetition_penalty: float = None,
enable_search: bool = False,
stream: bool = False,
result_format: str = "message",
**kwargs
) -> str:
"""使用 DashScope Generation API
文档: https://help.aliyun.com/zh/model-studio/qwen-api-via-dashscope
Args:
messages: 消息列表,格式为 [{"role": "user", "content": "..."}]
model: 模型名称,默认为 qwen-plus
temperature: 采样温度 [0, 2),控制随机性
top_p: 核采样概率阈值 (0, 1.0]
top_k: 候选 token 数量
max_tokens: 最大生成 token 数
repetition_penalty: 重复度惩罚 (>0)
enable_search: 是否启用联网搜索
stream: 是否流式输出 (启用时自动开启 incremental_output)
result_format: 返回格式 "message" 或 "text"
**kwargs: 额外参数 (如 enable_thinking, tools 等)
Returns:
生成的文本
"""
self._ensure_dashscope()
from dashscope import Generation
params = {
"api_key": self.api_key,
"model": model or self.chat_model,
"messages": messages,
"result_format": result_format,
}
if temperature is not None:
params["temperature"] = temperature
if top_p is not None:
params["top_p"] = top_p
if top_k is not None:
params["top_k"] = top_k
if max_tokens is not None:
params["max_tokens"] = max_tokens
if repetition_penalty is not None:
params["repetition_penalty"] = repetition_penalty
if enable_search:
params["enable_search"] = True
if stream:
params["stream"] = True
params["incremental_output"] = True
params.update(kwargs)
if stream:
return self._chat_stream(params, result_format)
else:
return self._chat_sync(params, result_format)
def _chat_sync(self, params: Dict, result_format: str) -> str:
"""同步调用 Chat API"""
from dashscope import Generation
response = Generation.call(**params)
if response.status_code != 200:
raise RuntimeError(
f"DashScope chat failed: {response.code} - {response.message}"
)
if result_format == "message":
return response.output.choices[0].message.content
else:
return response.output.text
@staticmethod
def _chat_stream(params: Dict, result_format: str) -> str:
"""流式调用 Chat API,返回拼接后的完整文本"""
from dashscope import Generation
response = Generation.call(**params)
full_content = ""
for chunk in response:
if chunk.status_code != 200:
raise RuntimeError(
f"DashScope chat failed: {chunk.code} - {chunk.message}"
)
if result_format == "message":
content = chunk.output.choices[0].message.content
else:
content = chunk.output.text
if content:
full_content += content
return full_content
class OllamaModel(BaseModel):
"""Ollama 本地部署模型实现
支持通过 Ollama 运行本地模型:
- Chat: llama3, qwen2, mistral 等 (/api/chat)
- Embed: nomic-embed-text, mxbai-embed-large 等 (/api/embed)
- Rerank: 通过 embedding + cosine similarity 模拟
文档: https://docs.ollama.com/api
Example:
>>> model = OllamaModel(chat_model="llama3")
>>> response = model.chat([{"role": "user", "content": "Hi"}])
>>> embeddings = model.embed(["Hello", "World"])
"""
@property
def provider(self) -> str:
return "ollama"
def _get_default_base_url(self) -> str:
return "http://localhost:11434"
def _get_default_model(self) -> str:
return "llama3"
def _get_default_embed_model(self) -> str:
return "nomic-embed-text"
def embed(
self,
texts: Union[str, List[str]],
model: str = None,
**kwargs
) -> List[List[float]]:
"""使用 Ollama Embed API (批量)
文档: https://docs.ollama.com/api/embed
Args:
texts: 文本或文本列表
model: 模型名称,默认为 nomic-embed-text
**kwargs: 额外参数 (如 truncate, dimensions, keep_alive)
Returns:
嵌入向量列表
"""
import requests
if isinstance(texts, str):
texts = [texts]
response = requests.post(
f"{self.base_url}/api/embed",
json={
"model": model or self.embed_model,
"input": texts,
**kwargs
},
timeout=self.timeout
)
response.raise_for_status()
return response.json()["embeddings"]
def rerank(
self,
query: str,
documents: List[str],
top_n: int = None,
**kwargs
) -> List[Dict[str, Any]]:
"""Ollama 不支持 rerank。
Raises:
NotImplementedError: 始终抛出,提示使用支持 rerank 的模型。
"""
raise NotImplementedError(
"Ollama does not provide a rerank API. "
"Use a provider with native rerank support "
"(e.g. DashScope with gte-rerank-v2, Cohere, Jina)."
)
def chat(
self,
messages: List[Dict[str, str]],
model: str = None,
temperature: float = None,
top_p: float = None,
top_k: int = None,
max_tokens: int = None,
repetition_penalty: float = None,
output_format: str = None,
keep_alive: str = None,
**kwargs
) -> str:
"""使用 Ollama Chat API
文档: https://docs.ollama.com/api/chat
Args:
messages: 消息列表,格式为 [{"role": "user", "content": "..."}]
model: 模型名称,默认为 llama3
temperature: 采样温度
top_p: 核采样概率阈值
top_k: 候选 token 数量
max_tokens: 最大生成 token 数 (对应 Ollama 的 num_predict)
repetition_penalty: 重复度惩罚 (对应 Ollama 的 repeat_penalty)
output_format: 输出格式,如 "json"
keep_alive: 模型保持加载时间,如 "5m"
**kwargs: 额外 options 参数 (如 seed, num_ctx 等)
Returns:
生成的文本
"""
import requests
payload = {
"model": model or self.chat_model,
"messages": messages,
"stream": False,
}
if output_format is not None:
payload["format"] = output_format
if keep_alive is not None:
payload["keep_alive"] = keep_alive
options = {}
if temperature is not None:
options["temperature"] = temperature
if top_p is not None:
options["top_p"] = top_p
if top_k is not None:
options["top_k"] = top_k
if max_tokens is not None:
options["num_predict"] = max_tokens
if repetition_penalty is not None:
options["repeat_penalty"] = repetition_penalty
options.update(kwargs)
if options:
payload["options"] = options
response = requests.post(
f"{self.base_url}/api/chat",
json=payload,
timeout=self.timeout
)
response.raise_for_status()
return response.json()["message"]["content"]
_PROVIDERS: Dict[str, type] = {
"openai": OpenAIModel,
"dashscope": DashScopeModel,
"ollama": OllamaModel,
}
def register_provider(name: str, cls: type):
"""注册自定义模型提供商
Args:
name: 提供商名称
cls: 模型类(需继承 BaseModel)
Example:
>>> class MyModel(BaseModel): ...
>>> register_provider("my_provider", MyModel)
>>> model = create_model("my_provider", api_key="xxx")
"""
if not (isinstance(cls, type) and issubclass(cls, BaseModel)):
raise TypeError(f"cls must be a subclass of BaseModel, got {cls}")
_PROVIDERS[name] = cls
def list_providers() -> List[str]:
"""列出所有已注册的模型提供商"""
return list(_PROVIDERS.keys())
def create_model(
provider: str,
**kwargs
) -> BaseModel:
"""创建模型实例的工厂函数
配置优先级: kwargs 显式传参 > 环境变量 > 子类硬编码默认值
Args:
provider: 提供商名称
- openai: OpenAI GPT 系列
- dashscope: 阿里云千问系列
- ollama: Ollama 本地部署
**kwargs: 传给模型构造函数的参数
(api_key, base_url, chat_model, embed_model, rerank_model, ...)
Returns:
BaseModel 实例
Example:
>>> model = create_model("openai", api_key="sk-xxx")
>>> model = create_model("dashscope", api_key="sk-xxx")
>>> model = create_model("ollama", chat_model="llama3")
"""
if provider not in _PROVIDERS:
raise ValueError(
f"Unknown provider: {provider}. "
f"Available: {list(_PROVIDERS.keys())}"
)
return _PROVIDERS[provider](**kwargs)