from argparse import ArgumentParser
from mindspeed.features_manager import CoCFeature as MindSpeedCoCFeature


class CoCFeature(MindSpeedCoCFeature):
    def __init__(self):
        super().__init__()

    def validate_args(self, args):
        super().validate_args(args)

        if hasattr(args, 'lora_target_modules') and args.lora_target_modules and args.use_ascend_coc:
            raise AssertionError('CoC is not compatible in lora training.')

    def register_patches(self, patch_manager, args):
        if args.use_ascend_coc:
            from mindspeed.core.tensor_parallel.coc_feature.adaptor import MindSpeedCoCColumnParallelLinear
            from mindspeed.core.tensor_parallel.coc_feature.adaptor import MindSpeedCoCRowParallelLinear
            patch_manager.register_patch('megatron.core.tensor_parallel.layers.ColumnParallelLinear',
                                          MindSpeedCoCColumnParallelLinear)
            patch_manager.register_patch('megatron.core.tensor_parallel.layers.RowParallelLinear',
                                          MindSpeedCoCRowParallelLinear)