import logging
import torch
import torch_npu
from mindspeed.te.pytorch.fp8.checkpoint import is_fp8_activation_recompute_enabled, in_fp8_activation_recompute_phase
from mindspeed.te.pytorch.fp8.constants import FormatEnum, MatmulKey, get_matmul_wise_by_tensor_key
from mindspeed.te.pytorch.fp8.tensor import is_fp8_tensor, is_fp8_tensor_2d, Float8TensorCpu, MXFP8TensorCpu
from mindspeed.te.pytorch.module_typing import FP8Metadata, FP8Tensor
logger = logging.getLogger(__name__)
__all__ = ["is_fp8_tensor_2d"]
def fp8_matmul(
inputs: FP8Tensor,
weight: FP8Tensor,
fp8_meta: FP8Metadata,
key: MatmulKey,
) -> tuple[torch.Tensor, FP8Tensor, FP8Tensor]:
"""
Returns:
output: high accuracy output
input: fp8 inputs
weight: fp8 weight
"""
if key == MatmulKey.forward:
rowwise = not is_fp8_activation_recompute_enabled() or in_fp8_activation_recompute_phase()
else:
rowwise = True
if not is_fp8_tensor(inputs):
inputs = fp8_meta.quantization(key[0], inputs, rowwise=rowwise)
if not is_fp8_tensor(weight):
weight = fp8_meta.quantization(key[1], weight, rowwise=rowwise)
output = inputs.quant_matmul(weight, get_matmul_wise_by_tensor_key(inputs, key), key)
return output, inputs, weight
def fp8_matmul_add(
main_grad: torch.Tensor, inputs: FP8Tensor, weight: FP8Tensor, fp8_meta: FP8Metadata, key: MatmulKey = MatmulKey.dw
):
if not is_fp8_tensor(inputs):
inputs = fp8_meta.quantization(key[0], inputs, rowwise=True)
if not is_fp8_tensor(weight):
weight = fp8_meta.quantization(key[1], weight, rowwise=True)
inputs.quant_matmul_add(main_grad, weight, get_matmul_wise_by_tensor_key(inputs, key))
def te_online_comparison_cpu(inputs, weight, output):
if inputs.fp8_dtype not in [torch.float8_e4m3fn, torch.float8_e5m2]:
raise ValueError(
f"TE online comparison only supports e4m3 and e5m2 formats, but fp8_dtype is {inputs.fp8_dtype}"
)
weight_cpu = Float8TensorCpu(torch.tensor([]), torch.float8_e4m3fn, None, torch.float32)
input_cpu = Float8TensorCpu(torch.tensor([]), torch.float8_e4m3fn, None, torch.float32)
weight_cpu.from_float8tensor(weight)
input_cpu.from_float8tensor(inputs)
output_cpu = input_cpu.quant_matmul(weight_cpu).npu()
abs_error = torch.abs(output_cpu - output)
rel_error = abs_error / torch.abs(output_cpu)
max_abs_error = torch.max(abs_error)
max_rel_error = torch.max(rel_error)
logger.info("The error of quant_matmul: ")
logger.info("[%s] Max Absolute Error: %s", output.device, max_abs_error.item())
logger.info("[%s] Max Relative Error: %s", output.device, max_rel_error.item())
if max_rel_error > 0.001:
raise ValueError(f"The error of quant_matmul exceeds tolerance: {max_rel_error.item()}")
def te_online_comparison_bf16(inputs, weight, output):
inputs_bf16 = inputs.data.to(torch.float32).to(torch.bfloat16)
weight_bf16 = weight.data.to(torch.float32).to(torch.bfloat16)
bf16_output = torch.matmul(inputs_bf16, weight_bf16)
abs_error = torch.abs(bf16_output - output)
rel_error = abs_error / torch.abs(bf16_output)
max_abs_error = torch.max(abs_error)
max_rel_error = torch.max(rel_error)
logger.info("The error of quant_matmul: ")
logger.info("[%s] Max Absolute Error: %s", output.device, max_abs_error.item())
logger.info("[%s] Max Relative Error: %s", output.device, max_rel_error.item())
if max_rel_error > 0.001:
raise ValueError(f"The error of quant_matmul exceeds tolerance: {max_rel_error.item()}")
def te_online_comparison_mxfp8_cpu(inputs, weight, transpose, output):
if inputs.fp8_dtype not in [torch.float8_e4m3fn, torch.float8_e5m2]:
raise ValueError(
f"TE online comparison only supports e4m3 and e5m2 formats, but fp8_dtype is {inputs.fp8_dtype}"
)
weight_cpu = MXFP8TensorCpu(torch.float8_e4m3fn, torch.Size(), torch.float32)
input_cpu = MXFP8TensorCpu(torch.float8_e4m3fn, torch.Size(), torch.float32)
weight_cpu.from_MXFP8Tensor(weight)
input_cpu.from_MXFP8Tensor(inputs)
output_cpu = input_cpu.quant_matmul(weight_cpu, transpose).npu()
abs_error = torch.abs(output_cpu - output)
rel_error = abs_error / torch.abs(output_cpu)
max_abs_error = torch.max(abs_error)
max_rel_error = torch.max(rel_error)
logger.info("The error of quant_matmul: ")
logger.info("[%s] Max Absolute Error: %s", output.device, max_abs_error.item())
logger.info("[%s] Max Relative Error: %s", output.device, max_rel_error.item())
if max_rel_error > 0.001:
raise ValueError(f"The error of quant_matmul exceeds tolerance: {max_rel_error.item()}")
def padding_bf16_scale(mxfp8_tensor, scale_tensor):
mxfp8_shape = mxfp8_tensor.shape
scale_shape = scale_tensor.shape
new_scale_shape = []
padding_dim = -1
for i, x in enumerate(mxfp8_shape):
if x != scale_shape[i]:
new_scale_shape.append(scale_shape[i] * scale_shape[-1])
padding_dim = i
else:
new_scale_shape.append(scale_shape[i])
scale_tensor = scale_tensor.view(new_scale_shape)
scale_bf16 = scale_tensor.to(torch.float32).to(torch.bfloat16)
scale_bf16 = torch.repeat_interleave(scale_bf16, 32, dim=padding_dim)
if scale_bf16.shape[padding_dim] > mxfp8_tensor.shape[padding_dim]:
scale_bf16 = scale_bf16.narrow(dim=padding_dim, start=0, length=mxfp8_tensor.shape[padding_dim])
return scale_bf16
def te_online_comparison_mxfp8_bf16(inputs, weight, transpose, output):
x1, x_scale = inputs.get_quant_data(transpose[0])
x2, weight_scale = weight.get_quant_data(transpose[1])
x_mxfp8_bf16 = x1.to(torch.float32).to(torch.bfloat16)
weight_mxfp8_bf16 = x2.to(torch.float32).to(torch.bfloat16)
x_scale_bf16 = padding_bf16_scale(x_mxfp8_bf16, x_scale)
weight_scale_bf16 = padding_bf16_scale(weight_mxfp8_bf16, weight_scale)
x_bf16 = torch.div(x_mxfp8_bf16, x_scale_bf16)
weight_bf16 = torch.div(weight_mxfp8_bf16, weight_scale_bf16)
bf16_output = torch.matmul(x_bf16, weight_bf16)
abs_error = torch.abs(bf16_output - output)
rel_error = abs_error / torch.abs(bf16_output)
max_abs_error = torch.max(abs_error)
max_rel_error = torch.max(rel_error)
logger.info("The error of quant_matmul: ")
logger.info("[%s] Max Absolute Error: %s", output.device, max_abs_error.item())
logger.info("[%s] Max Relative Error: %s", output.device, max_rel_error.item())
if max_rel_error > 0.001:
raise ValueError(f"The error of quant_matmul exceeds tolerance: {max_rel_error.item()}")
class Cast2FP8(torch.autograd.Function):
@staticmethod
def forward(ctx, quant_tensor, fp8_format):
if fp8_format == FormatEnum.E4M3.value:
return quant_tensor.to(torch.float8_e4m3fn)
elif fp8_format == FormatEnum.E5M2.value:
return quant_tensor.to(torch.float8_e5m2)
elif fp8_format == FormatEnum.HIF8.value:
return torch_npu.HiFloat8Tensor.to_hifloat8(quant_tensor)
else:
raise ValueError("Only e4m3, e5m2 and hif8 of the fp8 format are supported.")
@staticmethod
def backward(ctx, grad_output):
return grad_output, None, None
def cast_to_fp8(quant_tensor, fp8_format):
return Cast2FP8.apply(quant_tensor, fp8_format)
class Cast2FP8Cpu(torch.autograd.Function):
@staticmethod
def forward(ctx, quant_tensor, fp8_format):
if fp8_format == FormatEnum.E4M3.value:
quantized_tensor = quant_tensor.to(torch.float8_e4m3fn)
elif fp8_format == FormatEnum.E5M2.value:
quantized_tensor = quant_tensor.to(torch.float8_e5m2)
else:
raise ValueError("Only e4m3, e5m2 of the fp8 format are supported on cpu.")
return quantized_tensor
@staticmethod
def backward(ctx, grad_output):
return grad_output, None, None
def cast_to_fp8_cpu(quant_tensor, fp8_format):
return Cast2FP8Cpu.apply(quant_tensor, fp8_format)