from functools import wraps
from typing import List
import torch
from megatron.core import parallel_state
from megatron.core.enums import ModelType
from megatron.core.optimizer.optimizer_config import OptimizerConfig
from megatron.core.pipeline_parallel.schedules import custom_backward
from megatron.core.transformer.module import MegatronModule
from mindspeed.core.performance.auto_pipeline_perf.schedules import (
backward_step_decorator,
)
from mindspeed.megatron_adaptor import get_mindspeed_args
from mindspeed.patch_utils import MindSpeedPatchesManager as mspm
def _get_megatron_optimizer_based_on_param_groups_wrapper(fn):
@wraps(fn)
def wrapper(
config: OptimizerConfig,
model_chunks: List[MegatronModule],
param_groups: List,
*args,
**kwargs
):
if len(param_groups) == 0:
param_groups = [
{
"params": torch.nn.parameter.Parameter(
data=torch.Tensor(), requires_grad=True
),
"is_expert_parallel": False,
"is_decoupled_lr": False,
}
]
return fn(config, model_chunks, param_groups, *args, **kwargs)
return wrapper
def backward_step(input_tensor, output_tensor, output_tensor_grad, model_type, config):
"""Backward step through passed-in output tensor.
If last stage, output_tensor_grad is None, otherwise gradient of loss
with respect to stage's output tensor.
Returns gradient of loss with respect to input tensor (None if first
stage).
"""
if config.timers is not None:
config.timers("backward-compute", log_level=2).start()
unwrap_input_tensor_grad = False
if not isinstance(input_tensor, list):
input_tensor = [input_tensor]
unwrap_input_tensor_grad = True
for x in input_tensor:
if x is not None:
x.retain_grad()
if not isinstance(output_tensor, list):
output_tensor = [output_tensor]
if not isinstance(output_tensor_grad, list):
output_tensor_grad = [output_tensor_grad]
if output_tensor_grad[0] is None and config.grad_scale_func is not None:
output_tensor[0] = config.grad_scale_func(output_tensor[0])
if output_tensor[0].grad_fn is not None:
if config.deallocate_pipeline_outputs:
custom_backward(output_tensor[0], output_tensor_grad[0])
else:
torch.autograd.backward(output_tensor[0], grad_tensors=output_tensor_grad[0])
input_tensor_grad = [None]
if input_tensor is not None:
input_tensor_grad = []
for x in input_tensor:
if x is None:
input_tensor_grad.append(None)
else:
input_tensor_grad.append(x.grad)
if (
parallel_state.get_pipeline_model_parallel_world_size() > 1
and parallel_state.is_pipeline_stage_after_split()
and model_type == ModelType.encoder_and_decoder
):
if output_tensor_grad[1] is not None:
input_tensor_grad[-1].add_(output_tensor_grad[1])
if unwrap_input_tensor_grad:
input_tensor_grad = input_tensor_grad[0]
if config.timers is not None:
config.timers("backward-compute").stop()
return input_tensor_grad
mindspeed_args = get_mindspeed_args()
if (
hasattr(mindspeed_args, "enable_dummy_optimizer")
and mindspeed_args.enable_dummy_optimizer
):
mspm.register_patch(
"megatron.core.optimizer._get_megatron_optimizer_based_on_param_groups",
_get_megatron_optimizer_based_on_param_groups_wrapper,
)
mspm.register_patch(
"megatron.core.pipeline_parallel.schedules.backward_step",
backward_step,
force_patch=True,
)
mspm.apply_patches()