import torch
from mindspeed.te.pytorch.fp8 import fp8_matmul, get_matmul_wise_by_tensor_key
from mindspeed.te.pytorch.module.ops.comm_overlap_ops import CommOverlapOps
class DefaultOps(CommOverlapOps):
@staticmethod
def allgather_matmul(input_, weight, bias, fp8_meta, key=None, fp8_enable=False):
dim_size = list(input_.size())
dim_size[0] = dim_size[0] * fp8_meta.tp_world_size
total_input = torch.empty(dim_size, dtype=input_.dtype, device=input_.device)
torch.distributed._all_gather_base(total_input, input_.contiguous(), group=fp8_meta.tp_group, async_op=False)
if not fp8_enable:
transpose = get_matmul_wise_by_tensor_key(total_input, key)
output = torch.matmul(
total_input.t() if transpose[0] else total_input,
weight.t() if transpose[1] else weight
)
return output, total_input, None
else:
output, input_fp8, weight_fp8 = fp8_matmul(total_input, weight, fp8_meta, key)
return output, input_fp8, weight_fp8
@staticmethod
def matmul_reduce_scatter(input_, weight, bias, fp8_meta, key, fp8_enable=False):
if not fp8_enable:
transpose = get_matmul_wise_by_tensor_key(input_, key)
output_ = torch.matmul(
input_.t() if transpose[0] else input_,
weight.t() if transpose[1] else weight
)
else:
output_, input_, weight = fp8_matmul(input_, weight, fp8_meta, key)
dim_size = list(output_.size())
dim_size[0] = dim_size[0] // fp8_meta.tp_world_size
output = torch.empty(dim_size, dtype=output_.dtype, device=torch.cuda.current_device())
torch.distributed._reduce_scatter_base(output, output_.contiguous(), group=fp8_meta.tp_group)
return output, input_, weight
@staticmethod
def matmul_all_reduce(input_, weight, bias, fp8_meta, key=None, fp8_enable=False):
if not fp8_enable:
output_ = torch.matmul(input_, weight.t())
else:
output_, input_, weight = fp8_matmul(input_, weight, fp8_meta, key)
if fp8_meta.tp_world_size > 1:
torch.distributed.all_reduce(output_, group=fp8_meta.tp_group)
if bias is not None:
output_ = output_ + bias
return output_, input_, weight