"""
Test docstring of external APIs.
How to run this:
pytest tests/st/test_docs/test_docstring.py
"""
import importlib
import inspect
import re
from typing import Callable
import pytest
from mindspore import nn
def dynamic_import(full_path: str) -> type:
"""Dynamically import a class or method."""
try:
module_name, obj_name = full_path.rsplit('.', 1)
module = importlib.import_module(module_name)
obj = getattr(module, obj_name)
return obj
except (ImportError, AttributeError) as e:
raise ImportError(f"Cannot import {full_path}: \"{e}\"."
f"If its import path is modified, modify the interface_map in this test case. ")
def has_custom_init(cls: type) -> bool:
"""check whether the class has a custom __init__"""
return '__init__' in cls.__dict__
def get_method_signature_of_cls(cls: type, method_name: str) -> inspect.Signature:
"""Gets the full signature of the class methods, including those inherited from the parent class."""
method = getattr(cls, method_name, None)
if method is None:
raise ValueError(f"The class {cls} does not have a method names {method_name}.")
for base_cls in cls.__mro__:
if method_name in base_cls.__dict__:
func = base_cls.__dict__[method_name]
if isinstance(func, (classmethod, staticmethod)):
return inspect.signature(func.__func__)
return inspect.signature(func)
raise ValueError(f"Method {method_name} not found in MRO (method resolution order) of class {cls.__name__}.")
def parse_args_in_docstring(api_name: str, docstring: str) -> dict:
"""Parses a Google-style docstring to extract parameter definitions in the Args section."""
args_pattern = re.compile(r"Args:\s*\n(.*?)(?=\n\s*(Inputs|Returns)|$)", re.DOTALL)
args_section = args_pattern.search(docstring)
if not args_section:
return {}
param_pattern = re.compile(r"^\s*\*{0,2}(\w+)\s\(([\w\s.,\[\]()`]+)\):\s(.*?)(?=\n\s*\w+\s*\(|$)", re.MULTILINE)
matches = param_pattern.findall(args_section.group(1))
if not matches:
raise ValueError(f"No formatted parameter information was found "
f"in the Args section of the interface {api_name}.")
params = {}
for param_name, param_type, _ in matches:
is_optional = "optional" in param_type.lower()
params[param_name] = {
"type": param_type.replace("optional", "").strip(','),
"required": not is_optional,
}
return params
def parse_inputs_in_docstring(api_name: str, docstring: str) -> dict:
"""Parses a Google-style docstring to extract parameter definitions in the Inputs section."""
inputs_pattern = re.compile(r"Inputs:\s*\n(.*?)(?=\n\s*Outputs|$)", re.DOTALL)
inputs_section = inputs_pattern.search(docstring)
if not inputs_section:
raise ValueError(f"The docstring of interface {api_name} does not define the Inputs section.")
param_pattern = re.compile(
r"^\s*-\s\*\*(?:\\\*)?(\w+)\*\*\s\(([\w\s.,\[\]()`]+)\)\s-\s(.*?)(?=\n\s*-\s*\*\*\w+\*\*\s*\(|$)",
re.MULTILINE)
matches = param_pattern.findall(inputs_section.group(1))
if not matches:
raise ValueError(f"No formatted parameter information was found "
f"in the Inputs section of the interface {api_name}.")
params = {}
for param_name, param_type, _ in matches:
is_optional = "optional" in param_type.lower()
params[param_name] = {
"type": param_type.replace("optional", "").strip(','),
"required": not is_optional,
}
return params
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
class TestDocstring:
"""A test class for testing docstring."""
def setup_class(self):
"""Init interface map."""
self.interface_map = {
"mindformers.AutoConfig": ["from_pretrained", "register", "show_support_list"],
"mindformers.AutoModel": ["from_config", "from_pretrained", "register"],
"mindformers.AutoModelForCausalLM": ["from_config", "from_pretrained", "register"],
"mindformers.AutoProcessor": ["from_pretrained", "register"],
"mindformers.AutoTokenizer": ["from_pretrained", "register"],
"mindformers.Trainer": ["evaluate", "finetune", "predict", "train"],
"mindformers.TrainingArguments": ["convert_args_to_mindformers_config", "get_moe_config",
"get_parallel_config", "get_recompute_config", "get_warmup_steps",
"set_dataloader", "set_logging", "set_lr_scheduler", "set_optimizer",
"set_save", "set_training"],
"mindformers.pipeline": None,
"mindformers.run_check": None,
"mindformers.ModelRunner": None,
"mindformers.core.build_context": None,
"mindformers.core.get_context": None,
"mindformers.core.init_context": None,
"mindformers.core.set_context": None,
"mindformers.core.CrossEntropyLoss": None,
"mindformers.core.AdamW": None,
"mindformers.core.Came": None,
"mindformers.core.LearningRateWiseLayer": None,
"mindformers.core.ConstantWarmUpLR": None,
"mindformers.core.LinearWithWarmUpLR": None,
"mindformers.core.CosineWithWarmUpLR": None,
"mindformers.core.CosineWithRestartsAndWarmUpLR": None,
"mindformers.core.PolynomialWithWarmUpLR": None,
"mindformers.core.CosineAnnealingLR": None,
"mindformers.core.CosineAnnealingWarmRestarts": None,
"mindformers.core.CheckpointMonitor": None,
"mindformers.core.EvalCallBack": None,
"mindformers.core.MFLossMonitor": None,
"mindformers.core.TrainingStateMonitor": None,
"mindformers.core.ProfileMonitor": None,
"mindformers.core.SummaryMonitor": None,
"mindformers.core.EntityScore": ["clear", "eval", "update"],
"mindformers.core.EmF1Metric": ["clear", "eval", "update"],
"mindformers.core.PerplexityMetric": ["clear", "eval", "update"],
"mindformers.core.PromptAccMetric": ["clear", "eval", "update"],
"mindformers.dataset.CausalLanguageModelDataset": None,
"mindformers.dataset.KeyWordGenDataset": None,
"mindformers.dataset.MultiTurnDataset": None,
"mindformers.generation.GenerationConfig": None,
"mindformers.generation.GenerationMixin": ["chat", "forward", "generate", "infer", "postprocess"],
"mindformers.models.PreTrainedModel": ["can_generate", "from_pretrained", "post_init",
"register_for_auto_class", "save_pretrained"],
"mindformers.models.PretrainedConfig": ["from_dict", "from_json_file", "from_pretrained",
"get_config_dict", "save_pretrained", "to_dict",
"to_diff_dict", "to_json_file", "to_json_string"],
"mindformers.models.PreTrainedTokenizer": ["convert_ids_to_tokens", "convert_tokens_to_ids",
"get_added_vocab", "num_special_tokens_to_add",
"prepare_for_tokenization", "tokenize"],
"mindformers.models.PreTrainedTokenizerFast": ["convert_ids_to_tokens", "convert_tokens_to_ids",
"get_added_vocab", "num_special_tokens_to_add",
"set_truncation_and_padding", "train_new_from_iterator"],
"mindformers.models.multi_modal.ModalContentTransformTemplate": ["batch", "build_conversation_input_text",
"build_labels", "build_modal_context",
"get_need_update_output_items",
"post_process", "process_predict_query",
"process_train_item"],
"mindformers.models.LlamaForCausalLM": None,
"mindformers.models.LlamaConfig": None,
"mindformers.models.LlamaTokenizer": ["build_inputs_with_special_tokens",
"create_token_type_ids_from_sequences",
"get_special_tokens_mask"],
"mindformers.models.LlamaTokenizerFast": ["build_inputs_with_special_tokens",
"save_vocabulary",
"update_post_processor"],
"mindformers.models.ChatGLM2ForConditionalGeneration": None,
"mindformers.models.ChatGLM2Config": None,
"mindformers.models.ChatGLM3Tokenizer": None,
"mindformers.models.ChatGLM4Tokenizer": None,
"mindformers.modules.OpParallelConfig": None,
"mindformers.pet.models.LoraModel": None,
"mindformers.pet.pet_config.PetConfig": None,
"mindformers.pet.pet_config.LoraConfig": None,
"mindformers.pipeline.MultiModalToTextPipeline": None,
"mindformers.tools.register.MindFormerModuleType": None,
"mindformers.tools.register.MindFormerRegister": ["get_cls", "get_instance", "get_instance_from_cfg",
"is_exist", "register", "register_cls"],
"mindformers.tools.MindFormerConfig": ["merge_from_dict"],
"mindformers.wrapper.MFTrainOneStepCell": None,
"mindformers.wrapper.MFPipelineWithLossScaleCell": None,
}
def setup_method(self):
"""Init errors list."""
self.error_list = []
def check_class_docstring(self, cls: type, full_path: str) -> None:
"""Checks that the arguments in the class docstring match the arguments in the class definition."""
docstring = inspect.getdoc(cls)
if not docstring:
self.error_list.append(f"{full_path}: Missing docstring.")
return
try:
doc_params = parse_args_in_docstring(full_path, docstring)
except ValueError as e:
self.error_list.append(str(e))
return
if has_custom_init(cls):
target_method_name = "__init__"
else:
target_method_name = "__new__"
try:
sig = get_method_signature_of_cls(cls, target_method_name)
except ValueError as e:
self.error_list.append(str(e))
return
init_params = sig.parameters
for name, param_doc in doc_params.items():
if name not in init_params:
self.error_list.append(f"{full_path}: Argument '{name}' in docstring is not defined in "
f"{target_method_name} method.")
elif param_doc["required"] and init_params[name].default is not inspect.Parameter.empty:
self.error_list.append(f"{full_path}: The docstring specifies that the argument '{name}' should be "
f"required, but there is a default value in the {target_method_name} method.")
for name in init_params:
if name not in doc_params and name != 'cls' and name != 'self' and name != 'args' and name != 'kwargs':
self.error_list.append(f"{full_path}: There is a redundant argument '{name}' in the "
f"{target_method_name} method that is not defined in the docstring or is "
f"defined in an incorrect format.")
if not issubclass(cls, nn.Cell):
return
if 'construct' not in cls.__dict__:
return
try:
inputs_params = parse_inputs_in_docstring(cls.__name__, docstring)
except ValueError as e:
self.error_list.append(str(e))
return
construct_method = getattr(cls, 'construct', None)
if not callable(construct_method):
raise ValueError(f"{full_path}: The class {cls.__name__} does not define construct method.")
sig = inspect.signature(construct_method)
construct_params = sig.parameters
for name, param_doc in inputs_params.items():
if name not in construct_params:
self.error_list.append(f"{full_path}: Argument '{name}' in docstring is not defined in construct "
f"method.")
elif param_doc["required"] and construct_params[name].default is not inspect.Parameter.empty:
self.error_list.append(f"{full_path}: The docstring specifies that the argument '{name}' should be "
f"required, but there is a default value in the function signature.")
for name in construct_params:
if name not in inputs_params and name != 'self' and name != 'kwargs':
self.error_list.append(f"{full_path}: There is a redundant argument '{name}' in the construct "
f"method that is not defined in the docstring or is defined in an incorrect "
f"format.")
def check_method_docstring(self, func: Callable, full_path: str, cls: type = None) -> None:
"""Checks that the arguments in the method docstring match the arguments in the method definition."""
docstring = inspect.getdoc(func)
if not docstring:
self.error_list.append(f"{full_path}: Missing docstring.")
return
if cls:
try:
sig = get_method_signature_of_cls(cls, func.__name__)
except ValueError as e:
self.error_list.append(str(e))
return
else:
sig = inspect.signature(func)
func_params = sig.parameters
try:
doc_params = parse_args_in_docstring(full_path, docstring)
except ValueError as e:
self.error_list.append(str(e))
return
for name, param_doc in doc_params.items():
if name not in func_params:
self.error_list.append(f"{full_path}: Argument '{name}' in docstring is not defined in the function "
f"signature.")
elif param_doc["required"] and func_params[name].default is not inspect.Parameter.empty:
self.error_list.append(f"{full_path}: The docstring specifies that the argument '{name}' should be "
f"required, but there is a default value in the function signature.")
for name in func_params:
if name not in doc_params and name != 'cls' and name != 'self' and name != 'args' and name != 'kwargs':
self.error_list.append(f"{full_path}: There is a redundant argument '{name}' in the function signature "
f"that is not defined in the docstring or is defined in an incorrect format.")
def test_docstring(self):
"""Test docstring for all the interfaces."""
for class_or_func_path, methods in self.interface_map.items():
obj = dynamic_import(class_or_func_path)
if inspect.isclass(obj):
self.check_class_docstring(obj, class_or_func_path)
if not methods:
continue
for method_name in methods:
if not hasattr(obj, method_name):
raise AttributeError(f"The class {class_or_func_path} does not have method {method_name}. "
f"If the method has been deleted, "
f"modify the interface_map in this test case. ")
method = getattr(obj, method_name)
self.check_method_docstring(method, class_or_func_path + "." + method_name, obj)
else:
self.check_method_docstring(obj, class_or_func_path)
if self.error_list:
raise AssertionError(f"Checking arguments consistency failed! The inconsistencies are listing below:\n" +
" ".join(f"{i + 1}. {s}" for i, s in enumerate(self.error_list)))