import os
import importlib
from functools import lru_cache
from typing import Type, Any, Dict, List, Optional
from einops import rearrange
import torch
import torch.distributed
import numpy as np
from megatron.core import mpu
from megatron.training import get_args
from megatron.core.packed_seq_params import PackedSeqParams
from mindspeed.utils import get_actual_seq_len
def _ceil_div(x: int, y: int) -> int:
return (x + y - 1) // y
def build_padded_lens_from_cu_seqlens(
actual_seq_len: torch.Tensor,
cp_size: int,
pad_multiple: int = None,
):
"""
input:
actual_seq_len: cumulative seqlens, eg [4, 9, 15]
output:
raw_lens: [4, 5, 6]
padded_lens: Length after padding to pad_multiple.
"""
if pad_multiple is None:
pad_multiple = 2 * cp_size
actual_seq_len = actual_seq_len.to(torch.int64)
device = actual_seq_len.device
starts = torch.cat([
torch.zeros(1, dtype=torch.int64, device=device),
actual_seq_len[:-1]
])
raw_lens = actual_seq_len - starts
padded_lens = torch.tensor(
[_ceil_div(int(x.item()), pad_multiple) * pad_multiple for x in raw_lens],
dtype=torch.int64,
device=device,
)
return raw_lens, padded_lens
def get_packed_seq_len(actual_seq_len, cp_size):
pad_multiple = 2 * cp_size
_, padded_lens = build_padded_lens_from_cu_seqlens(
actual_seq_len=actual_seq_len,
cp_size=cp_size,
pad_multiple=pad_multiple,
)
cu_seqlens_padded = torch.cumsum(padded_lens, dim=0)
return cu_seqlens_padded[-1]
def get_packed_seq_params(
actual_seq_len: torch.Tensor,
cp_size: int,
pad_multiple: int = None,
):
"""Constructs PackedSeqParams and shapes for ringattn_context_parallel (TND)."""
if pad_multiple is None:
pad_multiple = 2 * cp_size
packed_seq_params = PackedSeqParams(
qkv_format='thd',
cu_seqlens_q=actual_seq_len,
cu_seqlens_kv=actual_seq_len
)
raw_lens, padded_lens = build_padded_lens_from_cu_seqlens(
actual_seq_len=actual_seq_len,
cp_size=cp_size,
pad_multiple=pad_multiple,
)
cu_seqlens_padded = torch.cumsum(padded_lens, dim=0)
packed_seq_params = PackedSeqParams(
qkv_format='thd',
cu_seqlens_q=actual_seq_len,
cu_seqlens_kv=actual_seq_len,
)
packed_seq_params.cu_seqlens_q_padded = cu_seqlens_padded
packed_seq_params.cu_seqlens_kv_padded = cu_seqlens_padded
packed_seq_params.max_seqlen_q = int(raw_lens.max().item()) if raw_lens.numel() > 0 else 0
packed_seq_params.max_seqlen_kv = int(raw_lens.max().item()) if raw_lens.numel() > 0 else 0
packed_seq_params.q_index = None
packed_seq_params.kv_index = None
local_total_len = int((padded_lens // cp_size).sum().item())
shapes = [local_total_len for _ in range(cp_size)]
return packed_seq_params, shapes
class Registry:
"""A generic class registry system that automatically uses class names as registration keys.
Features:
- Automatic registration using class names
- Prohibition of manual name specification
- Class name conflict detection
"""
_REGISTRY: Dict[str, Type[Any]] = {}
"""Internal registry storage mapping class names to their corresponding class objects"""
@classmethod
def register(cls, target_class: Type[Any]) -> Type[Any]:
"""Class decorator for automatic registration using the class name.
Args:
target_class: Target class to be registered
Returns:
The original class object to preserve class definition
Raises:
ValueError: If class name is already registered
"""
class_name = target_class.__name__
if class_name in cls._REGISTRY:
existing = cls._REGISTRY[class_name]
raise ValueError(
f"Class name conflict: '{class_name}' already registered by {existing}, "
f"attempting to register: {target_class}"
)
cls._REGISTRY[class_name] = target_class
return target_class
@classmethod
def get_class(cls, name: str) -> Type[Any]:
"""Retrieve a registered class by its name.
Args:
name: Name of the class to retrieve
Returns:
The registered class object
Raises:
ValueError: If the class is not found in registry
"""
if name not in cls._REGISTRY:
available = list(cls._REGISTRY.keys())
raise ValueError(
f"Class '{name}' not found in registry. Available classes: {available}"
)
return cls._REGISTRY[name]
@lru_cache
def is_npu_available():
"""Checks if `torch_npu` is installed and potentially if a NPU is in the environment"""
if importlib.util.find_spec("torch_npu") is None:
return False
import torch_npu
try:
_ = torch.npu.device_count()
return torch.npu.is_available()
except RuntimeError:
return False
def get_device(device="npu"):
"""
only support npu and cpu device, default npu.
device format: cpu, npu, or npu:0
"""
if isinstance(device, torch.device):
return device
device = device.lower().strip()
if device == "cpu":
return torch.device(device)
device_infos = device.split(":")
device_name = device_infos[0]
if device_name == "npu":
if is_npu_available():
if len(device_infos) == 1:
return torch.device(device_name)
if len(device_infos) == 2:
device_id = int(device_infos[1])
num_devices = torch.npu.device_count()
if device_id < num_devices:
return torch.device(f"{device_name}:{device_id}")
else:
raise ValueError(f"device_id: {device_id} must less than device nums: {num_devices}")
else:
raise RuntimeError("NPU environment is not available")
raise ValueError("only support npu and cpu device. device format: cpu, npu, or npu:0")
def get_dtype(dtype):
"""return torch type according to the string"""
if isinstance(dtype, torch.dtype):
return dtype
dtype_mapping = {
"int32": torch.int32,
"float64": torch.float64,
"float32": torch.float32,
"float16": torch.float16,
"fp32": torch.float32,
"fp16": torch.float16,
"half": torch.float16,
"bf16": torch.bfloat16,
}
if dtype not in dtype_mapping:
raise ValueError("Unsupported data type")
dtype = dtype_mapping[dtype]
return dtype
def video_to_image(func):
def wrapper(self, x, *args, **kwargs):
if x.dim() == 5:
t = x.shape[2]
x = rearrange(x, "b c t h w -> (b t) c h w")
x = func(self, x, *args, **kwargs)
x = rearrange(x, "(b t) c h w -> b c t h w", t=t)
else:
x = func(self, x, *args, **kwargs)
return x
return wrapper
def cast_tuple(t, length=1):
return t if isinstance(t, tuple) or isinstance(t, list) else ((t,) * length)
def quick_gelu(x: torch.Tensor) -> torch.Tensor:
return x * torch.sigmoid(1.702 * x)
def gelu_tanh(inp: torch.Tensor) -> torch.Tensor:
return torch.nn.functional.gelu(inp, approximate="tanh")
def set_modules_requires_grad(modules, requires_grad):
for module in modules:
module.requires_grad_(requires_grad)
def save_ae_checkpoint(
epoch,
current_step,
optimizer_state,
state_dict,
scaler_state,
sampler_state,
checkpoint_dir,
filename="checkpoint.ckpt",
ema_state_dict=None,
):
if not os.path.exists(checkpoint_dir):
os.makedirs(checkpoint_dir)
filepath = os.path.join(checkpoint_dir, filename)
torch.save(
{
"epoch": epoch,
"current_step": current_step,
"optimizer_state": optimizer_state,
"state_dict": state_dict,
"ema_state_dict": ema_state_dict,
"scaler_state": scaler_state,
"sampler_state": sampler_state,
},
filepath,
)
return filepath
_CONTEXT_PARALLEL_GROUP = None
_CONTEXT_PARALLEL_SIZE = None
def is_context_parallel_initialized():
if _CONTEXT_PARALLEL_GROUP is None:
return False
else:
return True
def set_context_parallel_group(size, group):
global _CONTEXT_PARALLEL_GROUP
global _CONTEXT_PARALLEL_SIZE
_CONTEXT_PARALLEL_GROUP = group
_CONTEXT_PARALLEL_SIZE = size
def initialize_context_parallel(context_parallel_size):
global _CONTEXT_PARALLEL_GROUP
global _CONTEXT_PARALLEL_SIZE
if _CONTEXT_PARALLEL_GROUP is not None:
raise AssertionError("Context parallel group is already initialized")
_CONTEXT_PARALLEL_SIZE = context_parallel_size
rank = torch.distributed.get_rank()
world_size = torch.distributed.get_world_size()
for i in range(0, world_size, context_parallel_size):
ranks = range(i, i + context_parallel_size)
group = torch.distributed.new_group(ranks)
if rank in ranks:
_CONTEXT_PARALLEL_GROUP = group
break
def get_context_parallel_group():
if _CONTEXT_PARALLEL_GROUP is None:
raise AssertionError("Context parallel group is not initialized")
return _CONTEXT_PARALLEL_GROUP
def get_context_parallel_world_size():
if _CONTEXT_PARALLEL_SIZE is None:
raise AssertionError("Context parallel size is not initialized")
return _CONTEXT_PARALLEL_SIZE
def get_context_parallel_rank():
if _CONTEXT_PARALLEL_SIZE is None:
raise AssertionError("Context parallel size is not initialized")
rank = torch.distributed.get_rank()
cp_rank = rank % _CONTEXT_PARALLEL_SIZE
return cp_rank
def get_context_parallel_group_rank():
if _CONTEXT_PARALLEL_SIZE is None:
raise AssertionError("Context parallel size is not initialized")
rank = torch.distributed.get_rank()
cp_group_rank = rank // _CONTEXT_PARALLEL_SIZE
return cp_group_rank
class IsNotValidError(Exception):
def __init__(self, error_message=None):
self.error_message = error_message
super().__init__(error_message or "Expression is not valid")
def __str__(self):
return self.error_message or "Expression is not valid"
def ensure_valid(expression, error_message=None):
if not expression:
raise IsNotValidError(error_message)
def dist_sort(image_num_list):
world_size = len(image_num_list)
total_images = sum(image_num_list)
avg = total_images // world_size
remainder = total_images % world_size
more_rank = avg + 1
target = [avg] * world_size
index_list = [[] for _ in range(world_size)]
index = 0
for i in range(world_size):
index_list[i].extend([j for j in range(index, index + image_num_list[i])])
index += image_num_list[i]
for index, image in enumerate(image_num_list):
if remainder and image > avg:
target[index] = more_rank
remainder -= 1
index = image_num_list.argsort()
for i in range(remainder):
target[index[i]] = more_rank
transfer = np.zeros((world_size, world_size), dtype=int)
surplus = []
deficit = []
for i in range(world_size):
if image_num_list[i] > target[i]:
surplus.append(i)
elif image_num_list[i] < target[i]:
deficit.append(i)
while surplus and deficit:
s = surplus[-1]
d = deficit[-1]
give = min(image_num_list[s] - target[s], target[d] - image_num_list[d])
image_num_list[s] -= give
image_num_list[d] += give
transfer[s][d] += give
if image_num_list[s] == target[s]:
surplus.pop()
if image_num_list[d] == target[d]:
deficit.pop()
return transfer, target
def unwrap_single(x: list):
while isinstance(x, list) and len(x) == 1:
x = x[0]
return x
class EncoderBalanceComm(torch.autograd.Function):
@staticmethod
def forward(ctx, input_tensor, group, transfer=None, nopadding=False, skip=False):
ctx.no_bk = transfer is None
rank = torch.distributed.get_rank(group=group)
ctx.shape = list(input_tensor.shape)
if transfer is not None:
transfer, target = transfer
input_tensor = input_tensor[:target[rank]].contiguous() if not nopadding else input_tensor
image_shape = input_tensor.shape
ctx.shape[1] -= input_tensor.shape[1]
image_num = image_shape[0]
ishape = image_shape[1:]
world_size = torch.distributed.get_world_size(group)
ctx.group = group
ctx.rank = rank
ctx.world_size = world_size
if transfer is None:
shape_input = torch.tensor([image_num], dtype=torch.int8).cuda()
shape_output = torch.empty([world_size, *shape_input.shape], dtype=shape_input.dtype).cuda()
torch.distributed._all_gather_base(shape_output, shape_input, group=group)
image_num_list = shape_output.cpu().numpy().reshape(-1)
transfer, target = dist_sort(image_num_list)
ctx.transfer = [transfer.T, target]
if skip:
return input_tensor, [transfer.T, target]
if np.sum(transfer) == 0:
if ctx.no_bk:
return input_tensor, [transfer.T, target]
else:
return input_tensor
send_img_num = sum(transfer[rank])
send_img = list(
torch.split(
input_tensor[image_num - send_img_num:].contiguous(),
transfer[rank].tolist(),
dim=0)
)
output = input_tensor[:image_num - send_img_num]
transfer = transfer.T
recv = torch.empty_like(input_tensor).resize_([sum(transfer[rank]), *ishape])
recv = list(torch.split(recv, transfer[rank].tolist(), dim=0))
torch.distributed.all_to_all(recv, send_img, group=group)
recv = torch.cat([output] + recv, dim=0)
if not ctx.no_bk:
return recv
return recv, [transfer, target]
@staticmethod
def backward(ctx, grad_output):
if ctx.no_bk or np.sum(ctx.transfer[0]) == 0:
return grad_output, None, None, None, None
else:
data = EncoderBalanceComm.apply(grad_output, ctx.group, ctx.transfer, True)
return data, None, None, None, None
def change_tensor_layout(tensor, src_layout, dst_layout, batch_size=None):
"""
Transforms the input tensor from the source layout (src_layout) to the target layout (dst_layout).
Args:
tensor (torch.Tensor): The input tensor.
src_layout (str): The source layout, e.g., "sbh" or "bsh".
dst_layout (str): The target layout, e.g., "sbnd" or "tnd".
Returns:
torch.Tensor: The tensor with the transformed layout.
"""
src_layout = src_layout.lower()
dst_layout = dst_layout.lower()
if src_layout == dst_layout:
return tensor
key = (src_layout, dst_layout)
layout_mappings = {
("bsh", "sbh"): lambda x: rearrange(x, "b s h -> s b h"),
("sbnd", "sbh"): lambda x: rearrange(x, "s b n d -> s b (n d)"),
("sbnd", "bsnd"): lambda x: rearrange(x, "s b n d -> b s n d"),
("sbnd", "bnsd"): lambda x: rearrange(x, "s b n d -> b n s d"),
("sbnd", "tnd"): lambda x: rearrange(x, "s b n d -> (s b) n d"),
("bsnd", "sbh"): lambda x: rearrange(x, "b s n d -> s b (n d)"),
("bnsd", "sbh"): lambda x: rearrange(x, "b n s d -> s b (n d)"),
("tnd", "sbh"): lambda x: rearrange(x, "(s b) n d -> s b (n d)", b=batch_size),
("sbh", "bsh"): lambda x: rearrange(x, "s b h -> b s h"),
("bsnd", "bsh"): lambda x: rearrange(x, "b s n d -> b s (n d)"),
("bnsd", "bsh"): lambda x: rearrange(x, "b n s d -> b s (n d)"),
("tnd", "bsh"): lambda x: rearrange(x, "(s b) n d -> b s (n d)", b=batch_size),
}
if key in layout_mappings:
if isinstance(tensor, torch.Tensor):
return layout_mappings[key](tensor)
elif isinstance(tensor, (list, tuple)):
return [layout_mappings[key](t) for t in tensor]
else:
raise ValueError(f"Unsupported input type {type(tensor)}")
else:
raise ValueError(f"Unsupported layout conversion from {src_layout} to {dst_layout}!")
def reorder_output(attn_output, cp_rank, cp_size, cp_group, dim=0):
index_this_rank = torch.tensor([cp_rank, (2 * cp_size - cp_rank - 1)], dtype=torch.int8, device=attn_output.device)
index_list = [torch.zeros_like(index_this_rank, device=attn_output.device) for _ in range(cp_size)]
torch.distributed.all_gather(index_list, index_this_rank, group=cp_group)
index_list = [int(item) for item in list(torch.concat(index_list))]
index_map = {element: idx for idx, element in enumerate(index_list)}
target = [i for i in range(len(index_list))]
target_list = [index_map[element] for element in target]
chunks = torch.chunk(attn_output, chunks=len(target_list), dim=dim)
reordered_chunks = [chunks[idx] for idx in target_list]
attn_output = torch.concat(reordered_chunks, dim=dim)
return attn_output
def _gather(
input_: torch.Tensor,
pg: torch.distributed.ProcessGroup,
dim: int = -1,
gather_size: List = None
):
input_ = input_.contiguous()
world_size = torch.distributed.get_world_size(group=pg)
if input_.device.type not in ["cpu", "npu"]:
raise AssertionError(f"Only support cpu and npu device, got {input_.device}")
if world_size == 1:
return input_
if gather_size is not None:
tensor_list = []
tensor_shape_base = input_.size()
for i in range(world_size):
tensor_shape = list(tensor_shape_base)
tensor_shape[dim] = gather_size[i]
tensor_list.append(torch.empty(tensor_shape, dtype=input_.dtype, device=input_.device))
else:
tensor_list = [torch.empty_like(input_) for _ in range(world_size)]
torch.distributed.all_gather(tensor_list, input_, group=pg)
output = torch.cat(tensor_list, dim=dim).contiguous()
return output
class _SplitForwardGatherBackWardWithMegatronCP(torch.autograd.Function):
'''
Split the input tensor in the forward pass and gather the gradients in the backward pass.
It will be implemented in Mindspeed in the future.
'''
@staticmethod
def forward(ctx, val, cp_rank, cp_size, seq_dim, cp_group=None):
val = val.view(
*val.shape[0:seq_dim],
2 * cp_size,
val.shape[seq_dim] // (2 * cp_size),
*val.shape[(seq_dim + 1):],
)
index = torch.tensor([cp_rank, (2 * cp_size - cp_rank - 1)], device=val.device)
val = val.index_select(seq_dim, index)
val = val.view(*val.shape[0:seq_dim], -1, *val.shape[(seq_dim + 2):])
ctx.cp_group = cp_group
ctx.cp_rank = cp_rank
ctx.cp_size = cp_size
ctx.seq_dim = seq_dim
return val
@staticmethod
def backward(ctx, grad_output):
grad_input = {}
grad_input = _gather(grad_output, ctx.cp_group, dim=ctx.seq_dim) / ctx.cp_size
grad_input = reorder_output(grad_input, ctx.cp_rank, ctx.cp_size, ctx.cp_group, dim=ctx.seq_dim)
return grad_input, None, None, None, None
def split_forward_gather_backward_with_megatron_cp(
input_: torch.Tensor,
process_group: torch.distributed.ProcessGroup,
dim: int = 0
) -> torch.Tensor:
cp_size = torch.distributed.get_world_size(group=process_group)
cp_rank = torch.distributed.get_rank(group=process_group)
return _SplitForwardGatherBackWardWithMegatronCP.apply(input_, cp_rank, cp_size, dim, process_group)
class _GatherForwardSplitBackWardWithMegatronCP(torch.autograd.Function):
'''
Split the input tensor in the forward pass and gather the gradients in the backward pass with megatron cp(Ring Attention)
It will be implemented in Mindspeed in the future.
'''
@staticmethod
def forward(ctx, val, seq_dim, cp_group=None):
cp_rank = torch.distributed.get_rank(group=cp_group)
cp_size = torch.distributed.get_world_size(group=cp_group)
val = _gather(val, cp_group, dim=seq_dim)
val = reorder_output(val, cp_rank, cp_size, cp_group, dim=seq_dim)
ctx.cp_group = cp_group
ctx.cp_rank = cp_rank
ctx.cp_size = cp_size
ctx.seq_dim = seq_dim
return val
@staticmethod
def backward(ctx, grad_output):
cp_group = ctx.cp_group
cp_rank = ctx.cp_rank
cp_size = ctx.cp_size
seq_dim = ctx.seq_dim
grad_output = grad_output.view(
*grad_output.shape[0:seq_dim],
2 * cp_size,
grad_output.shape[seq_dim] // (2 * cp_size),
*grad_output.shape[(seq_dim + 1):],
) * cp_size
index = torch.tensor([cp_rank, (2 * cp_size - cp_rank - 1)], device=grad_output.device)
grad_output = grad_output.index_select(seq_dim, index)
grad_input = grad_output.view(*grad_output.shape[0:seq_dim], -1, *grad_output.shape[(seq_dim + 2):])
return grad_input, None, None
def gather_forward_split_backward_with_megatron_cp(
input_: torch.Tensor,
process_group: torch.distributed.ProcessGroup,
dim: int = 0,
pad_multiple=None
) -> torch.Tensor:
actual_seq_len = get_actual_seq_len()
if actual_seq_len is not None:
return _GatherForwardSplitBackwardWithMegatronCPTND.apply(input_, dim, actual_seq_len, pad_multiple, process_group)
return _GatherForwardSplitBackWardWithMegatronCP.apply(input_, dim, process_group)
def get_index(actual_seq_len_cpu, cp_rank, cp_size):
"""
Parameters:
actual_seq_len_cpu: 1D tensor, cumulative end positions.
For example, [4, 9, 15] indicates three segments:
[0, 4), [4, 9), [9, 15)
cp_rank: current rank
cp_size: context parallel size
Returns: index (1D tensor) corresponding to the current rank.
"""
starts = torch.cat([torch.tensor([0]), actual_seq_len_cpu[:-1]])
ends = actual_seq_len_cpu
chunk_sizes = (ends - starts) // (2 * cp_size)
first_starts = starts + cp_rank * chunk_sizes
first_ends = first_starts + chunk_sizes
second_starts = ends - (cp_rank + 1) * chunk_sizes
second_ends = ends - cp_rank * chunk_sizes
all_indices = []
for i in range(actual_seq_len_cpu.shape[0]):
all_indices.append(torch.arange(first_starts[i], first_ends[i]))
all_indices.append(torch.arange(second_starts[i], second_ends[i]))
index = torch.cat(all_indices)
return index.to('npu')
def pad_input(input, raw_lens, padded_lens, dim=0, pad_val=0):
out_shape = list(input.shape)
out_shape[dim] = sum(padded_lens)
output = input.new_full(out_shape, pad_val)
in_start = 0
out_start = 0
for raw_len, padded_len in zip(raw_lens, padded_lens):
in_slices = [slice(None)] * input.dim()
out_slices = [slice(None)] * input.dim()
in_slices[dim] = slice(in_start, in_start + raw_len)
out_slices[dim] = slice(out_start, out_start + raw_len)
output[tuple(out_slices)] = input[tuple(in_slices)]
in_start += raw_len
out_start += padded_len
return output
def unpad_input(input, raw_lens, padded_lens, dim=0):
out_shape = list(input.shape)
out_shape[dim] = sum(raw_lens)
output = input.new_zeros(out_shape)
in_start = 0
out_start = 0
for raw_len, padded_len in zip(raw_lens, padded_lens):
in_slices = [slice(None)] * input.dim()
out_slices = [slice(None)] * input.dim()
in_slices[dim] = slice(in_start, in_start + raw_len)
out_slices[dim] = slice(out_start, out_start + raw_len)
output[tuple(out_slices)] = input[tuple(in_slices)]
in_start += padded_len
out_start += raw_len
return output
class _SplitForwardGatherBackWardWithMegatronCPTND(torch.autograd.Function):
@staticmethod
def forward(ctx, val, seq_dim, actual_seq_len, pad_multiple=None, pad_val=0, cp_group=None):
cp_rank = torch.distributed.get_rank(group=cp_group)
cp_size = torch.distributed.get_world_size(group=cp_group)
if pad_multiple is None:
pad_multiple = 2 * cp_size
raw_lens, padded_lens = build_padded_lens_from_cu_seqlens(
actual_seq_len=actual_seq_len,
cp_size=cp_size,
pad_multiple=pad_multiple,
)
padded_val = pad_input(val, raw_lens, padded_lens, seq_dim, pad_val=pad_val)
padded_cu_seqlens = torch.cumsum(
torch.tensor(
padded_lens,
device=actual_seq_len.device,
dtype=actual_seq_len.dtype,
),
dim=0
)
index = get_index(padded_cu_seqlens.cpu(), cp_rank, cp_size)
ctx.seq_dim = seq_dim
ctx.cp_group = cp_group
ctx.input_shape = val.shape
ctx.raw_lens = raw_lens
ctx.padded_lens = padded_lens
ctx.padded_shape = padded_val.shape
ctx.save_for_backward(index)
out = torch.index_select(padded_val, seq_dim, index)
return out
@staticmethod
def backward(ctx, grad_output):
(index,) = ctx.saved_tensors
seq_dim = ctx.seq_dim
input_shape = ctx.input_shape
raw_lens = ctx.raw_lens
padded_lens = ctx.padded_lens
padded_shape = ctx.padded_shape
grad_padded = grad_output.new_zeros(padded_shape)
grad_padded.index_add_(seq_dim, index, grad_output)
grad_val = unpad_input(grad_padded, raw_lens, padded_lens, seq_dim)
grad_val = grad_val.view(input_shape)
return grad_val, None, None, None, None, None
def split_forward_gather_backward_with_megatron_cp_tnd(
input_: torch.Tensor,
process_group: torch.distributed.ProcessGroup,
dim: int = 0,
actual_seq_len: torch.Tensor = None,
pad_multiple: int = None,
pad_val: float = 0
) -> torch.Tensor:
"""
From the full packed token stream, the local contiguous blocks of the current rank are obtained according to the rules compatible with the ring TND CP,
and the padding is added to the part of each subsequence that is less than local_len.
Rules:
- Length of each subsequence: raw_len
- Padding to padded_len = ceil(raw_len / pad_multiple) * pad_multiple
- Each rank should have local_len = padded_len // cp_size in the subsequence.
- To balance the ring CP load, the current rank obtains [rank * local_len : rank * local_len + local_len/cp, (rank + 1) * local - local_len/cp, (rank + 1) * local].
- If the actual token is less than local_len, pad_value is added locally to ensure that the output length of all ranks is the same.
The output length of all ranks is strictly the same.
"""
return _SplitForwardGatherBackWardWithMegatronCPTND.apply(input_, dim, actual_seq_len, pad_multiple, pad_val, process_group)
class _GatherForwardSplitBackwardWithMegatronCPTND(torch.autograd.Function):
@staticmethod
def forward(ctx, val, seq_dim, actual_seq_len, pad_multiple=None, cp_group=None):
cp_rank = torch.distributed.get_rank(group=cp_group)
cp_size = torch.distributed.get_world_size(group=cp_group)
if pad_multiple is None:
pad_multiple = 2 * cp_size
raw_lens, padded_lens = build_padded_lens_from_cu_seqlens(
actual_seq_len=actual_seq_len,
cp_size=cp_size,
pad_multiple=pad_multiple,
)
padded_cu_seqlens = torch.cumsum(
torch.tensor(
padded_lens,
device=actual_seq_len.device,
dtype=actual_seq_len.dtype,
),
dim=0,
)
local_index = get_index(padded_cu_seqlens.cpu(), cp_rank, cp_size)
gathered_vals = [torch.empty_like(val) for _ in range(cp_size)]
torch.distributed.all_gather(gathered_vals, val, group=cp_group)
padded_shape = list(val.shape)
padded_shape[seq_dim] = sum(padded_lens)
padded_val = val.new_zeros(padded_shape)
for rank, rank_val in enumerate(gathered_vals):
rank_index = get_index(padded_cu_seqlens.cpu(), rank, cp_size)
padded_val.index_copy_(seq_dim, rank_index, rank_val)
out = unpad_input(padded_val, raw_lens, padded_lens, seq_dim)
ctx.seq_dim = seq_dim
ctx.cp_group = cp_group
ctx.raw_lens = raw_lens
ctx.padded_lens = padded_lens
ctx.local_index = local_index
ctx.padded_shape = tuple(padded_shape)
return out
@staticmethod
def backward(ctx, grad_output):
seq_dim = ctx.seq_dim
raw_lens = ctx.raw_lens
padded_lens = ctx.padded_lens
local_index = ctx.local_index
grad_padded = pad_input(grad_output, raw_lens, padded_lens, seq_dim)
grad_val = torch.index_select(grad_padded, seq_dim, local_index)
return grad_val, None, None, None, None
def compute_token_level_loss(loss_dict):
"""Token level loss function"""
args = get_args()
if args.context_parallel_size > 1:
loss = loss_dict['loss']
total_tokens = loss_dict["token_nums"]
loss = torch.cat([loss.sum().view(1), total_tokens.sum().view(1)])
else:
loss = loss_dict['loss']
loss_mask = loss_dict['loss_mask']
loss_mask = loss_mask.view(-1).float()
total_tokens = loss_mask.sum()
if loss.view(-1).shape == loss_mask.shape:
loss = torch.cat([torch.sum(loss.view(-1) * loss_mask).view(1), total_tokens.view(1)])
else:
loss = torch.cat([loss.view(1), total_tokens.view(1)])
reporting_loss = loss.clone().detach()
loss[0] = loss[0] / mpu.get_context_parallel_world_size()
torch.distributed.all_reduce(reporting_loss, group=mpu.get_data_parallel_group())
local_num_tokens = loss[1].clone().detach().to(torch.int)
return loss, local_num_tokens, reporting_loss