from mindspeed.features_manager.feature import MindSpeedFeature


class GroupedMatmulFeature(MindSpeedFeature):
    def __init__(self):
        super().__init__('grouped-matmul', optimization_level=0)

    def register_patches(self, patch_manager, args):
        from mindspeed.core.fusions.grouped_matmul import Ops, grouped_gemm_is_available, \
            assert_grouped_gemm_is_available, get_device_capability
        patch_manager.register_patch('megatron.core.transformer.moe.grouped_gemm_util.ops', Ops)
        patch_manager.register_patch('megatron.core.transformer.moe.grouped_gemm_util.grouped_gemm_is_available',
                                     grouped_gemm_is_available)
        patch_manager.register_patch('megatron.core.transformer.moe.grouped_gemm_util.assert_grouped_gemm_is_available',
                                     assert_grouped_gemm_is_available)
        patch_manager.register_patch('torch.cuda.get_device_capability', get_device_capability)