from argparse import Namespace
from mindspeed.features_manager.feature import MindSpeedFeature
from mindspeed.patch_utils import is_megatron_training_available, MindSpeedPatchesManager
class OptimizeP2PCommFeature(MindSpeedFeature):
def __init__(
self,
feature_name: str = "pipeline-model-parallel-size",
optimization_level: int = 2,
):
super().__init__(feature_name, optimization_level)
def register_patches(
self,
patch_manager: MindSpeedPatchesManager,
args: Namespace,
):
from mindspeed.core.pipeline_parallel.optimize_p2p_comm.adaptor import (
core_transformer_config_from_args_wrapper,
)
if getattr(args, self.feature_name, None) and int(args.pipeline_model_parallel_size) >= 2:
if (
getattr(args, "num_layers_per_virtual_pipeline_stage", None) is None
and not (
getattr(args, "pipeline_model_parallel_layout", None) is not None
and getattr(args, "virtual_pipeline_model_parallel_size", None) is not None
)
):
megatron_training_available = is_megatron_training_available()
if megatron_training_available:
patch_manager.register_patch(
"megatron.training.arguments.core_transformer_config_from_args",
core_transformer_config_from_args_wrapper,
)