"""
Minimal ST for OP_L0C_COPY_UB_DUAL_DST fusion (SplitN direction).
Kernel layout:
matmul(A, B^T) -> mm: [M, N] on L0C
upper = mm[:, :N/2] # 1st L0C_COPY_UB, N 轴前半
lower = mm[:, N/2:] # 2nd L0C_COPY_UB, N 轴后半
out0 = upper + 1.0 # AIV0 vector chain
out1 = lower + 2.0 # AIV1 vector chain
The two L0C_COPY_UB ops:
- share the same L0C input tensor (mm)
- have identical UB output shape ([M, N/2]) and validShape
- are N-axis adjacent (offsets 0 and N/2)
- feed two independent vector ops, which MixSchedule's CoreScheduler
binds to AIV0 / AIV1 (via existing DualDstProcess preCoreAssign)
When `kTempEnableDualDst = true` 在 schedule_ooo.cpp 里被打开, OoOSchedule
RunDualDstFuse 阶段会把这两个 L0C_COPY_UB 合并成单个 OP_L0C_COPY_UB_DUAL_DST,
保留一个 ALLOC, 删另一个; 双 UB 池 (AIV0/AIV1) 同地址联合分配。开关关时退化为
两条普通 L0C_COPY_UB, 测试结果保持一致 (数值正确性与开关无关)。
Run:
pytest python/tests/st/test_dualdst.py -v
"""
import sys
import os
import torch
import torch_npu
import pypto
import pytest
from numpy.testing import assert_allclose
import numpy as np
FP16 = pypto.DT_FP16
FP32 = pypto.DT_FP32
_DUALDST_JIT = pypto.frontend.jit(
debug_options={"runtime_debug_mode": 0, "compile_debug_mode": 0},
pass_options={"auto_mix_partition": 1}
)
def _matmul_split_n(a_tile, b_tile, half_n):
"""matmul(b_trans=True, FP32 输出) + 沿 N 轴二分; 返回 (upper, lower).
所有 dualdst kernel 的核心三行模式都用它, 避免每个 kernel 重复 mm + 上下半切片。
"""
mm = pypto.matmul(a_tile, b_tile, b_trans=True, out_dtype=FP32)
return mm[:, :half_n], mm[:, half_n:]
def _split_n_prologue_64(a_tensor, b_tensor):
"""3 个 SplitN dualdst kernel 共用的前缀.
cube tile (64,64,128) -> matmul -> 沿 N 二分 -> 设 vec tile (64,64);
返回 (upper, lower) 两份 L0C view。
"""
pypto.set_cube_tile_shapes([64, 64], [64, 64], [128, 128])
upper, lower = _matmul_split_n(a_tensor, b_tensor, 64)
pypto.set_vec_tile_shapes(64, 64)
return upper, lower
@_DUALDST_JIT
def dual_dst_split_n_kernel(
a_tensor: pypto.Tensor([pypto.STATIC, pypto.STATIC], FP16),
b_tensor: pypto.Tensor([pypto.STATIC, pypto.STATIC], FP16),
out0_tensor: pypto.Tensor([pypto.STATIC, pypto.STATIC], FP32),
out1_tensor: pypto.Tensor([pypto.STATIC, pypto.STATIC], FP32),
):
"""matmul -> L0C -> N 轴二分 -> 两条独立 add-scalar vector chain.
单 cube tile 覆盖 M/K/N: 一次 matmul 产出唯一一份 L0C tensor,
它的两个 L0C_COPY_UB consumer 才是 dual_dst 候选对。
"""
upper, lower = _split_n_prologue_64(a_tensor, b_tensor)
out0_tensor[:, :] = upper + 1.0
pypto.set_vec_tile_shapes(64, 64)
out1_tensor[:, :] = lower + 2.0
@pytest.mark.soc("950")
@pytest.mark.skip(reason="large test case")
def test_dual_dst_split_n():
device_id = int(os.environ.get('TILE_FWK_DEVICE_ID', 0))
torch.npu.set_device(device_id)
m, k, n = 64, 64, 128
half_n = n // 2
torch.manual_seed(0)
a = torch.rand([m, k], dtype=torch.float16, device=f"npu:{device_id}")
b = torch.rand([n, k], dtype=torch.float16, device=f"npu:{device_id}")
out0 = torch.zeros([m, half_n], dtype=torch.float32, device=f"npu:{device_id}")
out1 = torch.zeros([m, half_n], dtype=torch.float32, device=f"npu:{device_id}")
dual_dst_split_n_kernel(a, b, out0, out1)
mm_golden = torch.matmul(a.to(torch.float32), b.to(torch.float32).T)
golden0 = mm_golden[:, :half_n] + 1.0
golden1 = mm_golden[:, half_n:] + 2.0
assert_allclose(
out0.cpu().to(torch.float32).numpy(),
golden0.cpu().numpy(),
rtol=5e-3, atol=5e-3,
)
assert_allclose(
out1.cpu().to(torch.float32).numpy(),
golden1.cpu().numpy(),
rtol=5e-3, atol=5e-3,
)
@_DUALDST_JIT
def dual_dst_split_m_kernel(
a_tensor: pypto.Tensor([pypto.STATIC, pypto.STATIC], FP16),
b_tensor: pypto.Tensor([pypto.STATIC, pypto.STATIC], FP16),
out0_tensor: pypto.Tensor([pypto.STATIC, pypto.STATIC], FP32),
out1_tensor: pypto.Tensor([pypto.STATIC, pypto.STATIC], FP32),
):
"""matmul -> L0C -> M 轴二分 -> 两条独立 add-scalar vector chain (SplitM 方向)。"""
pypto.set_cube_tile_shapes([128, 128], [64, 64], [64, 64])
mm = pypto.matmul(a_tensor, b_tensor, b_trans=True, out_dtype=FP32)
half_m = mm.shape[0] // 2
upper = mm[:half_m, :]
lower = mm[half_m:, :]
pypto.set_vec_tile_shapes(64, 64)
out0_tensor[:, :] = upper + 1.0
pypto.set_vec_tile_shapes(64, 64)
out1_tensor[:, :] = lower + 2.0
@pytest.mark.soc("950")
@pytest.mark.skip(reason="large test case")
def test_dual_dst_split_m():
device_id = int(os.environ.get('TILE_FWK_DEVICE_ID', 0))
torch.npu.set_device(device_id)
m, k, n = 128, 64, 64
half_m = m // 2
torch.manual_seed(0)
a = torch.rand([m, k], dtype=torch.float16, device=f"npu:{device_id}")
b = torch.rand([n, k], dtype=torch.float16, device=f"npu:{device_id}")
out0 = torch.zeros([half_m, n], dtype=torch.float32, device=f"npu:{device_id}")
out1 = torch.zeros([half_m, n], dtype=torch.float32, device=f"npu:{device_id}")
dual_dst_split_m_kernel(a, b, out0, out1)
mm_golden = torch.matmul(a.to(torch.float32), b.to(torch.float32).T)
golden0 = mm_golden[:half_m, :] + 1.0
golden1 = mm_golden[half_m:, :] + 2.0
assert_allclose(out0.cpu().to(torch.float32).numpy(), golden0.cpu().numpy(),
rtol=5e-3, atol=5e-3)
assert_allclose(out1.cpu().to(torch.float32).numpy(), golden1.cpu().numpy(),
rtol=5e-3, atol=5e-3)
@_DUALDST_JIT
def dual_dst_chained_ops_kernel(
a_tensor: pypto.Tensor([pypto.STATIC, pypto.STATIC], FP16),
b_tensor: pypto.Tensor([pypto.STATIC, pypto.STATIC], FP16),
out0_tensor: pypto.Tensor([pypto.STATIC, pypto.STATIC], FP32),
out1_tensor: pypto.Tensor([pypto.STATIC, pypto.STATIC], FP32),
):
"""SplitN 同形状,但下游每条 chain 由 add+mul 两 vector op 组成.
验证 dualdst 融合后下游多步依赖链仍正确。
"""
upper, lower = _split_n_prologue_64(a_tensor, b_tensor)
tmp0 = upper + 1.0
pypto.set_vec_tile_shapes(64, 64)
out0_tensor[:, :] = tmp0 * 2.0
pypto.set_vec_tile_shapes(64, 64)
tmp1 = lower * 0.5
pypto.set_vec_tile_shapes(64, 64)
out1_tensor[:, :] = tmp1 + 3.0
@pytest.mark.soc("950")
@pytest.mark.skip(reason="large test case")
def test_dual_dst_chained_ops():
device_id = int(os.environ.get('TILE_FWK_DEVICE_ID', 0))
torch.npu.set_device(device_id)
m, k, n = 64, 64, 128
half_n = n // 2
torch.manual_seed(0)
a = torch.rand([m, k], dtype=torch.float16, device=f"npu:{device_id}")
b = torch.rand([n, k], dtype=torch.float16, device=f"npu:{device_id}")
out0 = torch.zeros([m, half_n], dtype=torch.float32, device=f"npu:{device_id}")
out1 = torch.zeros([m, half_n], dtype=torch.float32, device=f"npu:{device_id}")
dual_dst_chained_ops_kernel(a, b, out0, out1)
mm_golden = torch.matmul(a.to(torch.float32), b.to(torch.float32).T)
golden0 = (mm_golden[:, :half_n] + 1.0) * 2.0
golden1 = (mm_golden[:, half_n:] * 0.5) + 3.0
assert_allclose(out0.cpu().to(torch.float32).numpy(), golden0.cpu().numpy(),
rtol=5e-3, atol=5e-3)
assert_allclose(out1.cpu().to(torch.float32).numpy(), golden1.cpu().numpy(),
rtol=5e-3, atol=5e-3)
@_DUALDST_JIT
def dual_dst_asymmetric_scale_kernel(
a_tensor: pypto.Tensor([pypto.STATIC, pypto.STATIC], FP16),
b_tensor: pypto.Tensor([pypto.STATIC, pypto.STATIC], FP16),
out0_tensor: pypto.Tensor([pypto.STATIC, pypto.STATIC], FP32),
out1_tensor: pypto.Tensor([pypto.STATIC, pypto.STATIC], FP32),
):
"""SplitN, 两侧 vector chain 用不同 op (add vs mul).
验证两个 AIV core 实际执行不同语义但共享同一份 L0C 数据。
"""
upper, lower = _split_n_prologue_64(a_tensor, b_tensor)
out0_tensor[:, :] = upper + 7.5
pypto.set_vec_tile_shapes(64, 64)
out1_tensor[:, :] = lower * 1.25
@pytest.mark.soc("950")
@pytest.mark.skip(reason="large test case")
def test_dual_dst_asymmetric_scale():
device_id = int(os.environ.get('TILE_FWK_DEVICE_ID', 0))
torch.npu.set_device(device_id)
m, k, n = 64, 64, 128
half_n = n // 2
torch.manual_seed(0)
a = torch.rand([m, k], dtype=torch.float16, device=f"npu:{device_id}")
b = torch.rand([n, k], dtype=torch.float16, device=f"npu:{device_id}")
out0 = torch.zeros([m, half_n], dtype=torch.float32, device=f"npu:{device_id}")
out1 = torch.zeros([m, half_n], dtype=torch.float32, device=f"npu:{device_id}")
dual_dst_asymmetric_scale_kernel(a, b, out0, out1)
mm_golden = torch.matmul(a.to(torch.float32), b.to(torch.float32).T)
golden0 = mm_golden[:, :half_n] + 7.5
golden1 = mm_golden[:, half_n:] * 1.25
assert_allclose(out0.cpu().to(torch.float32).numpy(), golden0.cpu().numpy(),
rtol=5e-3, atol=5e-3)
assert_allclose(out1.cpu().to(torch.float32).numpy(), golden1.cpu().numpy(),
rtol=5e-3, atol=5e-3)
N_PAIRS = 8
GAIN_M = 128
GAIN_K = 32
GAIN_HALF_N = 64
GAIN_N = 2 * GAIN_HALF_N * N_PAIRS
@_DUALDST_JIT
def dual_dst_max_gain_kernel(
a_tensor: pypto.Tensor([pypto.STATIC, pypto.STATIC], FP16),
b_tensor: pypto.Tensor([pypto.STATIC, pypto.STATIC], FP16),
out_tensor: pypto.Tensor([pypto.STATIC, pypto.STATIC], FP32),
):
"""单次 matmul + N 轴 2*N_PAIRS 等分, 偶/奇下标分别落 AIV0/AIV1, 相邻 (偶,奇) 被 dualdst 合并.
mm shape: [GAIN_M, GAIN_N]; 每个子块宽 GAIN_HALF_N。
偶下标 +1.0 (consumer 落 AIV0), 奇下标 +2.0 (落 AIV1);
相邻 (偶,奇) 子块被 dualdst 合并, 累计 N_PAIRS 对融合。
所有切片边界都是 python int 常量。
"""
pypto.set_cube_tile_shapes([GAIN_M, GAIN_M], [GAIN_K, GAIN_K], [2 * GAIN_HALF_N, 2 * GAIN_HALF_N])
mm = pypto.matmul(a_tensor, b_tensor, b_trans=True, out_dtype=FP32)
for j in range(2 * N_PAIRS):
lo = j * GAIN_HALF_N
hi = (j + 1) * GAIN_HALF_N
seg = mm[:, lo:hi]
pypto.set_vec_tile_shapes(GAIN_M, GAIN_HALF_N)
out_tensor[:, lo:hi] = seg + (1.0 if j % 2 == 0 else 2.0)
@pytest.mark.soc("950")
@pytest.mark.skip(reason="large test case")
def test_dual_dst_max_gain():
device_id = int(os.environ.get('TILE_FWK_DEVICE_ID', 0))
torch.npu.set_device(device_id)
torch.manual_seed(0)
a = torch.rand([GAIN_M, GAIN_K], dtype=torch.float16, device=f"npu:{device_id}")
b = torch.rand([GAIN_N, GAIN_K], dtype=torch.float16, device=f"npu:{device_id}")
out = torch.zeros([GAIN_M, GAIN_N], dtype=torch.float32, device=f"npu:{device_id}")
dual_dst_max_gain_kernel(a, b, out)
mm_golden = torch.matmul(a.to(torch.float32), b.to(torch.float32).T)
out_golden = mm_golden.clone()
for j in range(2 * N_PAIRS):
lo = j * GAIN_HALF_N
hi = (j + 1) * GAIN_HALF_N
out_golden[:, lo:hi] = mm_golden[:, lo:hi] + (1.0 if j % 2 == 0 else 2.0)
assert_allclose(out.cpu().to(torch.float32).numpy(), out_golden.cpu().numpy(),
rtol=5e-3, atol=5e-3)
N_MM = 12
MM_M = 64
MM_K = 64
MM_N = 128
MM_HALF_N = MM_N // 2
ADDS_CHAIN_DEPTH = 4
@_DUALDST_JIT
def dual_dst_long_chain_kernel(
a_tensor: pypto.Tensor([pypto.STATIC, pypto.STATIC], FP16),
b_tensor: pypto.Tensor([pypto.STATIC, pypto.STATIC], FP16),
out0_tensor: pypto.Tensor([pypto.STATIC, pypto.STATIC], FP32),
out1_tensor: pypto.Tensor([pypto.STATIC, pypto.STATIC], FP32),
):
"""python 层展开 N_MM 次独立 matmul + SplitN, 每次产 1 对 dualdst pair.
每条 chain 跟 ADDS_CHAIN_DEPTH 步 ADDS 形成长依赖链。
"""
pypto.set_cube_tile_shapes([MM_M, MM_M], [MM_K, MM_K], [MM_N, MM_N])
upper_scalars = [0.5, 1.0, 1.5, 2.0]
lower_scalars = [0.25, 0.75, 1.25, 1.75]
for i in range(N_MM):
a_i = a_tensor[i * MM_M:(i + 1) * MM_M, :]
b_i = b_tensor[i * MM_N:(i + 1) * MM_N, :]
upper, lower = _matmul_split_n(a_i, b_i, MM_HALF_N)
cur = upper
for k, s in enumerate(upper_scalars):
pypto.set_vec_tile_shapes(MM_M, MM_HALF_N)
if k == len(upper_scalars) - 1:
out0_tensor[i * MM_M:(i + 1) * MM_M, :] = cur + s
else:
cur = cur + s
cur = lower
for k, s in enumerate(lower_scalars):
pypto.set_vec_tile_shapes(MM_M, MM_HALF_N)
if k == len(lower_scalars) - 1:
out1_tensor[i * MM_M:(i + 1) * MM_M, :] = cur + s
else:
cur = cur + s
@pytest.mark.soc("950")
@pytest.mark.skip(reason="large test case")
def test_dual_dst_long_chain():
device_id = int(os.environ.get('TILE_FWK_DEVICE_ID', 0))
torch.npu.set_device(device_id)
m_tot = N_MM * MM_M
n_tot = N_MM * MM_N
upper_sum = 0.5 + 1.0 + 1.5 + 2.0
lower_sum = 0.25 + 0.75 + 1.25 + 1.75
torch.manual_seed(0)
a = torch.rand([m_tot, MM_K], dtype=torch.float16, device=f"npu:{device_id}")
b = torch.rand([n_tot, MM_K], dtype=torch.float16, device=f"npu:{device_id}")
out0 = torch.zeros([m_tot, MM_HALF_N], dtype=torch.float32, device=f"npu:{device_id}")
out1 = torch.zeros([m_tot, MM_HALF_N], dtype=torch.float32, device=f"npu:{device_id}")
dual_dst_long_chain_kernel(a, b, out0, out1)
out0_golden = torch.zeros([m_tot, MM_HALF_N], dtype=torch.float32)
out1_golden = torch.zeros([m_tot, MM_HALF_N], dtype=torch.float32)
for i in range(N_MM):
a_i = a[i * MM_M:(i + 1) * MM_M, :].to(torch.float32)
b_i = b[i * MM_N:(i + 1) * MM_N, :].to(torch.float32)
mm_i = torch.matmul(a_i, b_i.T)
out0_golden[i * MM_M:(i + 1) * MM_M, :] = mm_i[:, :MM_HALF_N] + upper_sum
out1_golden[i * MM_M:(i + 1) * MM_M, :] = mm_i[:, MM_HALF_N:] + lower_sum
assert_allclose(out0.cpu().to(torch.float32).numpy(), out0_golden.numpy(),
rtol=1e-2, atol=1e-2)
assert_allclose(out1.cpu().to(torch.float32).numpy(), out1_golden.numpy(),
rtol=1e-2, atol=1e-2)
N_LINK = 12
LINK_M = 64
LINK_K = 32
LINK_N = 64
LINK_HALF_N = LINK_N // 2
@_DUALDST_JIT
def dual_dst_link_chain_kernel(
a_tensor: pypto.Tensor([pypto.STATIC, pypto.STATIC], FP32),
b_tensor: pypto.Tensor([pypto.STATIC, pypto.STATIC], FP32),
out_a: pypto.Tensor([pypto.STATIC, pypto.STATIC], FP32),
out_b: pypto.Tensor([pypto.STATIC, pypto.STATIC], FP32),
):
"""matmul,ADDS,matmul,ADDS 一条数据流上 N_LINK 步串接, 每步 1 对 dualdst pair.
全程 FP32 (dualdst op 不支持随路 dtype cast)。
中间 N_LINK-1 步的两路 (AIV0 upper / AIV1 lower) 落点用 pypto.tensor() 在
kernel 内分配独立 workspace LogicalTensor —— 不暴露成 kernel 参数:
- ws_a_i: AIV0 chain (upper + 0.5) 落点, 同时作为下一步 mm 的 A 输入
- ws_b_i: AIV1 chain (lower + 0.25) 落点 (side output, 不传播)
每个 ws 各自独立, 避开 "同一 GM tensor 多 slice 读写" 的 pypto IR 环路;
最后一步直接写 out_a / out_b。
"""
pypto.set_cube_tile_shapes([LINK_M, LINK_M], [LINK_K, LINK_K], [LINK_N, LINK_N])
ws_a_list = []
ws_b_list = []
for ws_idx in range(N_LINK - 1):
ws_a_list.append(pypto.tensor([LINK_M, LINK_K], FP32, "ws_a_" + str(ws_idx)))
ws_b_list.append(pypto.tensor([LINK_M, LINK_K], FP32, "ws_b_" + str(ws_idx)))
for i in range(N_LINK):
a_i = a_tensor if i == 0 else ws_a_list[i - 1]
b_i = b_tensor[i * LINK_N:(i + 1) * LINK_N, :]
upper, lower = _matmul_split_n(a_i, b_i, LINK_HALF_N)
dst_a = out_a if i == N_LINK - 1 else ws_a_list[i]
dst_b = out_b if i == N_LINK - 1 else ws_b_list[i]
pypto.set_vec_tile_shapes(LINK_M, LINK_HALF_N)
dst_a[:, :] = upper + 0.5
pypto.set_vec_tile_shapes(LINK_M, LINK_HALF_N)
dst_b[:, :] = lower + 0.25
@pytest.mark.soc("950")
@pytest.mark.skip(reason="large test case")
def test_dual_dst_link_chain():
device_id = int(os.environ.get('TILE_FWK_DEVICE_ID', 0))
torch.npu.set_device(device_id)
n_tot_b = N_LINK * LINK_N
torch.manual_seed(0)
a = torch.rand([LINK_M, LINK_K], dtype=torch.float32, device=f"npu:{device_id}")
b = torch.rand([n_tot_b, LINK_K], dtype=torch.float32, device=f"npu:{device_id}")
out_a = torch.zeros([LINK_M, LINK_K], dtype=torch.float32, device=f"npu:{device_id}")
out_b = torch.zeros([LINK_M, LINK_K], dtype=torch.float32, device=f"npu:{device_id}")
dual_dst_link_chain_kernel(a, b, out_a, out_b)
prev = a.cpu()
last_upper = None
last_lower = None
for i in range(N_LINK):
b_i = b[i * LINK_N:(i + 1) * LINK_N, :].cpu()
mm = torch.matmul(prev, b_i.T)
last_upper = mm[:, :LINK_HALF_N] + 0.5
last_lower = mm[:, LINK_HALF_N:] + 0.25
prev = last_upper
assert_allclose(out_a.cpu().numpy(), last_upper.numpy(), rtol=1e-3, atol=1e-3)
assert_allclose(out_b.cpu().numpy(), last_lower.numpy(), rtol=1e-3, atol=1e-3)
"""
Mega-scale ST for OP_L0C_COPY_UB_DUAL_DST fusion:
并行 MEGA_N_BRANCH 条分支, 每条分支内串行 MEGA_N_PER_BRANCH 对 dualdst。
Default: 50 分支 × 10 对/分支 = 500 对 dualdst pair。
Bump MEGA_N_BRANCH to 100 -> 100 × 10 = 1000 对。
为什么不真的"数据流串接" (link_chain 风格):
严格 prev->next 串接需要每对独立 workspace 张量, 数量 = N_BRANCH × (N_LINK-1) × 2,
50×9×2 = 900 个 ws 张量参数, pypto frontend 无法 handle 这种 kernel 签名。
退而求其次, 这里串行的体现是 IR op 顺序: 每条分支的 10 个 mm 在 IR 上聚在
一起 (外循环 s 在外, 内循环 i 在内 trace 出 branch_0 的 mm_0..mm_9 接 branch_1
的 mm_0..mm_9...), 形态等价于:
mm_b0_0 -> dual_dst_0 -> ADDS_u/ADDS_l
mm_b0_1 -> dual_dst_1 -> ADDS_u/ADDS_l
...
mm_b0_9 -> dual_dst_9 -> ADDS_u/ADDS_l
mm_b1_0 -> ... (并行分支)
数据流上每个 mm 独立 a/b 输入 (避免 ws 张量爆炸), OoOSchedule 仍能让不同分支
的 cube/vec 阶段并行交叠。
并行性:
MEGA_N_BRANCH 条分支无任何数据依赖, OoOSchedule 可以让 branch_0 的 mm
在 cube 上跑时, branch_1/2/...的下游 vec/ADDS (AIV0/AIV1) 并行执行。
分支越多, cube/vec 流水越饱和, 总时延越接近 max(单分支时延)。
输出布局:
out_a / out_b 沿 M 轴拼 (out_idx = (s * MEGA_N_PER_BRANCH + i) * MEGA_M),
long_chain 用例已验证 M 轴 multi-slice 写入 OK; 避开 pypto N-axis multi-slice
写 toOffset 算成 [0, -HALF] 的已知 bug。
资源 (默认 500 对):
a_tensor [50*10*64, 64] FP32 = 8 MB
b_tensor [50*10*128, 64] FP32 = 16 MB
out_a_tensor [50*10*64, 64] FP32 = 8 MB
out_b_tensor [50*10*64, 64] FP32 = 8 MB
总 GM ≈ 40 MB
1000 对场景翻倍 ≈ 80 MB GM
编译时间提醒:
500 对 -> IR 中约 2500 个 op (500 mm + 1000 L0C_COPY_UB + 1000 ADDS),
pypto trace + pass 可能耗时 2-5 分钟; 1000 对约 5-10 分钟, 属正常。
Run:
pytest python/tests/st/test_dualdst_mega.py -v
"""
MEGA_N_BRANCH = 50
MEGA_N_PER_BRANCH = 10
MEGA_M = 64
MEGA_K = 64
MEGA_HALF_N = 64
MEGA_N = 2 * MEGA_HALF_N
MEGA_TOTAL_MM = MEGA_N_BRANCH * MEGA_N_PER_BRANCH
MEGA_A_M = MEGA_TOTAL_MM * MEGA_M
MEGA_B_N = MEGA_TOTAL_MM * MEGA_N
MEGA_OUT_M = MEGA_TOTAL_MM * MEGA_M
@_DUALDST_JIT
def dual_dst_mega_kernel(
a_tensor: pypto.Tensor([pypto.STATIC, pypto.STATIC], FP32),
b_tensor: pypto.Tensor([pypto.STATIC, pypto.STATIC], FP32),
out_a_tensor: pypto.Tensor([pypto.STATIC, pypto.STATIC], FP32),
out_b_tensor: pypto.Tensor([pypto.STATIC, pypto.STATIC], FP32),
):
"""MEGA_N_BRANCH 条独立并行分支, 每条分支内串行 MEGA_N_PER_BRANCH 对 dualdst.
a_tensor: [MEGA_TOTAL_MM * MEGA_M, MEGA_K] 所有 mm 的 A 沿 M 轴拼
b_tensor: [MEGA_TOTAL_MM * MEGA_N, MEGA_K] 所有 mm 的 B 沿 N 轴拼
out_a_tensor: [MEGA_TOTAL_MM * MEGA_M, MEGA_HALF_N] AIV0 chain 输出沿 M 拼
out_b_tensor: [MEGA_TOTAL_MM * MEGA_M, MEGA_HALF_N] AIV1 chain 输出沿 M 拼
"""
pypto.set_cube_tile_shapes(
[MEGA_M, MEGA_M], [MEGA_K, MEGA_K], [MEGA_N, MEGA_N])
for s in range(MEGA_N_BRANCH):
for i in range(MEGA_N_PER_BRANCH):
global_idx = s * MEGA_N_PER_BRANCH + i
a_i = a_tensor[global_idx * MEGA_M:(global_idx + 1) * MEGA_M, :]
b_i = b_tensor[global_idx * MEGA_N:(global_idx + 1) * MEGA_N, :]
upper, lower = _matmul_split_n(a_i, b_i, MEGA_HALF_N)
pypto.set_vec_tile_shapes(MEGA_M, MEGA_HALF_N)
out_a_tensor[global_idx * MEGA_M:(global_idx + 1) * MEGA_M, :] = upper + 1.0
pypto.set_vec_tile_shapes(MEGA_M, MEGA_HALF_N)
out_b_tensor[global_idx * MEGA_M:(global_idx + 1) * MEGA_M, :] = lower + 2.0
@pytest.mark.soc("950")
@pytest.mark.skip(reason="large test case")
def test_dual_dst_mega():
device_id = int(os.environ.get('TILE_FWK_DEVICE_ID', 0))
torch.npu.set_device(device_id)
torch.manual_seed(0)
a = torch.rand([MEGA_A_M, MEGA_K], dtype=torch.float32,
device=f"npu:{device_id}")
b = torch.rand([MEGA_B_N, MEGA_K], dtype=torch.float32,
device=f"npu:{device_id}")
out_a = torch.zeros([MEGA_OUT_M, MEGA_HALF_N], dtype=torch.float32,
device=f"npu:{device_id}")
out_b = torch.zeros([MEGA_OUT_M, MEGA_HALF_N], dtype=torch.float32,
device=f"npu:{device_id}")
dual_dst_mega_kernel(a, b, out_a, out_b)
out_a_golden = torch.zeros([MEGA_OUT_M, MEGA_HALF_N], dtype=torch.float32)
out_b_golden = torch.zeros([MEGA_OUT_M, MEGA_HALF_N], dtype=torch.float32)
a_cpu = a.cpu()
b_cpu = b.cpu()
for s in range(MEGA_N_BRANCH):
for i in range(MEGA_N_PER_BRANCH):
global_idx = s * MEGA_N_PER_BRANCH + i
a_i = a_cpu[global_idx * MEGA_M:(global_idx + 1) * MEGA_M, :]
b_i = b_cpu[global_idx * MEGA_N:(global_idx + 1) * MEGA_N, :]
mm = torch.matmul(a_i, b_i.T)
out_a_golden[global_idx * MEGA_M:(global_idx + 1) * MEGA_M, :] = \
mm[:, :MEGA_HALF_N] + 1.0
out_b_golden[global_idx * MEGA_M:(global_idx + 1) * MEGA_M, :] = \
mm[:, MEGA_HALF_N:] + 2.0
assert_allclose(out_a.cpu().numpy(), out_a_golden.numpy(), rtol=1e-3, atol=1e-3)
assert_allclose(out_b.cpu().numpy(), out_b_golden.numpy(), rtol=1e-3, atol=1e-3)
def main():
test_dual_dst_split_n()
test_dual_dst_split_m()
test_dual_dst_chained_ops()
test_dual_dst_asymmetric_scale()
test_dual_dst_max_gain()
test_dual_dst_long_chain()
test_dual_dst_link_chain()
test_dual_dst_mega()
if __name__ == "__main__":
main()