"""
Test cases for torch_npu.npu_mm_all_reduce_base API
用例说明(共20个用例):
非量化场景(x1和x2类型一致):
- 用例1: x1=float16, x2=float16, 2维输入
- 用例2: x1=bfloat16, x2=bfloat16, 2维输入
- 用例3: x1=float16, x2=float16, 3维输入
- 用例4: x1=bfloat16, x2=bfloat16, 3维输入
伪量化场景(x1=float16/bfloat16, x2=int8, 需要antiquant_scale/offset):
- 用例5: x1=float16, x2=int8, perchannel伪量化
- 用例6: x1=bfloat16, x2=int8, perchannel伪量化
- 用例7: x1=float16, x2=int8, pertensor伪量化
- 用例8: x1=bfloat16, x2=int8, pertensor伪量化
全量化场景(x1=int8, x2=int8, 需要dequant_scale):
- 用例9: x1=int8, x2=int8, dequant_scale=int64
- 用例10: x1=int8, x2=int8, dequant_scale=bfloat16
- 用例11: x1=int8, x2=int8, dequant_scale=float32, pertoken_scale=float32
- 用例12: x1=int8, x2=int8, dequant_scale=bfloat16, pertoken_scale=float32
扩展场景(使用x1_dtype/x2_dtype参数支持float8/hifloat8):
- 用例13: x1=float8_e4m3fn, x2=float8_e4m3fn
- 用例14: x1=float8_e5m2, x2=float8_e5m2
- 用例15: x1=hifloat8, x2=hifloat8
- 用例16: x1=float8_e4m3fn, x2=float8_e5m2
- 用例17: x1=float8_e5m2, x2=float8_e4m3fn
- 用例18: x1=hifloat8, x2=float8_e4m3fn
- 用例19: x1=float8_e4m3fn, x2=hifloat8
- 用例20: x1=hifloat8, x2=float8_e5m2
约束说明:
- x1支持2维或3维,x2必须是2维
- 非量化场景:x1和x2数据类型需一致
- 伪量化场景:x1为float16/bfloat16,x2为int8
- 全量化场景:x1和x2都为int8
- 支持1、2、4、8卡,仅支持hccs链路all mesh组网
"""
import os
import unittest
import numpy as np
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
import torch_npu
from torch_npu.testing.testcase import TestCase, run_tests
from torch_npu.testing.common_utils import create_common_tensor, SupportedDevices
from torch_npu.testing.common_distributed import skipIfUnsupportMultiNPU
class TestMmAllReduceBase(TestCase):
@classmethod
def _init_dist_hccl(cls, rank, world_size):
os.environ['MASTER_ADDR'] = '127.0.0.1'
os.environ['MASTER_PORT'] = '50000'
os.environ['HCCL_WHITELIST_DISABLE'] = '1'
torch_npu.npu.set_device(rank)
dist.init_process_group(backend='hccl', world_size=world_size, rank=rank)
return dist
@classmethod
def _test_npu_mm_all_reduce_base(cls, rank, input_list):
x1, x2, world_size, init_pg, c2p = input_list
pg = init_pg(rank, world_size)
group = pg.distributed_c10d._get_default_group()
if torch.__version__ > '2.0.1':
hcom_name = group._get_backend(torch.device('npu')).get_hccl_comm_name(rank)
else:
hcom_name = group.get_hccl_comm_name(rank)
x1 = x1.npu()
x2 = x2.npu()
out = torch_npu.npu_mm_all_reduce_base(x1, x2, hcom_name, reduce_op='sum', bias=None, comm_turn=0)
c2p.put((rank, out.cpu()))
pg.barrier()
def _test_multiprocess(self, f, init_pg, input_list):
expt_out_list, x1_list, x2_list, world_size = input_list
ctx = mp.get_context('spawn')
c2p = ctx.Queue(world_size)
ps = []
for i in range(world_size):
p = ctx.Process(
target=f,
args=(i, [x1_list[i], x2_list[i], world_size, init_pg, c2p]))
p.start()
ps.append(p)
for p in ps:
p.join()
def _construct_excepted_result(self, x1_list, x2_list, world_size):
out = None
out_list = []
for i in range(world_size):
x1 = x1_list[i]
x2 = x2_list[i]
out_single = torch.matmul(x1.to(torch.float), x2.to(torch.float))
if out is None:
out = out_single
else:
out = torch.add(out, out_single)
for i in range(world_size):
out_list.append(out.to(x1_list[0].dtype))
return out_list
def _construct_excepted_result_3d(self, x1_list, x2_list, world_size):
out = None
out_list = []
for i in range(world_size):
x1 = x1_list[i]
x2 = x2_list[i]
b, s, k = x1.shape
x1_2d = x1.reshape(b * s, k)
out_single = torch.matmul(x1_2d.to(torch.float), x2.to(torch.float))
out_single = out_single.reshape(b, s, -1)
if out is None:
out = out_single
else:
out = torch.add(out, out_single)
for i in range(world_size):
out_list.append(out.to(x1_list[0].dtype))
return out_list
@skipIfUnsupportMultiNPU(8)
@SupportedDevices(['Ascend910B', 'Ascend950'])
def test_npu_mm_all_reduce_base_fp16_fp16_2d(self):
"""
用例1: 非量化场景 - x1=float16, x2=float16, 2维输入
"""
world_size = 8
m, k, n = 128, 512, 256
x1_list = []
x2_list = []
for _ in range(world_size):
x1 = torch.randn(m, k, dtype=torch.float16).uniform_(-1, 1)
x2 = torch.randn(k, n, dtype=torch.float16).uniform_(-1, 1)
x1_list.append(x1)
x2_list.append(x2)
expt_out_list = self._construct_excepted_result(x1_list, x2_list, world_size)
self._test_multiprocess(TestMmAllReduceBase._test_npu_mm_all_reduce_base,
TestMmAllReduceBase._init_dist_hccl, [expt_out_list, x1_list, x2_list, world_size])
@skipIfUnsupportMultiNPU(8)
@SupportedDevices(['Ascend910B', 'Ascend950'])
def test_npu_mm_all_reduce_base_bf16_bf16_2d(self):
"""
用例2: 非量化场景 - x1=bfloat16, x2=bfloat16, 2维输入
"""
world_size = 8
m, k, n = 128, 512, 256
x1_list = []
x2_list = []
for _ in range(world_size):
x1 = torch.randn(m, k, dtype=torch.bfloat16).uniform_(-1, 1)
x2 = torch.randn(k, n, dtype=torch.bfloat16).uniform_(-1, 1)
x1_list.append(x1)
x2_list.append(x2)
expt_out_list = self._construct_excepted_result(x1_list, x2_list, world_size)
self._test_multiprocess(TestMmAllReduceBase._test_npu_mm_all_reduce_base,
TestMmAllReduceBase._init_dist_hccl, [expt_out_list, x1_list, x2_list, world_size])
@skipIfUnsupportMultiNPU(8)
@SupportedDevices(['Ascend910B', 'Ascend950'])
def test_npu_mm_all_reduce_base_fp16_fp16_3d(self):
"""
用例3: 非量化场景 - x1=float16, x2=float16, 3维输入
"""
world_size = 8
b, s, k, n = 2, 64, 512, 256
x1_list = []
x2_list = []
for _ in range(world_size):
x1 = torch.randn(b, s, k, dtype=torch.float16).uniform_(-1, 1)
x2 = torch.randn(k, n, dtype=torch.float16).uniform_(-1, 1)
x1_list.append(x1)
x2_list.append(x2)
expt_out_list = self._construct_excepted_result_3d(x1_list, x2_list, world_size)
self._test_multiprocess(TestMmAllReduceBase._test_npu_mm_all_reduce_base,
TestMmAllReduceBase._init_dist_hccl, [expt_out_list, x1_list, x2_list, world_size])
@skipIfUnsupportMultiNPU(8)
@SupportedDevices(['Ascend910B', 'Ascend950'])
def test_npu_mm_all_reduce_base_bf16_bf16_3d(self):
"""
用例4: 非量化场景 - x1=bfloat16, x2=bfloat16, 3维输入
"""
world_size = 8
b, s, k, n = 2, 64, 512, 256
x1_list = []
x2_list = []
for _ in range(world_size):
x1 = torch.randn(b, s, k, dtype=torch.bfloat16).uniform_(-1, 1)
x2 = torch.randn(k, n, dtype=torch.bfloat16).uniform_(-1, 1)
x1_list.append(x1)
x2_list.append(x2)
expt_out_list = self._construct_excepted_result_3d(x1_list, x2_list, world_size)
self._test_multiprocess(TestMmAllReduceBase._test_npu_mm_all_reduce_base,
TestMmAllReduceBase._init_dist_hccl, [expt_out_list, x1_list, x2_list, world_size])
@classmethod
def _test_npu_mm_all_reduce_base_quant(cls, rank, input_list):
x1, x2, scale, offset, world_size, init_pg, c2p = input_list
pg = init_pg(rank, world_size)
group = pg.distributed_c10d._get_default_group()
if torch.__version__ > '2.0.1':
hcom_name = group._get_backend(torch.device('npu')).get_hccl_comm_name(rank)
else:
hcom_name = group.get_hccl_comm_name(rank)
x1 = x1.npu()
x2 = x2.npu()
scale = scale.npu()
offset = offset.npu()
out = torch_npu.npu_mm_all_reduce_base(x1, x2, hcom_name, reduce_op='sum', antiquant_scale=scale,
antiquant_offset=offset, bias=None, comm_turn=0)
c2p.put((rank, out.cpu()))
pg.barrier()
def _test_multiprocess_quant(self, f, init_pg, input_list):
expt_out_list, x1_list, x2_list, scale, offset, world_size = input_list
ctx = mp.get_context('spawn')
c2p = ctx.Queue(world_size)
ps = []
for i in range(world_size):
p = ctx.Process(
target=f,
args=(i, [x1_list[i], x2_list[i], scale, offset, world_size, init_pg, c2p]))
p.start()
ps.append(p)
for p in ps:
p.join()
def _construct_excepted_result_quant(self, x1_list, x2_list, scale, offset, world_size):
out = None
out_list = []
for i in range(world_size):
x1 = x1_list[i]
x2 = x2_list[i]
weight = torch.add(x2.to(torch.float32), offset.to(torch.float32))
dequant = torch.mul(weight, scale.to(torch.float32))
out_single = torch.matmul(x1.to(torch.float32), dequant)
if out is None:
out = out_single
else:
out = torch.add(out, out_single)
for i in range(world_size):
out_list.append(out.to(x1_list[0].dtype))
return out_list
@skipIfUnsupportMultiNPU(8)
@SupportedDevices(['Ascend910B', 'Ascend950'])
def test_npu_mm_all_reduce_base_fp16_int8_perchannel(self):
"""
用例5: 伪量化场景 - x1=float16, x2=int8, perchannel伪量化
"""
world_size = 8
m, k, n = 1, 256, 256
x1_list = []
x2_list = []
scale = torch.randn(n, dtype=torch.float16).uniform_(0.001, 0.01)
offset = torch.randn(n, dtype=torch.float16).uniform_(-1, 1)
for _ in range(world_size):
x1 = torch.randn(m, k, dtype=torch.float16).uniform_(-1, 1)
x2 = torch.randint(-128, 127, (k, n), dtype=torch.int8)
x1_list.append(x1)
x2_list.append(x2)
expt_out_list = self._construct_excepted_result_quant(x1_list, x2_list, scale, offset, world_size)
self._test_multiprocess_quant(TestMmAllReduceBase._test_npu_mm_all_reduce_base_quant,
TestMmAllReduceBase._init_dist_hccl,
[expt_out_list, x1_list, x2_list, scale, offset, world_size])
@skipIfUnsupportMultiNPU(8)
@SupportedDevices(['Ascend910B', 'Ascend950'])
def test_npu_mm_all_reduce_base_bf16_int8_perchannel(self):
"""
用例6: 伪量化场景 - x1=bfloat16, x2=int8, perchannel伪量化
"""
world_size = 8
m, k, n = 1, 256, 256
x1_list = []
x2_list = []
scale = torch.randn(n, dtype=torch.bfloat16).uniform_(0.001, 0.01)
offset = torch.randn(n, dtype=torch.bfloat16).uniform_(-1, 1)
for _ in range(world_size):
x1 = torch.randn(m, k, dtype=torch.bfloat16).uniform_(-1, 1)
x2 = torch.randint(-128, 127, (k, n), dtype=torch.int8)
x1_list.append(x1)
x2_list.append(x2)
expt_out_list = self._construct_excepted_result_quant(x1_list, x2_list, scale, offset, world_size)
self._test_multiprocess_quant(TestMmAllReduceBase._test_npu_mm_all_reduce_base_quant,
TestMmAllReduceBase._init_dist_hccl,
[expt_out_list, x1_list, x2_list, scale, offset, world_size])
@skipIfUnsupportMultiNPU(8)
@SupportedDevices(['Ascend910B', 'Ascend950'])
def test_npu_mm_all_reduce_base_fp16_int8_pertensor(self):
"""
用例7: 伪量化场景 - x1=float16, x2=int8, pertensor伪量化
"""
world_size = 8
m, k, n = 1, 256, 256
x1_list = []
x2_list = []
scale = torch.randn(1, dtype=torch.float16).uniform_(0.001, 0.01)
offset = torch.randn(1, dtype=torch.float16).uniform_(-1, 1)
for _ in range(world_size):
x1 = torch.randn(m, k, dtype=torch.float16).uniform_(-1, 1)
x2 = torch.randint(-128, 127, (k, n), dtype=torch.int8)
x1_list.append(x1)
x2_list.append(x2)
expt_out_list = self._construct_excepted_result_quant(x1_list, x2_list, scale, offset, world_size)
self._test_multiprocess_quant(TestMmAllReduceBase._test_npu_mm_all_reduce_base_quant,
TestMmAllReduceBase._init_dist_hccl,
[expt_out_list, x1_list, x2_list, scale, offset, world_size])
@skipIfUnsupportMultiNPU(8)
@SupportedDevices(['Ascend910B', 'Ascend950'])
def test_npu_mm_all_reduce_base_bf16_int8_pertensor(self):
"""
用例8: 伪量化场景 - x1=bfloat16, x2=int8, pertensor伪量化
"""
world_size = 8
m, k, n = 1, 256, 256
x1_list = []
x2_list = []
scale = torch.randn(1, dtype=torch.bfloat16).uniform_(0.001, 0.01)
offset = torch.randn(1, dtype=torch.bfloat16).uniform_(-1, 1)
for _ in range(world_size):
x1 = torch.randn(m, k, dtype=torch.bfloat16).uniform_(-1, 1)
x2 = torch.randint(-128, 127, (k, n), dtype=torch.int8)
x1_list.append(x1)
x2_list.append(x2)
expt_out_list = self._construct_excepted_result_quant(x1_list, x2_list, scale, offset, world_size)
self._test_multiprocess_quant(TestMmAllReduceBase._test_npu_mm_all_reduce_base_quant,
TestMmAllReduceBase._init_dist_hccl,
[expt_out_list, x1_list, x2_list, scale, offset, world_size])
@classmethod
def _test_npu_mm_all_reduce_base_dequant(cls, rank, input_list):
x1, x2, dequant_scale, world_size, init_pg, c2p = input_list
pg = init_pg(rank, world_size)
group = pg.distributed_c10d._get_default_group()
if torch.__version__ > '2.0.1':
hcom_name = group._get_backend(torch.device('npu')).get_hccl_comm_name(rank)
else:
hcom_name = group.get_hccl_comm_name(rank)
x1 = x1.npu()
x2 = x2.npu()
dequant_scale = dequant_scale.npu()
out = torch_npu.npu_mm_all_reduce_base(x1, x2, hcom_name, reduce_op='sum', dequant_scale=dequant_scale,
bias=None, comm_turn=0)
c2p.put((rank, out.cpu()))
pg.barrier()
def _test_multiprocess_dequant(self, f, init_pg, input_list):
expt_out_list, x1_list, x2_list, dequant_scale, world_size = input_list
ctx = mp.get_context('spawn')
c2p = ctx.Queue(world_size)
ps = []
for i in range(world_size):
p = ctx.Process(
target=f,
args=(i, [x1_list[i], x2_list[i], dequant_scale, world_size, init_pg, c2p]))
p.start()
ps.append(p)
for p in ps:
p.join()
def _construct_excepted_result_dequant(self, x1_list, x2_list, dequant_scale, world_size):
out = None
out_list = []
for i in range(world_size):
x1 = x1_list[i]
x2 = x2_list[i]
out_mm = torch.matmul(x1.to(torch.float32), x2.to(torch.float32))
out_single = torch.mul(out_mm, dequant_scale.to(torch.float32))
if out is None:
out = out_single
else:
out = torch.add(out, out_single)
for i in range(world_size):
out_list.append(out.to(torch.float16))
return out_list
@skipIfUnsupportMultiNPU(2)
@SupportedDevices(['Ascend910B', 'Ascend950'])
def test_npu_mm_all_reduce_base_int8_int8_dequant_int64(self):
"""
用例9: 全量化场景 - x1=int8, x2=int8, dequant_scale=int64
"""
world_size = 2
m, k, n = 1, 256, 256
x1_list = []
x2_list = []
scale = torch.randint(1, 100, (n,), dtype=torch.int64)
for _ in range(world_size):
x1 = torch.randint(-128, 127, (m, k), dtype=torch.int8)
x2 = torch.randint(-128, 127, (k, n), dtype=torch.int8)
x1_list.append(x1)
x2_list.append(x2)
expt_out_list = self._construct_excepted_result_dequant(x1_list, x2_list, scale, world_size)
self._test_multiprocess_dequant(TestMmAllReduceBase._test_npu_mm_all_reduce_base_dequant,
TestMmAllReduceBase._init_dist_hccl,
[expt_out_list, x1_list, x2_list, scale, world_size])
@skipIfUnsupportMultiNPU(2)
@SupportedDevices(['Ascend910B', 'Ascend950'])
def test_npu_mm_all_reduce_base_int8_int8_dequant_bf16(self):
"""
用例10: 全量化场景 - x1=int8, x2=int8, dequant_scale=bfloat16
"""
world_size = 2
m, k, n = 1, 256, 256
x1_list = []
x2_list = []
scale = torch.randn(n, dtype=torch.bfloat16).uniform_(0.001, 0.01)
for _ in range(world_size):
x1 = torch.randint(-128, 127, (m, k), dtype=torch.int8)
x2 = torch.randint(-128, 127, (k, n), dtype=torch.int8)
x1_list.append(x1)
x2_list.append(x2)
expt_out_list = self._construct_excepted_result_dequant(x1_list, x2_list, scale, world_size)
self._test_multiprocess_dequant(TestMmAllReduceBase._test_npu_mm_all_reduce_base_dequant,
TestMmAllReduceBase._init_dist_hccl,
[expt_out_list, x1_list, x2_list, scale, world_size])
@classmethod
def _test_npu_mm_all_reduce_base_dequant_pertoken(cls, rank, input_list):
x1, x2, dequant_scale, pertoken_scale, world_size, init_pg, c2p = input_list
pg = init_pg(rank, world_size)
group = pg.distributed_c10d._get_default_group()
if torch.__version__ > '2.0.1':
hcom_name = group._get_backend(torch.device('npu')).get_hccl_comm_name(rank)
else:
hcom_name = group.get_hccl_comm_name(rank)
x1 = x1.npu()
x2 = x2.npu()
dequant_scale = dequant_scale.npu()
pertoken_scale = pertoken_scale.npu()
out = torch_npu.npu_mm_all_reduce_base(x1, x2, hcom_name, reduce_op='sum', dequant_scale=dequant_scale,
pertoken_scale=pertoken_scale, bias=None, comm_turn=0)
c2p.put((rank, out.cpu()))
pg.barrier()
def _test_multiprocess_dequant_pertoken(self, f, init_pg, input_list):
expt_out_list, x1_list, x2_list, dequant_scale, pertoken_scale, world_size = input_list
ctx = mp.get_context('spawn')
c2p = ctx.Queue(world_size)
ps = []
for i in range(world_size):
p = ctx.Process(
target=f,
args=(i, [x1_list[i], x2_list[i], dequant_scale, pertoken_scale, world_size, init_pg, c2p]))
p.start()
ps.append(p)
for p in ps:
p.join()
def _construct_excepted_result_dequant_pertoken(self, x1_list, x2_list, dequant_scale, pertoken_scale, world_size):
out = None
out_list = []
for i in range(world_size):
x1 = x1_list[i]
x2 = x2_list[i]
out_mm = torch.matmul(x1.to(torch.float32), x2.to(torch.float32))
out_single = torch.mul(out_mm, dequant_scale.to(torch.float32))
out_single = torch.mul(out_single, pertoken_scale.unsqueeze(1).to(torch.float32))
if out is None:
out = out_single
else:
out = torch.add(out, out_single)
for i in range(world_size):
out_list.append(out.to(torch.float16))
return out_list
@skipIfUnsupportMultiNPU(2)
@SupportedDevices(['Ascend910B', 'Ascend950'])
def test_npu_mm_all_reduce_base_int8_int8_dequant_fp32_pertoken_fp32(self):
"""
用例11: 全量化场景 - x1=int8, x2=int8, dequant_scale=float32, pertoken_scale=float32
"""
world_size = 2
m, k, n = 1, 256, 256
x1_list = []
x2_list = []
scale = torch.randn(n, dtype=torch.float32).uniform_(0.001, 0.01)
pertoken_scale = torch.randn(m, dtype=torch.float32).uniform_(0.001, 0.01)
for _ in range(world_size):
x1 = torch.randint(-128, 127, (m, k), dtype=torch.int8)
x2 = torch.randint(-128, 127, (k, n), dtype=torch.int8)
x1_list.append(x1)
x2_list.append(x2)
expt_out_list = self._construct_excepted_result_dequant_pertoken(x1_list, x2_list, scale, pertoken_scale, world_size)
self._test_multiprocess_dequant_pertoken(TestMmAllReduceBase._test_npu_mm_all_reduce_base_dequant_pertoken,
TestMmAllReduceBase._init_dist_hccl,
[expt_out_list, x1_list, x2_list, scale, pertoken_scale, world_size])
@skipIfUnsupportMultiNPU(2)
@SupportedDevices(['Ascend910B', 'Ascend950'])
def test_npu_mm_all_reduce_base_int8_int8_dequant_bf16_pertoken_fp32(self):
"""
用例12: 全量化场景 - x1=int8, x2=int8, dequant_scale=bfloat16, pertoken_scale=float32
"""
world_size = 2
m, k, n = 1, 256, 256
x1_list = []
x2_list = []
scale = torch.randn(n, dtype=torch.bfloat16).uniform_(0.001, 0.01)
pertoken_scale = torch.randn(m, dtype=torch.float32).uniform_(0.001, 0.01)
for _ in range(world_size):
x1 = torch.randint(-128, 127, (m, k), dtype=torch.int8)
x2 = torch.randint(-128, 127, (k, n), dtype=torch.int8)
x1_list.append(x1)
x2_list.append(x2)
expt_out_list = self._construct_excepted_result_dequant_pertoken(x1_list, x2_list, scale, pertoken_scale, world_size)
self._test_multiprocess_dequant_pertoken(TestMmAllReduceBase._test_npu_mm_all_reduce_base_dequant_pertoken,
TestMmAllReduceBase._init_dist_hccl,
[expt_out_list, x1_list, x2_list, scale, pertoken_scale, world_size])
@classmethod
def _test_npu_mm_all_reduce_base_fp8(cls, rank, input_list):
x1, x2, dequant_scale, world_size, x1_dtype, x2_dtype, output_dtype, init_pg, c2p = input_list
pg = init_pg(rank, world_size)
group = pg.distributed_c10d._get_default_group()
if torch.__version__ > '2.0.1':
hcom_name = group._get_backend(torch.device('npu')).get_hccl_comm_name(rank)
else:
hcom_name = group.get_hccl_comm_name(rank)
x1 = x1.npu()
x2 = x2.npu()
dequant_scale = dequant_scale.npu()
out = torch_npu.npu_mm_all_reduce_base(x1, x2, hcom_name, reduce_op='sum', dequant_scale=dequant_scale,
bias=None, comm_turn=0, x1_dtype=x1_dtype, x2_dtype=x2_dtype,
y_dtype=output_dtype)
c2p.put((rank, out.cpu()))
pg.barrier()
def _test_multiprocess_fp8(self, f, init_pg, input_list):
expt_out_list, x1_list, x2_list, dequant_scale, world_size, x1_dtype, x2_dtype, output_dtype = input_list
ctx = mp.get_context('spawn')
c2p = ctx.Queue(world_size)
ps = []
for i in range(world_size):
p = ctx.Process(
target=f,
args=(i, [x1_list[i], x2_list[i], dequant_scale, world_size, x1_dtype, x2_dtype, output_dtype, init_pg, c2p]))
p.start()
ps.append(p)
for p in ps:
p.join()
def _construct_excepted_result_fp8(self, x1_list, x2_list, dequant_scale, world_size, output_dtype):
out = None
out_list = []
for i in range(world_size):
x1 = x1_list[i].to(torch.float32)
x2 = x2_list[i].to(torch.float32)
out_mm = torch.matmul(x1, x2)
out_single = torch.mul(out_mm, dequant_scale.to(torch.float32))
if out is None:
out = out_single
else:
out = torch.add(out, out_single)
for i in range(world_size):
out_list.append(out.to(output_dtype))
return out_list
@skipIfUnsupportMultiNPU(2)
@SupportedDevices(['Ascend950'])
def test_npu_mm_all_reduce_base_fp8_e4m3fn_fp8_e4m3fn(self):
"""
用例13: 扩展场景 - x1=float8_e4m3fn, x2=float8_e4m3fn
"""
world_size = 2
m, k, n = 16, 256, 256
x1 = torch_npu.npu_dtype_cast(torch.randint(-5, 5, (m, k)), torch.float8_e4m3fn)
x2 = torch_npu.npu_dtype_cast(torch.randint(-5, 5, (k, n)), torch.float8_e4m3fn)
dequant_scale = torch.randn(n, dtype=torch.float32).uniform_(0.001, 0.01)
x1_list = [x1.clone() for _ in range(world_size)]
x2_list = [x2.clone() for _ in range(world_size)]
expt_out_list = self._construct_excepted_result_fp8(x1_list, x2_list, dequant_scale, world_size, torch.bfloat16)
self._test_multiprocess_fp8(TestMmAllReduceBase._test_npu_mm_all_reduce_base_fp8,
TestMmAllReduceBase._init_dist_hccl,
[expt_out_list, x1_list, x2_list, dequant_scale, world_size,
torch_npu.hifloat8, torch_npu.hifloat8, torch.bfloat16])
@skipIfUnsupportMultiNPU(2)
@SupportedDevices(['Ascend950'])
def test_npu_mm_all_reduce_base_fp8_e5m2_fp8_e5m2(self):
"""
用例14: 扩展场景 - x1=float8_e5m2, x2=float8_e5m2
"""
world_size = 2
m, k, n = 16, 256, 256
x1 = torch_npu.npu_dtype_cast(torch.randint(-5, 5, (m, k)), torch.float8_e5m2)
x2 = torch_npu.npu_dtype_cast(torch.randint(-5, 5, (k, n)), torch.float8_e5m2)
dequant_scale = torch.randn(n, dtype=torch.float32).uniform_(0.001, 0.01)
x1_list = [x1.clone() for _ in range(world_size)]
x2_list = [x2.clone() for _ in range(world_size)]
expt_out_list = self._construct_excepted_result_fp8(x1_list, x2_list, dequant_scale, world_size, torch.bfloat16)
self._test_multiprocess_fp8(TestMmAllReduceBase._test_npu_mm_all_reduce_base_fp8,
TestMmAllReduceBase._init_dist_hccl,
[expt_out_list, x1_list, x2_list, dequant_scale, world_size,
torch_npu.hifloat8, torch_npu.hifloat8, torch.bfloat16])
@skipIfUnsupportMultiNPU(2)
@SupportedDevices(['Ascend950'])
def test_npu_mm_all_reduce_base_hifloat8_hifloat8(self):
"""
用例15: 扩展场景 - x1=hifloat8, x2=hifloat8
"""
world_size = 2
m, k, n = 16, 256, 256
x1_int8 = torch.randint(-5, 5, (m, k), dtype=torch.int8)
x2_int8 = torch.randint(-5, 5, (k, n), dtype=torch.int8)
x1 = torch_npu.HiFloat8Tensor.to_hifloat8(x1_int8)
x2 = torch_npu.HiFloat8Tensor.to_hifloat8(x2_int8)
dequant_scale = torch.randn(n, dtype=torch.float32).uniform_(0.001, 0.01)
x1_list = [x1.clone() for _ in range(world_size)]
x2_list = [x2.clone() for _ in range(world_size)]
expt_out_list = self._construct_excepted_result_fp8(x1_list, x2_list, dequant_scale, world_size, torch.bfloat16)
self._test_multiprocess_fp8(TestMmAllReduceBase._test_npu_mm_all_reduce_base_fp8,
TestMmAllReduceBase._init_dist_hccl,
[expt_out_list, x1_list, x2_list, dequant_scale, world_size,
torch_npu.hifloat8, torch_npu.hifloat8, torch.bfloat16])
@skipIfUnsupportMultiNPU(2)
@SupportedDevices(['Ascend950'])
def test_npu_mm_all_reduce_base_fp8_e4m3fn_fp8_e5m2(self):
"""
用例16: 扩展场景 - x1=float8_e4m3fn, x2=float8_e5m2
"""
world_size = 2
m, k, n = 16, 256, 256
x1 = torch_npu.npu_dtype_cast(torch.randint(-5, 5, (m, k)), torch.float8_e4m3fn)
x2 = torch_npu.npu_dtype_cast(torch.randint(-5, 5, (k, n)), torch.float8_e5m2)
dequant_scale = torch.randn(n, dtype=torch.float32).uniform_(0.001, 0.01)
x1_list = [x1.clone() for _ in range(world_size)]
x2_list = [x2.clone() for _ in range(world_size)]
expt_out_list = self._construct_excepted_result_fp8(x1_list, x2_list, dequant_scale, world_size, torch.bfloat16)
self._test_multiprocess_fp8(TestMmAllReduceBase._test_npu_mm_all_reduce_base_fp8,
TestMmAllReduceBase._init_dist_hccl,
[expt_out_list, x1_list, x2_list, dequant_scale, world_size,
torch_npu.hifloat8, torch_npu.hifloat8, torch.bfloat16])
@skipIfUnsupportMultiNPU(2)
@SupportedDevices(['Ascend950'])
def test_npu_mm_all_reduce_base_fp8_e5m2_fp8_e4m3fn(self):
"""
用例17: 扩展场景 - x1=float8_e5m2, x2=float8_e4m3fn
"""
world_size = 2
m, k, n = 16, 256, 256
x1 = torch_npu.npu_dtype_cast(torch.randint(-5, 5, (m, k)), torch.float8_e5m2)
x2 = torch_npu.npu_dtype_cast(torch.randint(-5, 5, (k, n)), torch.float8_e4m3fn)
dequant_scale = torch.randn(n, dtype=torch.float32).uniform_(0.001, 0.01)
x1_list = [x1.clone() for _ in range(world_size)]
x2_list = [x2.clone() for _ in range(world_size)]
expt_out_list = self._construct_excepted_result_fp8(x1_list, x2_list, dequant_scale, world_size, torch.bfloat16)
self._test_multiprocess_fp8(TestMmAllReduceBase._test_npu_mm_all_reduce_base_fp8,
TestMmAllReduceBase._init_dist_hccl,
[expt_out_list, x1_list, x2_list, dequant_scale, world_size,
torch_npu.hifloat8, torch_npu.hifloat8, torch.bfloat16])
@skipIfUnsupportMultiNPU(2)
@SupportedDevices(['Ascend950'])
def test_npu_mm_all_reduce_base_hifloat8_fp8_e4m3fn(self):
"""
用例18: 扩展场景 - x1=hifloat8, x2=float8_e4m3fn
"""
world_size = 2
m, k, n = 16, 256, 256
x1_int8 = torch.randint(-5, 5, (m, k), dtype=torch.int8)
x1 = torch_npu.HiFloat8Tensor.to_hifloat8(x1_int8)
x2 = torch_npu.npu_dtype_cast(torch.randint(-5, 5, (k, n)), torch.float8_e4m3fn)
dequant_scale = torch.randn(n, dtype=torch.float32).uniform_(0.001, 0.01)
x1_list = [x1.clone() for _ in range(world_size)]
x2_list = [x2.clone() for _ in range(world_size)]
expt_out_list = self._construct_excepted_result_fp8(x1_list, x2_list, dequant_scale, world_size, torch.bfloat16)
self._test_multiprocess_fp8(TestMmAllReduceBase._test_npu_mm_all_reduce_base_fp8,
TestMmAllReduceBase._init_dist_hccl,
[expt_out_list, x1_list, x2_list, dequant_scale, world_size,
torch_npu.hifloat8, torch_npu.hifloat8, torch.bfloat16])
@skipIfUnsupportMultiNPU(2)
@SupportedDevices(['Ascend950'])
def test_npu_mm_all_reduce_base_fp8_e4m3fn_hifloat8(self):
"""
用例19: 扩展场景 - x1=float8_e4m3fn, x2=hifloat8
"""
world_size = 2
m, k, n = 16, 256, 256
x1 = torch_npu.npu_dtype_cast(torch.randint(-5, 5, (m, k)), torch.float8_e4m3fn)
x2_int8 = torch.randint(-5, 5, (k, n), dtype=torch.int8)
x2 = torch_npu.HiFloat8Tensor.to_hifloat8(x2_int8)
dequant_scale = torch.randn(n, dtype=torch.float32).uniform_(0.001, 0.01)
x1_list = [x1.clone() for _ in range(world_size)]
x2_list = [x2.clone() for _ in range(world_size)]
expt_out_list = self._construct_excepted_result_fp8(x1_list, x2_list, dequant_scale, world_size, torch.bfloat16)
self._test_multiprocess_fp8(TestMmAllReduceBase._test_npu_mm_all_reduce_base_fp8,
TestMmAllReduceBase._init_dist_hccl,
[expt_out_list, x1_list, x2_list, dequant_scale, world_size,
torch_npu.hifloat8, torch_npu.hifloat8, torch.bfloat16])
@skipIfUnsupportMultiNPU(2)
@SupportedDevices(['Ascend950'])
def test_npu_mm_all_reduce_base_hifloat8_fp8_e5m2(self):
"""
用例20: 扩展场景 - x1=hifloat8, x2=float8_e5m2
"""
world_size = 2
m, k, n = 16, 256, 256
x1_int8 = torch.randint(-5, 5, (m, k), dtype=torch.int8)
x1 = torch_npu.HiFloat8Tensor.to_hifloat8(x1_int8)
x2 = torch_npu.npu_dtype_cast(torch.randint(-5, 5, (k, n)), torch.float8_e5m2)
dequant_scale = torch.randn(n, dtype=torch.float32).uniform_(0.001, 0.01)
x1_list = [x1.clone() for _ in range(world_size)]
x2_list = [x2.clone() for _ in range(world_size)]
expt_out_list = self._construct_excepted_result_fp8(x1_list, x2_list, dequant_scale, world_size, torch.bfloat16)
self._test_multiprocess_fp8(TestMmAllReduceBase._test_npu_mm_all_reduce_base_fp8,
TestMmAllReduceBase._init_dist_hccl,
[expt_out_list, x1_list, x2_list, dequant_scale, world_size,
torch_npu.hifloat8, torch_npu.hifloat8, torch.bfloat16])
if __name__ == '__main__':
run_tests()