"""Ollama大模型服务"""
import logging
from collections.abc import AsyncGenerator, Sequence
from typing import Any
from ollama import AsyncClient, ChatResponse
from typing_extensions import override
from apps.llm.token import token_calculator
from apps.models import LLMType
from apps.schemas.llm import LLMChunk, LLMFunctions
from .base import BaseProvider
_logger = logging.getLogger(__name__)
class OllamaProvider(BaseProvider):
"""Ollama大模型服务"""
@override
def _check_type(self) -> None:
"""检查大模型类型"""
if LLMType.VISION in self.config.llmType:
err = "[OllamaProvider] 当前暂不支持视觉模型"
_logger.error(err)
raise RuntimeError(err)
if LLMType.CHAT not in self.config.llmType:
self._allow_chat = False
else:
self._allow_chat = True
if LLMType.FUNCTION not in self.config.llmType:
self._allow_function = False
else:
self._allow_function = True
if LLMType.EMBEDDING not in self.config.llmType:
self._allow_embedding = False
else:
self._allow_embedding = True
@override
def _init_client(self) -> None:
"""初始化客户端"""
if not self.config.apiKey:
self._client = AsyncClient(
host=self.config.baseUrl,
timeout=self._timeout,
verify=False,
)
else:
self._client = AsyncClient(
host=self.config.baseUrl,
headers={
"Authorization": f"Bearer {self.config.apiKey}",
},
timeout=self._timeout,
verify=False,
)
def _process_usage_data(self, last_chunk: ChatResponse | None, messages: list[dict[str, str]]) -> None:
"""处理最后一个chunk的usage数据"""
if last_chunk and last_chunk.done:
self.input_tokens += last_chunk.prompt_eval_count or 0
self.output_tokens += last_chunk.eval_count or 0
_logger.info(
"[OllamaProvider] 使用Ollama统计数据: input_tokens=%s, output_tokens=%s",
self.input_tokens, self.output_tokens,
)
else:
_logger.warning("[OllamaProvider] 无法获取Ollama统计数据,使用本地估算逻辑")
self.input_tokens += token_calculator.calculate_token_length(messages)
self.output_tokens += token_calculator.calculate_token_length([{
"role": "assistant",
"content": "<think>" + self.full_thinking + "</think>" + self.full_answer,
}])
def _build_tools_param(self, tools: list[LLMFunctions]) -> list[dict[str, Any]]:
"""构建工具参数"""
return [{
"type": "function",
"function": {
"name": tool.name,
"description": tool.description,
"parameters": tool.param_schema,
},
} for tool in tools]
async def _handle_streaming_response(
self, response: AsyncGenerator[ChatResponse, None],
messages: list[dict[str, str]],
tools: list[LLMFunctions] | None,
) -> AsyncGenerator[LLMChunk, None]:
"""处理流式响应"""
last_chunk = None
async for chunk in response:
last_chunk = chunk
if chunk.message.thinking:
self.full_thinking += chunk.message.thinking
yield LLMChunk(reasoning_content=chunk.message.thinking)
if chunk.message.content:
self.full_answer += chunk.message.content
yield LLMChunk(content=chunk.message.content)
if tools and chunk.message.tool_calls:
tool_call_dict = {}
for tool_call in chunk.message.tool_calls:
tool_call_dict.update({
tool_call.function.name: tool_call.function.arguments,
})
if tool_call_dict:
yield LLMChunk(tool_call=tool_call_dict)
self._process_usage_data(last_chunk, messages)
async def _handle_non_streaming_response(
self, response: ChatResponse,
messages: list[dict[str, str]],
) -> AsyncGenerator[LLMChunk, None]:
"""处理非流式响应"""
if response.message.thinking:
self.full_thinking = response.message.thinking
if response.message.content:
self.full_answer = response.message.content
self._process_usage_data(response if response.done else None, messages)
yield LLMChunk(
content=self.full_answer,
reasoning_content=self.full_thinking,
)
def _build_chat_kwargs(
self, messages: list[dict[str, str]],
*,
streaming: bool,
tools: list[LLMFunctions] | None,
temperature: float = 0.7,
) -> dict[str, Any]:
"""构建chat请求参数"""
chat_kwargs = {
"messages": messages,
"options": {
"temperature": temperature,
"num_predict": self.config.maxToken,
},
"stream": streaming,
**self.config.extraConfig,
}
if self.config.modelName:
chat_kwargs["model"] = self.config.modelName
if tools:
chat_kwargs["tools"] = self._build_tools_param(tools)
return chat_kwargs
@override
async def chat(
self, messages: list[dict[str, str]],
tools: list[LLMFunctions] | None = None,
*, include_thinking: bool = False,
streaming: bool = True,
temperature: float = 0.7,
) -> AsyncGenerator[LLMChunk, None]:
if not self._allow_chat:
err = "[OllamaProvider] 当前模型不支持Chat"
_logger.error(err)
raise RuntimeError(err)
if not hasattr(self, "input_tokens"):
self.input_tokens = 0
if not hasattr(self, "output_tokens"):
self.output_tokens = 0
self.full_thinking = ""
self.full_answer = ""
messages = self._validate_messages(messages)
chat_kwargs = self._build_chat_kwargs(messages, streaming=streaming, tools=tools, temperature=temperature)
response = await self._client.chat(**chat_kwargs)
if streaming:
async for chunk in self._handle_streaming_response(response, messages, tools):
yield chunk
else:
async for chunk in self._handle_non_streaming_response(response, messages):
yield chunk
def _seq_to_list(self, seq: Sequence[Any]) -> list[Any]:
"""将Sequence转换为list"""
result = []
for item in seq:
if isinstance(item, (list, tuple, range)):
result.extend(self._seq_to_list(item))
else:
result.append(item)
return result
@override
async def embedding(self, text: list[str]) -> list[list[float]]:
"""文本向量化"""
if not self._allow_embedding:
err = "[OllamaProvider] 当前模型不支持Embedding"
_logger.error(err)
raise RuntimeError(err)
embed_kwargs: dict[str, Any] = {"input": text}
if self.config.modelName:
embed_kwargs["model"] = self.config.modelName
result = await self._client.embed(**embed_kwargs)
return self._seq_to_list(result.embeddings)