import os
import shutil
import pytest
from rec_sdk_common.validator.safe_checker import (
str_safe_check,
int_safe_check,
class_safe_check,
dir_safe_check,
float_safe_check,
file_safe_check,
)
from rec_sdk_common.constants.constants import NumCheckValueMethod
from rec_sdk_common.communication.hccl import get_rank_id
class TestStrSafeCheck:
"""Test for 'mxrec.python.util.validator.safe_checker.str_safe_check'."""
@staticmethod
def test_ok():
try:
str_safe_check("table_name", "user_table", min_len=0, max_len=255)
except Exception as e:
pytest.fail(f"unexpected exception raised: {e}")
@staticmethod
def test_type_err():
with pytest.raises(ValueError) as excinfo:
str_safe_check("table_name", 1)
assert "is not str" in str(excinfo.value)
@staticmethod
def test_len_less_err():
with pytest.raises(ValueError) as excinfo:
str_safe_check("table_name", "user_table", min_len=20)
assert "'table_name' length is less than" in str(excinfo.value)
@staticmethod
def test_len_greater_err():
with pytest.raises(ValueError) as excinfo:
str_safe_check("table_name", "user_table", max_len=5)
assert "'table_name' length is bigger than" in str(excinfo.value)
@staticmethod
def test_whitelist_err():
with pytest.raises(ValueError) as excinfo:
str_safe_check("table_name", "user_table*")
assert "The string 'table_name' is invalid" in str(excinfo.value)
@staticmethod
def test_black_element_err():
with pytest.raises(ValueError) as excinfo:
str_safe_check("table_name", "user_table", black_element="user")
assert "contain black element" in str(excinfo.value)
class TestIntSafeCheck:
"""Test for 'mxrec.python.util.validator.safe_checker.int_safe_check'."""
@staticmethod
def test_ok():
try:
int_safe_check("embedding_dim", 8, min_value=0, max_value=8192)
except Exception as e:
pytest.fail(f"unexpected exception raised: {e}")
@staticmethod
def test_type_err():
with pytest.raises(ValueError) as excinfo:
int_safe_check("embedding_dim", "8")
assert "is not int" in str(excinfo.value)
@staticmethod
def test_max_value_err():
with pytest.raises(ValueError) as excinfo:
int_safe_check("embedding_dim", 8, max_value=4)
assert "is bigger than" in str(excinfo.value)
@staticmethod
def test_min_value_err():
with pytest.raises(ValueError) as excinfo:
int_safe_check("embedding_dim", 8, min_value=16)
assert "is less than" in str(excinfo.value)
class TestClassSafeCheck:
"""Test for 'mxrec.python.util.validator.safe_checker.class_safe_check'."""
@staticmethod
def test_ok():
try:
class_safe_check("embedding_dim", 8, int)
except Exception as e:
pytest.fail(f"unexpected exception raised: {e}")
@staticmethod
def test_type_err():
with pytest.raises(ValueError) as excinfo:
class_safe_check("embedding_dim", 8, str)
assert "not in <class 'str'>" in str(excinfo.value)
class TestDirSafeCheck:
"""Test for 'mxrec.python.util.validator.safe_checker.dir_safe_check'."""
@staticmethod
def test_ok():
path = "test_path"
if not os.path.exists(path):
os.makedirs(path, exist_ok=True)
try:
dir_safe_check("path", path)
except Exception as e:
pytest.fail(f"unexpected exception raised: {e}")
shutil.rmtree(path)
class TestFileSafeCheck:
"""Test for 'rec_sdk_common.validator.safe_checker.file_safe_check'."""
@staticmethod
def _create_test_file(file_path: str, content: str):
os.makedirs(os.path.dirname(file_path), exist_ok=True)
with open(file_path, "w", encoding="utf-8") as f:
f.write(content)
@staticmethod
def _delete_file(file_path: str):
if os.path.exists(file_path):
os.remove(file_path)
@staticmethod
def test_ok():
test_file_name = "test_file.txt"
test_content = """This is test file.
Hello, World!
"""
test_file_path = os.path.join(os.getcwd(), test_file_name)
TestFileSafeCheck._create_test_file(test_file_path, test_content)
try:
file_safe_check("test file", test_file_path)
except Exception as e:
pytest.fail(f"unexpected exception raised: {e}")
TestFileSafeCheck._delete_file(test_file_path)
class TestFloatSafeCheck:
"""Test for 'mxrec.python.util.validator.safe_checker.float_safe_check'."""
@staticmethod
def test_default_method():
try:
float_safe_check(
"embedding_dim",
8.0,
min_value=0.0,
max_value=8.0,
method=NumCheckValueMethod.DEFAULT.value,
)
except Exception as e:
pytest.fail(f"unexpected exception raised: {e}")
@staticmethod
def test_open_interval_method():
with pytest.raises(ValueError) as excinfo:
float_safe_check(
"embedding_dim",
8.0,
min_value=0.0,
max_value=8.0,
method=NumCheckValueMethod.OPEN_INTERVAL.value,
)
assert "is bigger than or equal" in str(excinfo.value)
@staticmethod
def test_left_open_interval_method():
with pytest.raises(ValueError) as excinfo:
float_safe_check(
"embedding_dim",
0.0,
min_value=0.0,
max_value=8.0,
method=NumCheckValueMethod.LEFT_OPEN_INTERVAL.value,
)
assert "is less than or equal" in str(excinfo.value)
@staticmethod
def test_right_open_interval_method():
with pytest.raises(ValueError) as excinfo:
float_safe_check(
"embedding_dim",
8.0,
min_value=0.0,
max_value=8.0,
method=NumCheckValueMethod.RIGHT_OPEN_INTERVAL.value,
)
assert "is bigger than or equal" in str(excinfo.value)
@staticmethod
def test_method_not_support():
with pytest.raises(ValueError) as excinfo:
float_safe_check(
"embedding_dim",
8.0,
min_value=0.0,
max_value=8.0,
method="xxx",
)
assert "the check method supports" in str(excinfo.value)