import uuid
from typing import Callable, List, Optional
from functools import wraps, partial
import torch
from megatron.training import get_args
from mindspeed.core.memory.compress.compress_optimizer import compress_optimizer_step_impl
from mindspeed.core.memory.compress.compress_activation import CompressHook, GlobalContext, GlobalContextConfig
def compress_optimizer_step(self, closure=None):
with torch.no_grad():
loss = compress_optimizer_step_impl(self, closure)
return loss
def layer_forward_wrapper(forward) -> Callable:
""" Main enterence of adaptive activation compression.
"""
@wraps(forward)
def wrapper(self, *args, **kwargs):
if self.layer_number not in get_args().compress_activation:
return forward(self, *args, **kwargs)
ctx = get_global_context()
ctx.statistic = get_statistic()
absolute_order = get_absolute_order(ctx)
if not hasattr(self, "layer_uuid"):
self.layer_uuid = uuid.uuid1()
order_layer_uuid = (self.layer_uuid, absolute_order)
ctx.pack_start(order_layer_uuid)
with CompressHook(order_layer_uuid, ctx):
result = forward(self, *args, **kwargs)
if result is None or result[0].grad_fn is None:
raise RuntimeError(
"Result is None or grad_fn is missing."
"Ensure the model is in training mode and gradients are enabled."
)
result[0].register_hook(partial(backward_start, order_layer_uuid))
return result
return wrapper
def backward_start(order_layer_uuid, *unused) -> None:
get_global_context().unpack_start(order_layer_uuid)
def get_global_context() -> GlobalContext:
args = get_args()
if not hasattr(args, "global_context"):
filter_funcs = [filter_func]
async_funcs = ["matmul", "allgather", "all2all"]
config = GlobalContextConfig(
get_moe_average_token_num(),
get_args().num_experts,
filter_funcs,
async_funcs
)
setattr(args, "global_context", GlobalContext(config))
return args.global_context
def get_statistic() -> int:
""" Determine whether each encoding requires re-statistical analysis of the PDF.
"""
increment = get_args().curr_iteration - get_args().iteration
return increment <= 3 or increment % 100 == 0
def get_moe_average_token_num() -> int:
""" Calculate the average number of tokens allocated per card
for the MoE model based on configuration.
"""
args = get_args()
return args.micro_batch_size * args.seq_length * args.tensor_model_parallel_size * args.moe_router_topk
def get_absolute_order(ctx: GlobalContext) -> int:
is_first_step = (get_args().curr_iteration - get_args().iteration) == 0
return ctx.get_absolute_order(is_first_step)
def filter_func(tensor: torch.Tensor) -> bool:
""" Exclude tensors that do not require compression.
"""
if tensor.grad_fn is None:
return False
if tensor.dtype == torch.float32:
return False
if tensor.numel() == 0 or tensor.numel() % 64 != 0 or tensor.numel() < 32768:
return False
if tensor.storage().size() == 0 or tensor.storage_offset() != 0 \
or tensor.storage().size() != tensor.numel():
return False
if tensor.grad_fn and "CheckpointWithout" in str(tensor.grad_fn):
return False
return True