import warnings
from functools import lru_cache
import torch
import torch_npu
from megatron.core import parallel_state
from mindspeed.args_utils import get_full_args as get_args
def view_as_n_dim(input_tensor, dim=2):
if dim < 2:
raise AssertionError("dim should be greater than or equal to 2")
if len(input_tensor.shape) != dim:
return input_tensor.view(-1, *input_tensor.shape[-dim + 1:])
return input_tensor
class QuantDtype:
def __init__(self, x: torch.dtype, w: torch.dtype, grads: torch.dtype):
self.x = x
self.w = w
self.grads = grads
if self.x == torch_npu.hifloat8:
self.mm_kwargs = {'x1_dtype': self.x, 'x2_dtype': self.w}
self.gmm_kwargs = {"x_dtype": self.x, "weight_dtype": self.w}
else:
self.mm_kwargs = {}
self.gmm_kwargs = {}
@lru_cache
def get_quant_dtype():
args = get_args()
if args.fp8 == 'hif8':
return QuantDtype(torch_npu.hifloat8, torch_npu.hifloat8, torch_npu.hifloat8)
elif args.fp8 == 'hybrid':
return QuantDtype(torch.float8_e4m3fn, torch.float8_e4m3fn, torch.float8_e5m2)
return QuantDtype(torch.float8_e4m3fn, torch.float8_e4m3fn, torch.float8_e4m3fn)
def get_hccl_comm_name(group, rank):
if torch.__version__ > "2.0":
global_rank = torch.distributed.get_global_rank(group, rank)
hcomm_name = group._get_backend(torch.device("npu")).get_hccl_comm_name(global_rank)
else:
hcomm_name = group.get_hccl_comm_name(rank)
return hcomm_name
def all_gather_along_dim(input_, async_op=False, axis=0):
from megatron.core.parallel_state import get_tensor_model_parallel_group, get_tensor_model_parallel_world_size
group = get_tensor_model_parallel_group()
world_size = get_tensor_model_parallel_world_size()
dim_size = list(input_.size())
dim_size[axis] = dim_size[axis] * world_size
output_ = torch.empty(dim_size, dtype=input_.dtype, device=torch.npu.current_device(), requires_grad=False)
handle = torch.distributed._all_gather_base(output_, input_.contiguous(), group=group, async_op=async_op)
return handle, output_
def gather_split_1d_tensor(tensor, tp_group=None):
"""Opposite of split_tensor_into_1d_equal_chunks. Gather values from tensor
model parallel ranks.
Returns a new Tensor with the gathered data.
Args:
tensor: A Tensor or view of this rank's portion of the data.
"""
tp_group = get_tensor_model_parallel_group_if_none(tp_group)
numel_gathered = torch.numel(tensor) * tp_group.size()
gathered = torch.empty(
numel_gathered, dtype=tensor.dtype, device=torch.cuda.current_device(), requires_grad=False
)
torch.distributed._all_gather_base(gathered, tensor, group=tp_group)
return gathered
def get_tensor_model_parallel_group_if_none(tp_group, is_expert=False, check_initialized=True):
"""Issue a deprecation warning if tp_group is None and return the default tp group."""
if not torch.distributed.is_initialized():
return None
if tp_group is None:
if torch.distributed.is_initialized() and torch.distributed.get_rank() == 0:
warnings.warn(
"Warning: tp_group is None, using default tp group. "
"Passing tp_group will be mandatory soon",
DeprecationWarning,
stacklevel=2,
)
if is_expert:
tp_group = parallel_state.get_expert_tensor_parallel_group(
check_initialized=check_initialized
)
else:
tp_group = parallel_state.get_tensor_model_parallel_group(
check_initialized=check_initialized
)
return tp_group
def split_tensor_into_1d_equal_chunks(tensor, new_buffer=False, tp_group=None):
"""Break a tensor into equal 1D chunks across tensor parallel ranks.
Returns a Tensor or View with this rank's portion of the data.
Args:
tensor: The tensor to split
Keyword Args:
new_buffer (bool): If True, returns a new Tensor.
If False, returns a view into the existing Tensor.
Default is False
"""
tp_group = get_tensor_model_parallel_group_if_none(tp_group)
partition_size = torch.numel(tensor) // tp_group.size()
start_index = partition_size * tp_group.rank()
end_index = start_index + partition_size
if new_buffer:
data = torch.empty(
partition_size,
dtype=tensor.dtype,
device=torch.npu.current_device(),
requires_grad=False,
)
data.copy_(tensor.view(-1)[start_index:end_index])
else:
data = tensor.view(-1)[start_index:end_index]
return data