"""参数槽位管理"""
import copy
import json
import logging
import traceback
from collections.abc import Mapping
from typing import Any
from jsonschema import Draft7Validator
from jsonschema.exceptions import ValidationError
from jsonschema.protocols import Validator
from jsonschema.validators import extend
from apps.scheduler.call.choice.schema import Type
from apps.scheduler.slot.parser import (
SlotConstParser,
SlotDateParser,
SlotTimestampParser,
)
from apps.scheduler.slot.util import escape_path, patch_json
from apps.schemas.response_data import ParamsNode
logger = logging.getLogger(__name__)
_TYPE_CHECKER = [
SlotDateParser,
SlotTimestampParser,
]
_FORMAT_CHECKER = []
_KEYWORD_CHECKER = {
"const": SlotConstParser.keyword_validate,
}
_TYPE_CONVERTER = [
SlotDateParser,
SlotTimestampParser,
]
_KEYWORD_CONVERTER = {
"const": SlotConstParser,
}
class Slot:
"""
参数槽
(1)检查提供的JSON和JSON Schema的有效性
(2)找到不满足要求的JSON字段,并提取成平铺的JSON,交由前端处理
(3)可对特殊格式的字段进行处理
"""
def __init__(self, schema: dict) -> None:
"""初始化参数槽处理器"""
try:
self._validator_cls = Slot._construct_validator()
except Exception as e:
err = f"Invalid JSON Schema validator: {e!s}\n{traceback.format_exc()}"
raise ValueError(err) from e
self._json = {}
try:
self._validator_cls.check_schema(schema)
except Exception as e:
err = f"Invalid JSON Schema: {e!s}"
raise ValueError(err) from e
self._validator = self._validator_cls(schema)
self._schema = schema
@staticmethod
def _construct_validator() -> type[Validator]:
"""构造JSON Schema验证器"""
type_checker = Draft7Validator.TYPE_CHECKER
for checker in _TYPE_CHECKER:
type_checker = type_checker.redefine(checker.type, checker.type_validate)
format_checker = Draft7Validator.FORMAT_CHECKER
for checker in _FORMAT_CHECKER:
format_checker = format_checker.redefine(checker.type, checker.type_validate)
return extend(
Draft7Validator, type_checker=type_checker, format_checker=format_checker, validators=_KEYWORD_CHECKER,
)
def process_json(self, json_data: str | dict[str, Any]) -> dict[str, Any]:
"""将提供的JSON数据进行处理"""
if isinstance(json_data, str):
json_data = json.loads(json_data)
def _process_json_value(json_value: Any, spec_data: dict[str, Any]) -> Any:
"""
使用递归的方式对JSON返回值进行处理
:param json_value: 返回值中的字段
:param spec_data: 返回值字段对应的JSON Schema
:return: 处理后的这部分返回值字段
"""
if "allOf" in spec_data:
processed_dict = {}
for item in spec_data["allOf"]:
processed_dict.update(_process_json_value(json_value, item))
return processed_dict
for key in ("anyOf", "oneOf"):
if key in spec_data:
for item in spec_data[key]:
processed_dict = _process_json_value(json_value, item)
if processed_dict is not None:
return processed_dict
if "type" in spec_data:
if spec_data["type"] == "array" and isinstance(json_value, list):
if "items" not in spec_data:
return json_value
return [_process_json_value(item, spec_data["items"]) for item in json_value]
if spec_data["type"] == "object" and isinstance(json_value, dict):
if "properties" not in spec_data:
return json_value
processed_dict = {}
for key, val in json_value.items():
if key not in spec_data["properties"]:
processed_dict[key] = val
continue
processed_dict[key] = _process_json_value(val, spec_data["properties"][key])
return processed_dict
for converter in _TYPE_CONVERTER:
if converter.name == spec_data["type"]:
if converter.name in spec_data:
return converter.convert(json_value, **spec_data[converter.name])
return converter.convert(json_value)
return json_value
return _process_json_value(json_data, self._schema)
@staticmethod
def _generate_example(schema_node: dict) -> Any:
"""根据schema生成示例值"""
if "anyOf" in schema_node or "oneOf" in schema_node:
for item in schema_node["anyOf"] if "anyOf" in schema_node else schema_node["oneOf"]:
example = Slot._generate_example(item)
if example is not None:
return example
if "allOf" in schema_node:
example = None
for item in schema_node["allOf"]:
if example is None:
example = Slot._generate_example(item)
else:
other_example = Slot._generate_example(item)
if isinstance(example, dict) and isinstance(other_example, dict):
example.update(other_example)
else:
example = None
break
return example
if "default" in schema_node:
return schema_node["default"]
if "type" not in schema_node:
return None
type_value = schema_node["type"]
if isinstance(type_value, list) and len(type_value) > 1:
type_value = type_value[0]
if type_value == "object":
data = {}
properties = schema_node.get("properties", {})
for name, schema in properties.items():
data[name] = Slot._generate_example(schema)
return data
if type_value == "array":
items_schema = schema_node.get("items", {})
return [Slot._generate_example(items_schema)]
if type_value == "string":
return ""
if type_value in ["number", "integer"]:
return 0
if type_value == "boolean":
return False
return None
def create_empty_slot(self) -> dict[str, Any]:
"""创建一个空的槽位"""
return self._generate_example(self._schema)
def _extract_type_desc(self, schema_node: dict[str, Any]) -> dict[str, Any]:
special_keys = ["anyOf", "allOf", "oneOf"]
for key in special_keys:
if key in schema_node:
data = {
"type": key,
"description": schema_node.get("description", ""),
"items": {},
}
for type_index, item in enumerate(schema_node[key]):
if isinstance(item, dict):
data["items"][f"item_{type_index}"] = self._extract_type_desc(item)
else:
data["items"][f"item_{type_index}"] = {"type": item, "description": ""}
return data
type_val = schema_node.get("type", "")
description = schema_node.get("description", "")
if isinstance(type_val, list):
if len(type_val) > 1:
data = {"type": "union", "description": description, "items": {}}
type_index = 0
for t in type_val:
if t == "object":
tmp_dict = {}
for key, val in schema_node.get("properties", {}).items():
tmp_dict[key] = self._extract_type_desc(val)
data["items"][f"item_{type_index}"] = tmp_dict
elif t == "array":
items_schema = schema_node.get("items", {})
data["items"][f"item_{type_index}"] = self._extract_type_desc(items_schema)
else:
data["items"][f"item_{type_index}"] = {"type": t, "description": description}
type_index += 1
return data
type_val = type_val[0] if len(type_val) == 1 else ""
data = {"type": type_val, "description": description, "items": {}}
if type_val == "object":
for key, val in schema_node.get("properties", {}).items():
data["items"][key] = self._extract_type_desc(val)
elif type_val == "array":
items_schema = schema_node.get("items", {})
if isinstance(items_schema, list):
item_index = 0
for item_index, item in enumerate(items_schema):
data["items"][f"item_{item_index}"] = self._extract_type_desc(item)
else:
data["items"]["item"] = self._extract_type_desc(items_schema)
if data["items"] == {}:
del data["items"]
return data
def extract_type_desc_from_schema(self) -> dict[str, str]:
"""从JSON Schema中提取类型描述"""
return self._extract_type_desc(self._schema)
def _extract_params_node_recursive(
self, schema_node: dict[str, Any], name: str = "", path: str = "",
) -> ParamsNode | None:
"""递归提取ParamsNode"""
if "type" not in schema_node:
return None
param_type = schema_node["type"]
if isinstance(param_type, list):
return None
if param_type == "object":
param_type = Type.DICT
elif param_type == "array":
param_type = Type.LIST
elif param_type == "string":
param_type = Type.STRING
elif param_type in ["number", "integer"]:
param_type = Type.NUMBER
elif param_type == "boolean":
param_type = Type.BOOL
else:
err = f"[Slot] 不支持的参数类型: {param_type}"
logger.warning(err)
return None
sub_params = []
if param_type == Type.DICT and "properties" in schema_node:
for key, value in schema_node["properties"].items():
sub_param = self._extract_params_node_recursive(value, name=key, path=f"{path}/{key}")
if sub_param:
sub_params.append(sub_param)
else:
sub_params = None
return ParamsNode(paramName=name,
paramPath=path,
paramType=param_type,
subParams=sub_params)
def get_params_node_from_schema(self, root: str = "") -> ParamsNode | None:
"""从JSON Schema中提取ParamsNode"""
try:
return self._extract_params_node_recursive(self._schema, name=root, path=root)
except Exception:
logger.exception("[Slot] 提取ParamsNode失败")
return None
def _flatten_schema(self, schema: dict[str, Any]) -> tuple[dict[str, Any], list[str]]:
"""将JSON Schema扁平化"""
result = {}
required = []
for key in ("allOf", "anyOf", "oneOf"):
if key in schema:
for item in schema[key]:
sub_result, sub_required = self._flatten_schema(item)
result.update(sub_result)
required.extend(sub_required)
if "type" in schema:
if schema["type"] == "object" and "properties" in schema:
sub_result, sub_required = self._flatten_schema(schema["properties"])
result.update(sub_result)
required.extend(sub_required)
else:
result[schema["type"]] = schema
required.append(schema["type"])
return result, required
def _strip_error(self, error: ValidationError) -> tuple[dict[str, Any], list[str]]:
"""裁剪发生错误的JSON Schema,并返回可能的附加路径"""
if error.validator == "required":
try:
key = error.message.split("'")[1]
except IndexError:
logger.exception("[Slot] 错误信息不合法: %s", error.message)
return {}, []
if isinstance(error.schema, Mapping) and "properties" in error.schema and key in error.schema["properties"]:
schema = error.schema["properties"][key]
schema["default"] = ""
return schema, [key]
logger.exception("[Slot] 错误schema不合法: %s", error.schema)
return {}, []
if isinstance(error.schema, Mapping):
return dict(error.schema.items()), []
logger.exception("[Slot] 错误schema不合法: %s", error.schema)
return {}, []
def _assemble_patch(
self,
key: str,
val: Any,
json_data: Any,
schema: dict[str, Any],
) -> list[dict[str, Any]]:
"""将用户手动填充的参数专为真实JSON"""
patch_list = []
key_path = key.split("/")
current_path = "/"
current_schema = schema
for path in key_path:
if path == "":
continue
if (current_path + path) == key:
patch_list.append({"op": "add", "path": current_path + path, "value": val})
return patch_list
if path.isdigit() and isinstance(json_data, list):
try:
json_data = json_data[int(path)]
current_schema = current_schema["items"]
except (IndexError, KeyError):
empty_value = self._generate_example(current_schema["items"])
patch_list.append({"op": "add", "path": current_path + "-", "value": [empty_value]})
json_data = empty_value
current_schema = current_schema["items"]
elif isinstance(json_data, dict):
try:
json_data = json_data[path]
current_schema = current_schema["properties"][path]
except (KeyError, IndexError):
patch_list.append({"op": "add", "path": current_path + path, "value": {}})
json_data = {}
current_schema = current_schema["properties"][path]
else:
err = f"[Slot] 错误的路径: {key}"
logger.exception(err)
raise ValueError(err)
current_path = current_path + path + "/"
logger.info("[Slot] 组装patch: %s", patch_list)
return patch_list
def convert_json(self, json_data: str | dict[str, Any]) -> dict[str, Any]:
"""将用户手动填充的参数专为真实JSON"""
json_dict = json.loads(json_data) if isinstance(json_data, str) else json_data
final_json = {}
for key, val in json_dict.items():
if key[0] == "/":
patch_list = self._assemble_patch(key, val, final_json, self._schema)
final_json = patch_json(patch_list, final_json)
else:
final_json[key] = val
return final_json
def check_json(self, json_data: dict[str, Any]) -> dict[str, Any]:
"""检测槽位是否合法、是否填充完成"""
empty = True
schema_template = {
"type": "object",
"properties": {},
"required": [],
}
for error in self._validator.iter_errors(json_data):
empty = False
slot_schema, additional_path = self._strip_error(error)
pointer = "/" + "/".join([escape_path(str(v)) for v in error.path])
if additional_path:
pointer = pointer.rstrip("/") + "/" + "/".join(additional_path)
schema_template["properties"][pointer] = slot_schema
schema_template["required"].append(pointer)
if not empty:
return schema_template
return {}
def add_null_to_basic_types(self) -> dict[str, Any]:
"""递归地为 JSON Schema 中的基础类型(bool、number等)添加 null 选项"""
schema_copy = copy.deepcopy(self._schema)
return add_null_to_basic_types_func(schema_copy)
def add_null_to_basic_types_func(schema: dict[str, Any]) -> dict[str, Any]:
"""
递归地为 JSON Schema 中的基础类型(bool、number等)添加 null 选项
:param schema: 原始 JSON Schema
:return: 修改后的 JSON Schema
"""
if not isinstance(schema, dict):
return schema
if "type" in schema:
if isinstance(schema["type"], str):
if schema["type"] in ["boolean", "number", "string", "integer"]:
schema["type"] = [schema["type"], "null"]
elif isinstance(schema["type"], list):
for t in schema["type"]:
if isinstance(t, str) and t in ["boolean", "number", "string", "integer"]:
if "null" not in schema["type"]:
schema["type"].append("null")
break
if "properties" in schema:
for prop, prop_schema in schema["properties"].items():
schema["properties"][prop] = add_null_to_basic_types_func(prop_schema)
if "items" in schema:
schema["items"] = add_null_to_basic_types_func(schema["items"])
for keyword in ["anyOf", "oneOf", "allOf"]:
if keyword in schema:
schema[keyword] = [add_null_to_basic_types_func(sub_schema) for sub_schema in schema[keyword]]
return schema