import torch_npu
from mindspeed.args_utils import get_full_args as get_args
from mindspeed.te.pytorch.fp8 import MatmulKey
from mindspeed.te.pytorch.fp8.constants import get_matmul_wise_by_tensor_key
from mindspeed.te.pytorch.fp8.tensor.float8_tensor import Float8Tensor2D
from mindspeed.te.pytorch.module_typing import FP8Metadata
from mindspeed.te.pytorch.utils import all_gather_along_dim
class MXFP8Tensor(Float8Tensor2D):
def quant_matmul(self, other: 'MXFP8Tensor', is_rowwise, key: MatmulKey):
x1, x1_scale = self.get_quant_data(is_rowwise[0])
x2, x2_scale = other.get_quant_data(is_rowwise[1])
output = torch_npu.npu_quant_matmul(
x1,
x2,
x2_scale,
pertoken_scale=x1_scale,
output_dtype=self.dtype,
scale_dtype=torch_npu.float8_e8m0fnu,
pertoken_scale_dtype=torch_npu.float8_e8m0fnu,
group_sizes=[1, 1, 32],
)
output = self.restore_reshape(other, output)
args = get_args()
if args.te_comparison_with_cpu:
from mindspeed.te.pytorch.fp8 import te_online_comparison_mxfp8_cpu
te_online_comparison_mxfp8_cpu(self, other, is_rowwise, output)
if args.te_comparison_with_bf16:
from mindspeed.te.pytorch.fp8 import te_online_comparison_mxfp8_bf16
te_online_comparison_mxfp8_bf16(self, other, is_rowwise, output)
self.release(x1, x1_scale, key)
other.release(x2, x2_scale, key)
return output
def quant_matmul_add(self, main_grad, other: 'MXFP8Tensor', is_rowwise):
x1, x1_scale = self.get_quant_data(is_rowwise[0])
x2, x2_scale = other.get_quant_data(is_rowwise[1])
torch_npu.npu_add_quant_matmul_(
main_grad,
x1,
x2,
x2_scale,
x1_scale=x1_scale,
x1_scale_dtype=torch_npu.float8_e8m0fnu,
x2_scale_dtype=torch_npu.float8_e8m0fnu,
group_sizes=[1, 1, 32],
)
self.release(x1, x1_scale)
other.release(x2, x2_scale)
def all_gather_matmul(self, other: 'MXFP8Tensor', bias, fp8_meta: FP8Metadata, key: MatmulKey):
_, is_rowwise = get_matmul_wise_by_tensor_key(self, key)
x2, x2_scale = other.get_quant_data(is_rowwise)
row_data, row_scale = self.row_tensor.t() if key == MatmulKey.dx else self.row_tensor
_, row_data = all_gather_along_dim(row_data)
_, row_scale = all_gather_along_dim(row_scale)
output, _, _ = torch_npu.npu_all_gather_quant_mm(
self.col_tensor.data,
x2,
fp8_meta.hcom_name,
fp8_meta.tp_world_size,
bias=bias,
x1_scale=self.col_tensor.scale,
x2_scale=x2_scale,
quant_scale=None,
block_size=0,
comm_turn=0,
group_sizes=[1, 1, 32],
amax_output=False,
y_dtype=self.dtype,
gather_output=False,
x1_dtype=None,
x2_dtype=None,
x1_scale_dtype=torch_npu.float8_e8m0fnu,
x2_scale_dtype=torch_npu.float8_e8m0fnu,
)
gather_out = MXFP8Tensor(self.fp8_dtype, self.origin_shape, self.device, dtype=self.dtype)
gather_out.set_row_data(row_data, row_scale, key == MatmulKey.dx)
return output.view(-1, self.origin_shape[1], output.shape[1]), gather_out
def matmul_reduce_scatter(self, other: 'MXFP8Tensor', bias, fp8_meta: FP8Metadata, key: MatmulKey):
x1_row_wise, x2_row_wise = get_matmul_wise_by_tensor_key(self, key)
x1, x1_scale = self.get_quant_data(x1_row_wise)
x2, x2_scale = other.get_quant_data(x2_row_wise)
output, _ = torch_npu.npu_quant_mm_reduce_scatter(
x1,
x2,
fp8_meta.hcom_name,
fp8_meta.tp_world_size,
bias=bias,
reduce_op='sum',
x1_scale=x1_scale,
x2_scale=x2_scale,
quant_scale=None,
block_size=0,
comm_turn=0,
group_sizes=[1, 1, 32],
amax_output=False,
y_dtype=self.dtype,
x1_dtype=None,
x2_dtype=None,
x1_scale_dtype=torch_npu.float8_e8m0fnu,
x2_scale_dtype=torch_npu.float8_e8m0fnu,
)
return output.view(-1, self.origin_shape[1], output.shape[1])