"""
We can't use assert in our code for codecheck, so create this auxiliary function to wrap
the assert case in ut for ci.
"""


def judge_expression(expression):
    if not expression:
        raise AssertionError


class TestConfig(object):
    def __init__(self, entries):
        for k, v in entries.items():
            if isinstance(v, dict):
                self.__dict__[k] = TestConfig(v)
            else:
                self.__dict__[k] = v

    def to_dict(self):
        ret = {}
        for k, v in self.__dict__.items():
            if isinstance(v, self.__class__):
                ret[k] = v.to_dict()
            else:
                ret[k] = v
        return ret


def initialize_model_parallel(
    tensor_model_parallel_size=1,
    pipeline_model_parallel_size=1,
    virtual_pipeline_model_parallel_size=None,
    pipeline_model_parallel_split_rank=None,
    context_parallel_size=1,
):
    import megatron.core.parallel_state as ps
    ps.destroy_model_parallel()
    ps.initialize_model_parallel(
        tensor_model_parallel_size=tensor_model_parallel_size,
        pipeline_model_parallel_size=pipeline_model_parallel_size,
        virtual_pipeline_model_parallel_size=virtual_pipeline_model_parallel_size,
        pipeline_model_parallel_split_rank=pipeline_model_parallel_split_rank,
        context_parallel_size=context_parallel_size,
    )


def clear_module(module_name):
    import sys
    modules_to_delete = [key for key in sys.modules.keys() if key.startswith(module_name)]
    for key in modules_to_delete:
        del sys.modules[key]