from typing import List
import pytest
from common.utils.openai import OpenaiClient
def test_check_openai(admin_client, setup_sys_config):
"""校验openai配置"""
setup_sys_config.set("openai_base_url", "https://api.openai.com")
response = admin_client.get("/check/openai/")
assert response.status_code == 200
assert response.json()["data"] == False
setup_sys_config.set("openai_api_key", "sk-xxxx")
response = admin_client.get("/check/openai/")
assert response.status_code == 200
assert response.json()["data"] == True
@pytest.fixture
def openai_client(setup_sys_config):
# 使用mock来模拟SysConfig
setup_sys_config.set("openai_base_url", "https://api.openai.com")
setup_sys_config.set("openai_api_key", "sk-xxxx")
setup_sys_config.set("default_chat_model", "gpt-3.5-turbo")
yield OpenaiClient()
def test_init(openai_client):
assert openai_client.base_url == "https://api.openai.com"
assert openai_client.api_key == "sk-xxxx"
assert openai_client.default_chat_model == "gpt-3.5-turbo"
openai_client.client.close()
def test_request_chat_completion(openai_client, mocker):
mock_response = {
"id": "cmpl-123",
"object": "text_completion",
"created": 1234567890,
"choices": [{"message": {"content": "SELECT * FROM table"}}],
}
mocker.patch.object(
openai_client.client.chat.completions, "create", return_value=mock_response
)
result = openai_client.request_chat_completion(
messages=[{"role": "user", "content": "test message"}]
)
assert result == mock_response
class ChatCompletionMessage:
def __init__(self, content):
self.content = content
class Choice:
def __init__(self, message: ChatCompletionMessage):
self.message = message
class ChatCompletion:
def __init__(self, choices: List[Choice]):
self.choices = choices
def test_generate_sql_by_openai(openai_client, mocker):
mock_response = ChatCompletion(
choices=[Choice(message=ChatCompletionMessage(content="SELECT * FROM table"))]
)
mocker.patch.object(
openai_client, "request_chat_completion", return_value=mock_response
)
db_type = "MySQL"
table_schema = "table_schema_description"
query_desc = "query_description"
result = openai_client.generate_sql_by_openai(db_type, table_schema, query_desc)
assert result == "SELECT * FROM table"
# exception
mocker.patch.object(
openai_client, "request_chat_completion", side_effect=ValueError("API Error")
)
with pytest.raises(ValueError) as excinfo:
openai_client.generate_sql_by_openai(
"MySQL", "table_schema_description", "query_description"
)
assert str(excinfo.value) == "请求openai生成查询语句失败: API Error"
@pytest.mark.parametrize(
"data, expected_msg",
[
({}, "query_desc or db_type不存在"),
(
{
"db_type": "",
"query_desc": "获取所有用户名为test的记录",
"instance_name": "some_ins",
},
"query_desc or db_type不存在",
),
(
{
"db_type": "MySQL",
"query_desc": "获取所有用户名为test的记录",
"instance_name": "test_instance",
},
"实例不存在",
),
],
)
def test_generate_sql(admin_client, db_instance, data, expected_msg):
"""测试openai生成sql"""
response = admin_client.post("/query/generate_sql/", data=data)
assert response.status_code == 200
assert response.json()["msg"] == expected_msg