from mindspeed.features_manager.feature import MindSpeedFeature
class MHCFeature(MindSpeedFeature):
def __init__(self):
super(MHCFeature, self).__init__(feature_name="mhc", optimization_level=0)
def register_args(self, parser):
group = parser.add_argument_group(title=self.feature_name)
group.add_argument('--enable-mhc', action='store_true', default=False,
help='add mhc module in model.')
group.add_argument('--use-triton-rmsnorm-without-weight', action='store_true', default=False,
help='use triton rmsnorm-without-weight.')
group.add_argument('--use-triton-sinkhorn', action='store_true', default=False,
help='use triton sinkhorn.')
group.add_argument('--hc-mult', type=int, default=4,
help='dimension for index head number.')
group.add_argument('--hc-sinkhorn-iters', type=int, default=20,
help='dimension for index head dim.')
group.add_argument('--hc-eps', type=float, default=1e-6,
help='dimension for index head dim.')
group.add_argument('--use-triton-mhc', action='store_true', default=False,
help='use triton for pre/pos/only pre.')
group.add_argument('--mhc-recompute', action='store_true', default=False,
help='Fine-grained recompute for attention pre/pos and mlp pre.')
def register_patches(self, patch_manager, args):
if args.enable_mhc:
from mindspeed_llm.tasks.models.transformer.deepseek4.mhc.mhc import get_tensor_shapes_in_mhc
patch_manager.register_patch('megatron.core.pipeline_parallel.schedules.get_tensor_shapes', get_tensor_shapes_in_mhc)
if getattr(args, "num_layers_per_virtual_pipeline_stage", False) and args.num_layers_per_virtual_pipeline_stage is not None:
from mindspeed_llm.tasks.models.transformer.deepseek4.mhc.mhc import forward_backward_pipelining_with_interleaving_in_mhc
patch_manager.register_patch('megatron.core.pipeline_parallel.schedules.forward_backward_pipelining_with_interleaving', forward_backward_pipelining_with_interleaving_in_mhc)
def validate_args(self, args):
if args.enable_mhc:
if not args.multi_latent_attention:
raise ValueError("DSAIndexer is currently only supported in MLA, plese check model_spec and open --multi-latent-attention.")
if not args.use_flash_attn:
raise ValueError("DSAIndexer is currently only supported in FA, plese open --use-flash-attn.")
valid_algos = ['ulysses_cp_algo', 'kvallgather_cp_algo']
if args.context_parallel_size > 1 and args.context_parallel_algo not in valid_algos:
raise ValueError("DSAIndexer is currently only supported `ulysses_cp_algo` when use context parallel.")