import os
import pytest
import torch
import torch_npu
import mindspeed.megatron_adaptor
from megatron.core.enums import Fp8Recipe
from megatron.training.arguments import parse_args, core_transformer_config_from_args
from megatron.training.global_vars import set_args
from megatron.training.initialize import _set_random_seed
from mindspeed.core.transformer.moe import grouped_gemm_util as gg
from mindspeed.core.transformer.moe.grouped_matmul_util import TensorwiseGMMFunction
from mindspeed.te.pytorch.fp8 import cast_to_fp8, cast_to_fp8_cpu
from mindspeed.te.pytorch.fp8.constants import Format, FormatEnum
from mindspeed.te.pytorch.fp8.fp8 import fp8_autocast
from mindspeed.te.pytorch.fp8.recipes import TensorwiseMatMul, MXFP8MatMul, Float8BlockMatMul
from mindspeed.te.pytorch.module.linear import TEColumnParallelLinear, TERowParallelLinear
from tests_extend.commons import initialize_model_parallel
from tests_extend.unit_tests.common import DistributedTest
WORLD_SIZE = 8
class ColumnModel(torch.nn.Module):
def __init__(self, config, input_size, output_size, sp=False):
super().__init__()
word_size = torch.distributed.get_world_size()
assert output_size % word_size == 0, 'output size can not div with word size.'
self.linear = TEColumnParallelLinear(
input_size=input_size,
output_size=output_size,
config=config,
init_method=config.init_method,
gather_output=False,
bias=False,
skip_bias_add=False,
is_expert=False
)
def forward(self, x):
return self.linear(x)
class RowModel(torch.nn.Module):
def __init__(self, config, input_size, output_size, sp=False):
super().__init__()
word_size = torch.distributed.get_world_size()
assert output_size % word_size == 0, 'output size can not div with word size.'
self.linear = TERowParallelLinear(
input_size=input_size,
output_size=output_size,
config=config,
init_method=config.init_method,
input_is_parallel=True,
bias=False,
skip_bias_add=False,
is_expert=False
)
def forward(self, x):
return self.linear(x)
FORMAT_MAP = {
'e4m3': FormatEnum.E4M3,
'e5m2': FormatEnum.E5M2,
'hif8': FormatEnum.HIF8
}
def get_config_by_recipe(fp8_recipe, fp8_format, config):
from mindspeed.te.pytorch.fp8.recipes import Float8CurrentScaling, TEDelayedScaling, MXFP8BlockScaling, \
Float8BlockScaling
if fp8_format == "e4m3":
fp8_format = Format.E4M3
elif fp8_format == "hybrid":
fp8_format = Format.HYBRID
elif fp8_format == 'hif8':
fp8_format = Format.HIF8
else:
raise ValueError("E4M3 and HYBRID are the only supported FP8 formats.")
if fp8_recipe == Fp8Recipe.delayed:
fp8_recipe = TEDelayedScaling(
config=config,
fp8_format=fp8_format,
)
elif fp8_recipe == Fp8Recipe.tensorwise:
fp8_recipe = Float8CurrentScaling(
config=config,
fp8_format=fp8_format
)
elif fp8_recipe == Fp8Recipe.mxfp8:
fp8_recipe = MXFP8BlockScaling(
fp8_format=fp8_format
)
elif fp8_recipe == Fp8Recipe.blockwise:
fp8_recipe = Float8BlockScaling(
fp8_format=fp8_format
)
return fp8_recipe
@pytest.mark.skip(reason='not support for current version')
@pytest.mark.parametrize("fp8_args", [
("tensorwise", 'e4m3'),
("delayed", 'e4m3'),
("blockwise", 'e4m3'),
('tensorwise', 'hif8'),
('mxfp8', 'e4m3')
])
class TestFP8Model(DistributedTest):
def test_fp8_column_model(self, fp8_args):
iteration_num = 10
os.environ['HCCL_DETERMINISTIC'] = 'True'
input_size = 8192
output_size = 8192
args = parse_args(None, True)
args.params_dtype = torch.bfloat16
args.num_attention_heads = 16
args.hidden_size = 1024
args.num_layers = 2
args.tensor_model_parallel_size = 2
args.sequence_parallel = True
args.gradient_accumulation_fusion = False
set_args(args)
config = core_transformer_config_from_args(args)
recipe_config = get_config_by_recipe(*fp8_args, config)
initialize_model_parallel(WORLD_SIZE, 1)
_set_random_seed(seed_=123, data_parallel_random_init=False)
model = ColumnModel(config, input_size, output_size).npu()
for _ in range(iteration_num):
inputs = torch.randn([2048, input_size], requires_grad=True, dtype=torch.bfloat16,
device='npu:{}'.format(torch.npu.current_device()))
output = model(inputs)[0]
output.sum().backward()
baseline = output.clone()
baseline_wgrad = model.linear.weight.grad.clone()
baseline_dgrad = inputs.grad.clone()
model.zero_grad()
inputs.grad = None
fp8_context = fp8_autocast(enabled=True, fp8_recipe=recipe_config)
with fp8_context:
output = model(inputs)[0]
output.sum().backward()
fp8 = output.clone()
fp8_wgrad = model.linear.weight.grad.clone()
fp8_dgrad = inputs.grad.clone()
model.zero_grad()
torch.cuda.synchronize()
assert torch.allclose(baseline, fp8, atol=0.005, rtol=0.005)
assert torch.allclose(baseline_wgrad, fp8_wgrad, atol=0.005, rtol=0.005)
assert torch.allclose(baseline_dgrad, fp8_dgrad, atol=0.005, rtol=0.005)
def test_fp8_row_model(self, fp8_args):
iteration_num = 10
os.environ['HCCL_DETERMINISTIC'] = 'True'
input_size = 8192
output_size = 8192
args = parse_args(None, True)
args.params_dtype = torch.bfloat16
args.num_attention_heads = 16
args.hidden_size = 2048
args.num_layers = 2
args.tensor_model_parallel_size = 2
args.sequence_parallel = True
args.gradient_accumulation_fusion = False
set_args(args)
config = core_transformer_config_from_args(args)
recipe_config = get_config_by_recipe(*fp8_args, config)
initialize_model_parallel(WORLD_SIZE, 1)
_set_random_seed(seed_=123, data_parallel_random_init=False)
model = RowModel(config, input_size, output_size).npu()
for _ in range(iteration_num):
inputs = torch.randn([input_size // args.tensor_model_parallel_size, 1024, ], requires_grad=True,
device='npu:{}'.format(torch.npu.current_device()), dtype=torch.bfloat16)
output = model(inputs)[0]
output.sum().backward()
baseline = output.clone()
baseline_wgrad = model.linear.weight.grad.clone()
baseline_dgrad = inputs.grad.clone()
model.zero_grad()
inputs.grad = None
fp8_context = fp8_autocast(enabled=True, fp8_recipe=recipe_config)
with fp8_context:
output = model(inputs)[0]
output.sum().backward()
fp8 = output.clone()
fp8_wgrad = model.linear.weight.grad.clone()
fp8_dgrad = inputs.grad.clone()
model.zero_grad()
torch.cuda.synchronize()
assert torch.allclose(baseline, fp8, atol=0.005, rtol=0.005)
assert torch.allclose(baseline_wgrad, fp8_wgrad, atol=0.005, rtol=0.005)
assert torch.allclose(baseline_dgrad, fp8_dgrad, atol=0.005, rtol=0.005)
@pytest.mark.skip(reason='not support for current version')
@pytest.mark.parametrize("fp8_args", [
('tensorwise', 'e4m3'),
('delayed', 'e4m3'),
('blockwise', 'e4m3')
])
class TestFP8Cast(DistributedTest):
def test_fp8_cast(self, fp8_args):
initialize_model_parallel(WORLD_SIZE, 1)
recipe, format_str = fp8_args
fp8_format = FORMAT_MAP[format_str].value
if fp8_format == FormatEnum.E4M3.value:
max_val = 448.0
elif fp8_format == FormatEnum.E5M2.value:
max_val = 57344.0
input_size = (128, 256)
rand_tensor = torch.rand(*input_size, dtype=torch.bfloat16)
scaled_tensor = rand_tensor * 2 * max_val - max_val
inputs_npu = scaled_tensor.clone().npu()
inputs_cpu = scaled_tensor.clone().cpu()
fp8_npu = cast_to_fp8(inputs_npu, fp8_format)
fp8_cpu = cast_to_fp8_cpu(inputs_cpu, fp8_format)
fp8_cpu = fp8_cpu.npu()
abs_error = torch.abs(fp8_cpu.to(torch.float32) - fp8_npu.to(torch.float32))
rel_error = abs_error / torch.abs(fp8_cpu.to(torch.float32))
max_abs_error = torch.max(abs_error)
max_rel_error = torch.max(rel_error)
if max_rel_error > 0.0:
raise ValueError(f"The error of cast exceeds tolerance: {max_rel_error.item()}")
@pytest.mark.skip(reason='not support for current version')
@pytest.mark.parametrize("fp8_args", [
('tensorwise', 'e4m3'),
('delayed', 'e4m3'),
('blockwise', 'e4m3')
])
class TestFP8WithCpu(DistributedTest):
def test_fp8_quantatmul(self, fp8_args):
args = parse_args(None, True)
args.params_dtype = torch.bfloat16
args.num_attention_heads = 16
args.hidden_size = 1024
args.num_layers = 2
args.tensor_model_parallel_size = 2
args.sequence_parallel = True
args.gradient_accumulation_fusion = False
set_args(args)
initialize_model_parallel(WORLD_SIZE, 1)
config = core_transformer_config_from_args(args)
recipe_config = get_config_by_recipe(*fp8_args, config)
input_size1 = (256, 128)
input_size2 = (128, 256)
bf16_tensor_x1 = torch.randn(*input_size1, dtype=torch.bfloat16)
bf16_tensor_x2 = torch.randn(*input_size2, dtype=torch.bfloat16)
inputs_npu_x1 = bf16_tensor_x1.clone().npu()
inputs_npu_x2 = bf16_tensor_x2.clone().npu()
recipe_npu_x1 = recipe_config.recipe('inputs', recipe_config, input_size1)
recipe_npu_x2 = recipe_config.recipe('weight', recipe_config, input_size2)
fp8_npu_x1 = recipe_npu_x1.quantization(inputs_npu_x1)
fp8_npu_x2 = recipe_npu_x2.quantization(inputs_npu_x2)
output_npu = fp8_npu_x1.quant_matmul(fp8_npu_x2)
from mindspeed.te.pytorch.fp8.tensor import Float8TensorCpu
fp8_cpu_x1 = Float8TensorCpu(
torch.tensor([]),
torch.float8_e4m3fn,
None,
torch.float32
)
fp8_cpu_x2 = Float8TensorCpu(
torch.tensor([]),
torch.float8_e4m3fn,
None,
torch.float32
)
fp8_cpu_x1.from_float8tensor(fp8_npu_x1)
fp8_cpu_x2.from_float8tensor(fp8_npu_x2)
output_cpu = fp8_cpu_x1.quant_matmul(fp8_cpu_x2)
fp8_cpu_data_x1 = fp8_cpu_x1.data.npu().to(torch.float32)
fp8_cpu_data_x2 = fp8_cpu_x2.data.npu().to(torch.float32)
assert torch.allclose(fp8_npu_x1.data.to(torch.float32), fp8_cpu_data_x1, atol=0.0, rtol=0.0)
assert torch.allclose(fp8_npu_x2.data.to(torch.float32), fp8_cpu_data_x2, atol=0.0, rtol=0.0)
assert torch.allclose(output_npu, output_cpu.npu(), atol=0.001, rtol=0.001)
@pytest.mark.skip(reason='not support for current version')
@pytest.mark.parametrize("gmm_quant_func", [
TensorwiseGMMFunction
])
class TestFP8GMM(DistributedTest):
iter_times = 10
num_local_experts = 16
world_size = 8
hidden_size = 56
gemm_fusion = False
def test_gmm_forward(self, gmm_quant_func):
initialize_model_parallel(self.world_size, 1)
for _ in range(self.iter_times):
weight = torch.rand((56, 256), dtype=torch.bfloat16, requires_grad=True).npu()
x = torch.rand((134, 56), dtype=torch.bfloat16, requires_grad=True).npu()
tokens_per_expert = torch.tensor([4288, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ],
dtype=torch.int64).npu()
weight = weight.view(self.num_local_experts, self.hidden_size, -1)
normal_x, normal_weight = x.clone(), weight.clone()
quant_output = gmm_quant_func.apply(x, weight, None, tokens_per_expert)
quant_bk_dx, quant_bk_dw = torch.autograd.grad(outputs=quant_output.sum(), inputs=(x, weight))
cp_output, cp_bk_dx, cp_bk_dw = self.compare(normal_weight, normal_x, tokens_per_expert)
assert torch.allclose(quant_output, cp_output, atol=0.005, rtol=0.005)
assert torch.allclose(quant_bk_dx, cp_bk_dx, atol=0.005, rtol=0.005)
assert torch.allclose(quant_bk_dw, cp_bk_dw, atol=0.005, rtol=0.005)
def compare(self, weight, x, tokens_per_expert):
output = gg.ops.gmm(x, weight, tokens_per_expert, trans_b=False,
gemm_fusion=self.gemm_fusion, original_weight=weight)
dx, dw = torch.autograd.grad(outputs=output.sum(), inputs=(x, weight))
return output, dx, dw
@pytest.mark.skip(reason='not support for current version')
@pytest.mark.parametrize("fp8_args", [
(TensorwiseMatMul, 'e4m3'),
(TensorwiseMatMul, 'hybrid'),
(TensorwiseMatMul, 'hif8'),
(MXFP8MatMul, 'e4m3'),
(MXFP8MatMul, 'hybrid'),
(Float8BlockMatMul, 'e4m3'),
(Float8BlockMatMul, 'hybrid'),
])
class TestFP8MM(DistributedTest):
iter_times = 10
world_size = 8
def test_mm(self, fp8_args):
mm_func, _format = fp8_args
initialize_model_parallel(self.world_size, 1)
_args = parse_args(None, True)
_args.fp8 = _format
set_args(_args)
for _ in range(self.iter_times):
weight = torch.rand((256, 56), dtype=torch.bfloat16, requires_grad=True).npu()
x = torch.rand((134, 56), dtype=torch.bfloat16, requires_grad=True).npu()
normal_x, normal_weight = x.clone(), weight.clone()
quant_output = mm_func.apply(x, weight)
quant_bk_dx, quant_bk_dw = torch.autograd.grad(outputs=quant_output.sum(), inputs=(x, weight))
cp_output, cp_bk_dx, cp_bk_dw = self.compare(normal_weight, normal_x)
assert torch.allclose(quant_output, cp_output, atol=0.005, rtol=0.005)
assert torch.allclose(quant_bk_dx, cp_bk_dx, atol=0.005, rtol=0.005)
assert torch.allclose(quant_bk_dw, cp_bk_dw, atol=0.005, rtol=0.005)
def compare(self, weight, x):
output = torch.matmul(x, weight.t())
dx, dw = torch.autograd.grad(outputs=output.sum(), inputs=(x, weight))
return output, dx, dw