"""大模型提供商:OpenAI"""
import json
import logging
from collections.abc import AsyncGenerator
from typing import Any, cast
import httpx
from openai import AsyncOpenAI, AsyncStream
from openai.types.chat import (
ChatCompletion,
ChatCompletionChunk,
ChatCompletionMessage,
ChatCompletionMessageParam,
)
from typing_extensions import override
from apps.llm.token import token_calculator
from apps.models import LLMType
from apps.schemas.llm import LLMChunk, LLMFunctions, LLMToolCall
from .base import BaseProvider
_logger = logging.getLogger(__name__)
class OpenAIProvider(BaseProvider):
"""OpenAI大模型客户端"""
_client: AsyncOpenAI
_http_client: httpx.AsyncClient
input_tokens: int
output_tokens: int
_allow_chat: bool
_allow_function: bool
_allow_embedding: bool
@override
def _check_type(self) -> None:
"""检查模型能力"""
if LLMType.VISION in self.config.llmType:
err = "[OpenAIProvider] 当前暂不支持视觉模型"
_logger.error(err)
raise RuntimeError(err)
if LLMType.CHAT not in self.config.llmType:
self._allow_chat = False
else:
self._allow_chat = True
if LLMType.FUNCTION not in self.config.llmType:
self._allow_function = False
else:
self._allow_function = True
if LLMType.EMBEDDING not in self.config.llmType:
self._allow_embedding = False
else:
self._allow_embedding = True
@override
def _init_client(self) -> None:
"""初始化模型API客户端"""
self._http_client = httpx.AsyncClient(verify=False)
if not self.config.apiKey:
self._client = AsyncOpenAI(
base_url=self.config.baseUrl,
timeout=self._timeout,
http_client=self._http_client,
)
else:
self._client = AsyncOpenAI(
base_url=self.config.baseUrl,
api_key=self.config.apiKey,
timeout=self._timeout,
http_client=self._http_client,
)
def _parse_tool_calls(self, message: ChatCompletionMessage) -> list[LLMToolCall]:
"""解析工具调用并转换为LLMToolCall列表"""
tool_calls_list: list[LLMToolCall] = []
if not hasattr(message, "tool_calls") or not message.tool_calls:
return tool_calls_list
for tool_call in message.tool_calls:
if tool_call.type == "function" and hasattr(tool_call, "function"):
func = tool_call.function
try:
arguments = json.loads(func.arguments) if isinstance(func.arguments, str) else func.arguments
tool_calls_list.append(LLMToolCall(
id=tool_call.id,
name=func.name,
arguments=arguments,
))
except (json.JSONDecodeError, TypeError):
_logger.warning(
"[OpenAIProvider] 工具调用参数解析失败: tool_call_id=%s, name=%s",
tool_call.id,
func.name,
)
continue
return tool_calls_list
def _handle_usage_chunk(self, chunk: ChatCompletionChunk | None, messages: list[dict[str, str]]) -> None:
"""处理包含usage信息的chunk"""
if chunk and getattr(chunk, "usage", None):
try:
usage = chunk.usage
if usage and hasattr(usage, "prompt_tokens") and hasattr(usage, "completion_tokens"):
self.input_tokens += usage.prompt_tokens
self.output_tokens += usage.completion_tokens
except Exception:
_logger.warning("[OpenAIProvider] 推理框架未返回使用数据,使用本地估算逻辑")
if not self.input_tokens or not self.output_tokens:
self.input_tokens += token_calculator.calculate_token_length(messages)
self.output_tokens += token_calculator.calculate_token_length([{
"role": "assistant",
"content": "<think>" + self.full_thinking + "</think>" + self.full_answer,
}])
def _handle_usage_response(self, response: ChatCompletion, messages: list[dict[str, str]]) -> None:
"""处理非流式响应的usage信息"""
if hasattr(response, "usage") and response.usage:
self.input_tokens += response.usage.prompt_tokens
self.output_tokens += response.usage.completion_tokens
else:
self.input_tokens += token_calculator.calculate_token_length(messages)
self.output_tokens += token_calculator.calculate_token_length([{
"role": "assistant",
"content": "<think>" + self.full_thinking + "</think>" + self.full_answer,
}])
def _build_request_kwargs(
self,
messages: list[dict[str, str]],
*,
streaming: bool,
temperature: float = 0.7,
) -> dict:
"""构建请求参数"""
request_kwargs = {
"messages": self._convert_messages(messages),
"max_tokens": self.config.maxToken,
"temperature": temperature,
"stream": streaming,
**self.config.extraConfig,
}
if self.config.modelName:
request_kwargs["model"] = self.config.modelName
if streaming:
request_kwargs["stream_options"] = {"include_usage": True}
return request_kwargs
def _add_tools_to_request(
self,
request_kwargs: dict,
tools: list[LLMFunctions],
) -> None:
"""将工具添加到请求参数中"""
functions = [
{
"type": "function",
"function": {
"name": tool.name,
"description": tool.description,
"parameters": tool.param_schema,
},
}
for tool in tools
]
request_kwargs["tools"] = functions
async def _process_streaming_chunk(
self,
chunk: ChatCompletionChunk,
*,
include_thinking: bool,
) -> AsyncGenerator[LLMChunk, None]:
"""处理单个流式响应chunk"""
if not hasattr(chunk, "choices") or not chunk.choices:
return
delta = chunk.choices[0].delta
if (
hasattr(delta, "reasoning_content") and
getattr(delta, "reasoning_content", None) and
include_thinking
):
reasoning_content = getattr(delta, "reasoning_content", "")
self.full_thinking += reasoning_content
yield LLMChunk(reasoning_content=reasoning_content)
if hasattr(delta, "content") and delta.content:
self.full_answer += delta.content
yield LLMChunk(content=delta.content)
async def _handle_streaming_response(
self,
request_kwargs: dict,
messages: list[dict[str, str]],
*,
include_thinking: bool,
) -> AsyncGenerator[LLMChunk, None]:
"""处理流式响应"""
stream: AsyncStream[ChatCompletionChunk] = await self._client.chat.completions.create(**request_kwargs)
last_chunk = None
async for chunk in stream:
last_chunk = chunk
async for llm_chunk in self._process_streaming_chunk(
chunk,
include_thinking=include_thinking,
):
yield llm_chunk
self._handle_usage_chunk(last_chunk, messages)
async def _handle_non_streaming_response(
self,
request_kwargs: dict,
messages: list[dict[str, str]],
tools: list[LLMFunctions] | None,
*,
include_thinking: bool,
) -> AsyncGenerator[LLMChunk, None]:
"""处理非流式响应"""
response: ChatCompletion = await self._client.chat.completions.create(**request_kwargs)
tool_calls_list: list[LLMToolCall] = []
if response.choices:
message = response.choices[0].message
if (
hasattr(message, "reasoning_content") and
getattr(message, "reasoning_content", None) and
include_thinking
):
self.full_thinking = getattr(message, "reasoning_content", "")
if hasattr(message, "content") and message.content:
self.full_answer = message.content
if tools:
tool_calls_list = self._parse_tool_calls(message)
self._handle_usage_response(response, messages)
yield LLMChunk(
content=self.full_answer,
reasoning_content=self.full_thinking,
tool_call=tool_calls_list if tool_calls_list else None,
)
@override
async def chat(
self, messages: list[dict[str, str]],
tools: list[LLMFunctions] | None = None,
*, include_thinking: bool = False,
streaming: bool = True,
temperature: float = 0.7,
) -> AsyncGenerator[LLMChunk, None]:
"""聊天"""
if not self._allow_chat:
err = "[OpenAIProvider] 当前模型不支持Chat"
_logger.error(err)
raise RuntimeError(err)
if not hasattr(self, "input_tokens"):
self.input_tokens = 0
if not hasattr(self, "output_tokens"):
self.output_tokens = 0
self.full_thinking = ""
self.full_answer = ""
messages = self._validate_messages(messages)
request_kwargs = self._build_request_kwargs(messages, streaming=streaming, temperature=temperature)
if tools:
self._add_tools_to_request(request_kwargs, tools)
if streaming:
async for chunk in self._handle_streaming_response(
request_kwargs,
messages,
include_thinking=include_thinking,
):
yield chunk
else:
async for chunk in self._handle_non_streaming_response(
request_kwargs,
messages,
tools,
include_thinking=include_thinking,
):
yield chunk
@override
async def embedding(self, text: list[str]) -> list[list[float]]:
if not self._allow_embedding:
err = "[OpenAIProvider] 当前模型不支持Embedding"
_logger.error(err)
raise RuntimeError(err)
embedding_kwargs: dict[str, Any] = {"input": text}
if self.config.modelName:
embedding_kwargs["model"] = self.config.modelName
response = await self._client.embeddings.create(**embedding_kwargs)
return [data.embedding for data in response.data]
def _convert_messages(self, messages: list[dict[str, str]]) -> list[ChatCompletionMessageParam]:
"""确保消息格式符合OpenAI API要求,特别是tool消息需要tool_call_id字段"""
result: list[dict[str, str]] = []
for msg in messages:
role = msg.get("role", "user")
if role not in ("system", "user", "assistant", "tool"):
err = f"[OpenAIProvider] 未知角色: {role}"
_logger.error(err)
raise ValueError(err)
if role == "tool" and "tool_call_id" not in msg:
result.append({**msg, "tool_call_id": ""})
else:
result.append(msg)
return cast("list[ChatCompletionMessageParam]", result)