"""
-------------------------------------------------------------------------
This file is part of the RAGSDK project.
Copyright (c) 2025 Huawei Technologies Co.,Ltd.
RAGSDK is licensed under Mulan PSL v2.
You can use this software according to the terms and conditions of the Mulan PSL v2.
You may obtain a copy of Mulan PSL v2 at:
http://license.coscl.org.cn/MulanPSL2
THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
See the Mulan PSL v2 for more details.
-------------------------------------------------------------------------
"""
import unittest
from typing import Optional, List, Any, Dict
from langchain_core.callbacks import CallbackManagerForLLMRun
from pydantic import Field, BaseModel, ValidationError, field_validator, ConfigDict
from mx_rag.utils.common import validate_params, validate_list_str
class Person():
MAX_LIMIT = 100
name: str
number: int = 10
@validate_params(
age=dict(validator=lambda x: x > 10),
weight=dict(validator=lambda x: 90 <= x <= 150)
)
def __init__(self, age: int, weight: int, ranker: int = 1):
self.age = age
self.weight = weight
self.ranker = ranker
def call_back_fun(self, func, *args, **kwargs):
func(*args, **kwargs)
@validate_params(
param1=dict(validator=lambda x: 0.0 < x < 1.0),
)
def validate_call_back_fun(self, param1: float, func, *args):
func(*args)
@validate_params(
param1=dict(validator=lambda x: x < Person.MAX_LIMIT),
)
def validate_self_var(self, param1):
pass
@validate_params(
name=dict(validator=lambda x: isinstance(x, str)),
weight=dict(validator=lambda x: 0 <= x <= 50)
)
class Animal(BaseModel):
model_config = ConfigDict(arbitrary_types_allowed=True)
name: str
master: Person
weight: int = 10
@property
def _llm_type(self) -> str:
pass
def _call(self, prompt: str, stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any) -> str:
pass
class Cat(BaseModel):
model_config = ConfigDict(arbitrary_types_allowed=True)
age: int = Field(ge=0, le=10)
gender: str
color: str = Field(min_length=0, max_length=10, default="white")
master: Optional[Person] = None
attribute: Optional[str] = None
multidict: List[Dict[str, str]] = Field(min_length=0, max_length=1, default=None)
@field_validator('gender')
def validate_gender(cls, input_value):
if input_value not in ["male", "female"]:
raise ValueError(f"month value must be in [male, female]")
return input_value
@validate_params(
param1=dict(validator=lambda x: isinstance(x, str) and len(x) >= 5),
param2=dict(validator=lambda x: x > 0)
)
def non_class_funciton(param1, param2: int, param3=None):
pass
@validate_params(
param3=dict(validator=lambda x: x > 0)
)
def non_class_funciton1(param1, param2: int, param3: int = 50):
pass
@validate_params(
param1=dict(validator=lambda x: isinstance(x, int)),
param2=dict(validator=lambda x: x > 0),
param3=dict(validator=
lambda x: all(len(item) > 1 for item in x),
message="check rule: lambda x: all(len(item) > 1 for item in x)")
)
def non_class_funciton2(param1: int, param2: int, param3: List[dict]):
pass
@validate_params(
param1={"validator": lambda x: validate_list_str(x, [1, 3], [5, 10]),
"message": "param type is not List[str] or length of list not in [1, 3] "
"or length of str in list not in [5, 10]"}
)
def non_class_funciton3(param1: List[str]):
pass
class TestValidateParams(unittest.TestCase):
def test_class_scope(self):
Person(18, 140, 2)
Person(18, weight=140, ranker=-1)
with self.assertRaises(ValueError):
Person(18, 85, 10)
def test_non_calss_funciton(self):
non_class_funciton("hello", 1)
with self.assertRaises(ValueError):
non_class_funciton("hello", param2=-1, param3=5)
with self.assertRaises(ValueError):
non_class_funciton(1, param2=-1, param3=5)
def test_call_back_function(self):
person = Person(18, 140, 2)
person.call_back_fun(non_class_funciton, "world!", param2=3, param3=5)
person.call_back_fun(non_class_funciton, param1="world!", param2=3, param3=5)
person.validate_call_back_fun(0.5, non_class_funciton, "world!", 5)
with self.assertRaises(ValueError):
person.validate_call_back_fun(1.1, non_class_funciton, "world!", 5)
def test_default_parm_validation(self):
non_class_funciton1(1, 2)
def test_validate_self_var(self):
person = Person(18, 140, 2)
person.validate_self_var(80)
with self.assertRaises(ValueError):
person.validate_self_var(110)
def test_class_variable(self):
person = Person(18, 140, 2)
Animal(name="panda", master=person)
with self.assertRaises(ValidationError):
Animal(name="panda", master=123)
Animal(name="panda", master=person, weight=0)
with self.assertRaises(ValueError):
Animal(name="panda", master=person, weight=-1)
def test_pydantic(self):
person = Person(18, 140, 2)
cat = Cat(age=10, gender="male", master=person, attribute="attribute")
self.assertEqual(cat.color, "white")
with self.assertRaises(ValidationError):
Cat(age=10, gender="median", attribute=None)
with self.assertRaises(ValidationError):
Cat(age=20, gender="male", attribute=None)
Cat(age=0, gender="male", attribute=None)
Cat(age=0, gender="male")
with self.assertRaises(ValidationError):
Cat(age=10, gender="male", master=person, attribute="attribute", color="hello world!")
with self.assertRaises(ValidationError):
Cat(age=10, gender="male", master=person, attribute="attribute", multidict=[{"1": "a"}, {"2": "b"}])
Cat(age=10, gender="male", master=person, attribute="attribute", multidict=[])
def test_log_info(self):
try:
non_class_funciton2("1", 1, [{1: "a", 2: "b", 3: "c"}, {4: "a", 5: "b", 6: "c"}])
except Exception as e:
self.assertGreater(str(e).find("'param1' of function 'non_class_funciton2' is invalid"), -1)
try:
non_class_funciton2(1, -1, [{1: "a", 2: "b", 3: "c"}, {4: "a", 5: "b", 6: "c"}])
except Exception as e:
self.assertGreater(str(e).find("'param2' of function 'non_class_funciton2' is invalid"), -1)
try:
non_class_funciton2(1, 1, [{1: "a"}, {4: "a"}])
except Exception as e:
self.assertGreater(str(e).find("lambda x: all(len(item) > 1 for item in x)"), -1)
def test_validate_list_str(self):
non_class_funciton3(["hello!", "world!", "beautiful"])
with self.assertRaises(ValueError):
non_class_funciton3([123, "world!", "beautiful"])
with self.assertRaises(ValueError):
non_class_funciton3(["hello!", "world!", "beautiful", "xxxxx"])
with self.assertRaises(ValueError):
non_class_funciton3(["hi!", "world!", "beautiful"])