import pytest
import torch
import torch_npu
from torch.nn.parameter import Parameter
import mindspeed.megatron_adaptor
from megatron.core.tensor_parallel.layers import ColumnParallelLinear, RowParallelLinear
from megatron.training.arguments import parse_args, core_transformer_config_from_args
from megatron.training.global_vars import set_args
from megatron.training.initialize import _set_random_seed
from mindspeed.te.pytorch.fp8 import MatmulKey
from mindspeed.te.pytorch.fp8.metadata import FP8Metadata
from mindspeed.te.pytorch.module.linear import TEColumnParallelLinear, TERowParallelLinear
from mindspeed.te.pytorch.module.ops.default_ops import DefaultOps
from mindspeed.te.pytorch.module.ops.mc2_ops import Mc2Ops
from tests_extend.commons import initialize_model_parallel
from tests_extend.unit_tests.common import DistributedTest
from tests_extend.unit_tests.utils import multi_compare
class TestAllgatherMatmul(DistributedTest):
world_size = 8
def test_allgather_matmul(self):
batch_size = 1
seq_size = 4096
input_size = 1024
dtype = torch.bfloat16
args = parse_args(None, True)
args.params_dtype = dtype
set_args(args)
initialize_model_parallel(self.world_size, 1)
_set_random_seed(seed_=123, data_parallel_random_init=False)
x = torch.randn(seq_size, batch_size, input_size, dtype=dtype) * 5
w = torch.randn(seq_size, input_size, dtype=dtype) * 5
fp8_meta = FP8Metadata(["inputs", "weight", "grads"])
output_baseline = self.allgather_matmul(x, w).view(-1)
output_default = DefaultOps.allgather_matmul(x.npu(), w.npu(), None, fp8_meta, MatmulKey.forward)
output_mc2 = Mc2Ops.allgather_matmul(x.npu(), w.npu(), None, fp8_meta, MatmulKey.forward)
output_default = output_default[0].view(-1).cpu()
output_mc2 = output_mc2[0].view(-1).cpu()
assert multi_compare(output_mc2, output_baseline, output_default, f"{torch.npu.current_device()}") != "FAIL"
def allgather_matmul(self, input_, weight):
dim_size = list(input_.size())
dim_size[0] = dim_size[0] * self.world_size
total_input = torch.empty(dim_size, dtype=input_.dtype, device=input_.device)
group = torch.distributed.new_group(list(range(self.world_size)), backend="gloo")
torch.distributed._all_gather_base(total_input, input_.contiguous(), group=group, async_op=False)
return torch.matmul(total_input, weight.t())
class TestMatmulReduceScatter(DistributedTest):
world_size = 8
def test_matmul_reduce_scatter(self):
batch_size = 1
seq_size = 4096
input_size = 1024
dtype = torch.bfloat16
args = parse_args(None, True)
args.params_dtype = dtype
set_args(args)
initialize_model_parallel(self.world_size, 1)
_set_random_seed(seed_=123, data_parallel_random_init=False)
x = torch.randn(seq_size, batch_size, input_size, dtype=dtype) * 5
w = torch.randn(seq_size, input_size, dtype=dtype) * 5
fp8_meta = FP8Metadata(["inputs", "weight", "grads"])
output_baseline = self.reduce_scatter(x, w).view(-1)
output_default, _, _ = DefaultOps.matmul_reduce_scatter(x.npu(), w.npu(), None, fp8_meta, MatmulKey.forward)
output_mc2, _, _ = Mc2Ops.matmul_reduce_scatter(x.npu(), w.npu(), None, fp8_meta, MatmulKey.forward)
output_default = output_default.view(-1).cpu()
output_mc2 = output_mc2.view(-1).cpu()
assert multi_compare(
output_mc2,
output_baseline,
output_default,
f"{torch.npu.current_device()}",
"l0"
) != "FAIL"
def reduce_scatter(self, x, w):
output_ = torch.matmul(x, w.t())
dim_size = list(output_.size())
dim_size[0] = dim_size[0] // self.world_size
output = torch.empty(dim_size, dtype=output_.dtype)
group = torch.distributed.new_group(list(range(self.world_size)), backend="gloo")
torch.distributed._reduce_scatter_base(output, output_.contiguous(), group=group)
return output
class TestMatmulAllReduce(DistributedTest):
world_size = 8
def test_matmul_all_reduce(self):
batch_size = 1
seq_size = 4096
input_size = 1024
dtype = torch.bfloat16
args = parse_args(None, True)
args.params_dtype = dtype
set_args(args)
initialize_model_parallel(self.world_size, 1)
_set_random_seed(seed_=123, data_parallel_random_init=False)
x = torch.randn(seq_size, batch_size, input_size, dtype=dtype) * 5
w = torch.randn(seq_size, input_size, dtype=dtype) * 5
fp8_meta = FP8Metadata(["inputs", "weight", "grads"])
output_baseline = self.matmul_all_reduce(x, w).view(-1)
output_default, _, _ = DefaultOps.matmul_all_reduce(x.npu(), w.npu(), None, fp8_meta)
output_mc2, _, _ = Mc2Ops.matmul_all_reduce(x.npu(), w.npu(), None, fp8_meta)
output_default = output_default.view(-1).cpu()
output_mc2 = output_mc2.view(-1).cpu()
assert multi_compare(
output_mc2,
output_baseline,
output_default,
f"{torch.npu.current_device()}",
"l0"
) != "FAIL"
def matmul_all_reduce(self, x, w):
output = torch.matmul(x, w.t())
group = torch.distributed.new_group(list(range(self.world_size)), backend="gloo")
torch.distributed.all_reduce(output, group=group)
return output
class TestTEColumnParallel(DistributedTest):
world_size = 8
@pytest.mark.parametrize("use_ascend_mc2", [True, False])
@pytest.mark.parametrize("limit_args", [
(torch.bfloat16, 0.005, 0.005)
])
def test_te_column_parallel(self, use_ascend_mc2, limit_args):
batch_size = 1
seq_size = 4096
input_size = 1024
output_size = 1024
dtype, rtol, atol = limit_args
args = parse_args(None, True)
args.params_dtype = dtype
args.num_attention_heads = 16
args.hidden_size = 2048
args.num_layers = 2
args.gradient_accumulation_fusion = False
args.tensor_model_parallel_size = self.world_size
args.sequence_parallel = True
args.use_ascend_mc2 = use_ascend_mc2
args.transformer_impl = "transformer_engine"
set_args(args)
config = core_transformer_config_from_args(args)
initialize_model_parallel(self.world_size, 1)
_set_random_seed(seed_=123, data_parallel_random_init=False)
inputs = torch.rand(batch_size, seq_size, input_size, requires_grad=True, dtype=dtype).npu()
teinputs = inputs.clone()
linear = ColumnParallelLinear(
input_size=input_size,
output_size=output_size,
config=config,
init_method=config.init_method,
gather_output=False,
bias=False,
skip_bias_add=False,
is_expert=False,
)
telinear = TEColumnParallelLinear(
input_size=input_size,
output_size=output_size,
config=config,
init_method=config.init_method,
gather_output=False,
bias=False,
skip_bias_add=False,
is_expert=False,
)
telinear.weight = Parameter(linear.weight.clone())
outputs = linear(inputs)
teoutputs = telinear(teinputs)
outputs[0].sum().backward()
teoutputs[0].sum().backward()
assert torch.allclose(outputs[0], teoutputs[0], rtol=rtol, atol=atol)
assert torch.allclose(linear.weight.grad, telinear.weight.grad, rtol=rtol, atol=atol)
class TestTEColumnParallelNoSeq(DistributedTest):
world_size = 8
@pytest.mark.parametrize("use_ascend_mc2", [False])
@pytest.mark.parametrize("limit_args", [
(torch.bfloat16, 0.005, 0.005)
])
def test_te_column_parallel_no_seq(self, use_ascend_mc2, limit_args):
batch_size = 1
seq_size = 4096
input_size = 1024
output_size = 1024
dtype, rtol, atol = limit_args
args = parse_args(None, True)
args.params_dtype = dtype
args.num_attention_heads = 16
args.hidden_size = 2048
args.num_layers = 2
args.gradient_accumulation_fusion = False
args.tensor_model_parallel_size = self.world_size
args.sequence_parallel = False
args.use_ascend_mc2 = use_ascend_mc2
args.transformer_impl = "transformer_engine"
set_args(args)
config = core_transformer_config_from_args(args)
initialize_model_parallel(self.world_size, 1)
_set_random_seed(seed_=123, data_parallel_random_init=False)
inputs = torch.rand(batch_size, seq_size, input_size, requires_grad=True, dtype=dtype).npu()
teinputs = inputs.clone()
linear = ColumnParallelLinear(
input_size=input_size,
output_size=output_size,
config=config,
init_method=config.init_method,
gather_output=False,
bias=False,
skip_bias_add=False,
is_expert=False,
)
telinear = TEColumnParallelLinear(
input_size=input_size,
output_size=output_size,
config=config,
init_method=config.init_method,
gather_output=False,
bias=False,
skip_bias_add=False,
is_expert=False,
)
telinear.weight = Parameter(linear.weight.clone())
outputs = linear(inputs)
teoutputs = telinear(teinputs)
outputs[0].sum().backward()
teoutputs[0].sum().backward()
assert torch.allclose(outputs[0], teoutputs[0], rtol=rtol, atol=atol)
assert torch.allclose(linear.weight.grad, telinear.weight.grad, rtol=rtol, atol=atol)
class TestTERowParallel(DistributedTest):
world_size = 8
@pytest.mark.parametrize("use_ascend_mc2", [True, False])
@pytest.mark.parametrize("limit_args", [
(torch.bfloat16, 0.005, 0.005)
])
def test_te_row_parallel(self, use_ascend_mc2, limit_args):
batch_size = 1
seq_size = 4096
input_size = 2048
output_size = 4096
dtype, rtol, atol = limit_args
args = parse_args(None, True)
args.params_dtype = dtype
args.init_method_std = 0.002
args.num_attention_heads = 16
args.hidden_size = 2048
args.num_layers = 2
args.gradient_accumulation_fusion = False
args.tensor_model_parallel_size = self.world_size
args.sequence_parallel = True
args.use_ascend_mc2 = use_ascend_mc2
args.transformer_impl = "transformer_engine"
set_args(args)
config = core_transformer_config_from_args(args)
initialize_model_parallel(self.world_size, 1)
_set_random_seed(seed_=123, data_parallel_random_init=False)
inputs = torch.rand(
seq_size, batch_size, input_size // args.tensor_model_parallel_size, requires_grad=True, dtype=dtype
).npu()
teinputs = inputs.clone()
linear = RowParallelLinear(
input_size=input_size,
output_size=output_size,
config=config,
init_method=config.init_method,
bias=False,
input_is_parallel=True,
skip_bias_add=True,
is_expert=False,
)
telinear = TERowParallelLinear(
input_size=input_size,
output_size=output_size,
config=config,
init_method=config.init_method,
bias=False,
input_is_parallel=True,
skip_bias_add=True,
is_expert=False,
)
telinear.weight = Parameter(linear.weight.clone())
outputs = linear(inputs)
teoutputs = telinear(teinputs)
y_grad = torch.ones(seq_size // args.tensor_model_parallel_size, batch_size, output_size, dtype=dtype).npu()
outputs[0].backward(y_grad)
teoutputs[0].backward(y_grad)
assert torch.allclose(outputs[0], teoutputs[0], rtol=rtol, atol=atol)
assert torch.allclose(linear.weight.grad, telinear.weight.grad, rtol=rtol, atol=atol)
class TestTERowParallelNoSeq(DistributedTest):
world_size = 8
@pytest.mark.parametrize("use_ascend_mc2", [False])
@pytest.mark.parametrize("limit_args", [
(torch.bfloat16, 0.005, 0.005)
])
def test_te_row_parallel_no_seq(self, use_ascend_mc2, limit_args):
batch_size = 1
seq_size = 4096
input_size = 2048
output_size = 4096
dtype, rtol, atol = limit_args
args = parse_args(None, True)
args.params_dtype = dtype
args.init_method_std = 0.002
args.num_attention_heads = 16
args.hidden_size = 2048
args.num_layers = 2
args.gradient_accumulation_fusion = False
args.tensor_model_parallel_size = self.world_size
args.sequence_parallel = False
args.use_ascend_mc2 = use_ascend_mc2
args.transformer_impl = "transformer_engine"
set_args(args)
config = core_transformer_config_from_args(args)
initialize_model_parallel(self.world_size, 1)
_set_random_seed(seed_=123, data_parallel_random_init=False)
inputs = torch.rand(
seq_size, batch_size, input_size // args.tensor_model_parallel_size, requires_grad=True, dtype=dtype
).npu()
teinputs = inputs.clone()
linear = RowParallelLinear(
input_size=input_size,
output_size=output_size,
config=config,
init_method=config.init_method,
bias=False,
input_is_parallel=True,
skip_bias_add=True,
is_expert=False,
)
telinear = TERowParallelLinear(
input_size=input_size,
output_size=output_size,
config=config,
init_method=config.init_method,
bias=False,
input_is_parallel=True,
skip_bias_add=True,
is_expert=False,
)
telinear.weight = Parameter(linear.weight.clone())
outputs = linear(inputs)
teoutputs = telinear(teinputs)
y_grad = torch.ones(seq_size, batch_size, output_size, dtype=dtype).npu()
outputs[0].backward(y_grad)
teoutputs[0].backward(y_grad)
assert torch.allclose(outputs[0], teoutputs[0], rtol=rtol, atol=atol)
assert torch.allclose(linear.weight.grad, telinear.weight.grad, rtol=rtol, atol=atol)