from mindspeed.features_manager.moe.tp_extend_ep import MoETpExtendEpFeature as MindSpeedMoETpExtendEpFeature





class MoETpExtendEpFeature(MindSpeedMoETpExtendEpFeature):



    def register_patches(self, patch_manager, args):

        from mindspeed.core.transformer.moe.moe_feature.adaptor import MindSpeedAlltoAllSEQTptoEpMoELayer



        if hasattr(args, 'moe_token_dispatcher_type') and args.moe_token_dispatcher_type == 'alltoall_seq':

            if args.moe_tp_extend_ep:

                if not args.moe_alltoall_overlap_comm:

                    patch_manager.register_patch('megatron.core.transformer.moe.moe_layer.MoELayer',

                                                  MindSpeedAlltoAllSEQTptoEpMoELayer)