import ast
import os
import shutil
import unittest
from pathlib import Path
from typing import Dict, List, Tuple
os.environ.setdefault("TORCH_COMPILE_DEBUG", "1")
os.environ.setdefault("TORCHINDUCTOR_NPU_EXT_DEBUG", "cpu")
os.environ.setdefault("TORCHINDUCTOR_FORCE_DISABLE_CACHES", "1")
import torch
from torch._inductor import config
import inductor_npu_ext
if hasattr(inductor_npu_ext, "_stub_debugging_host_only"):
inductor_npu_ext._stub_debugging_host_only()
_BINARY_OPS = {"Add", "Sub", "Mul", "Div", "TrueDiv", "FloorDiv", "Mod", "Maximum",
"Minimum", "Pow", "Ge", "Gt", "Le", "Lt", "Eq", "Ne",
"BitwiseAnd", "BitwiseOr", "BitwiseXor", "LogicalAnd", "LogicalOr"}
def _eval_int_list(node):
if isinstance(node, ast.List):
out = []
for e in node.elts:
if isinstance(e, ast.Constant):
out.append(e.value)
elif isinstance(e, ast.UnaryOp) and isinstance(e.op, ast.USub) and isinstance(e.operand, ast.Constant):
out.append(-e.operand.value)
elif isinstance(e, ast.Name):
out.append(e.id)
else:
out.append(ast.unparse(e))
return out
return None
def _parse_asc_graph(path: Path) -> Dict[str, dict]:
src = path.read_text()
tree = ast.parse(src)
ops: Dict[str, dict] = {}
for node in ast.walk(tree):
if not isinstance(node, ast.Assign) or len(node.targets) != 1:
continue
target = node.targets[0]
if isinstance(target, ast.Name) and isinstance(node.value, ast.Call):
func = node.value.func
if (isinstance(func, ast.Attribute) and isinstance(func.value, ast.Attribute)
and func.value.attr == "ops"):
ops[target.id] = {"type": func.attr, "inputs": {}, "size": None, "strides": None, "axis": None}
continue
if not isinstance(target, ast.Attribute):
continue
chain = []
cur = target
while isinstance(cur, ast.Attribute):
chain.append(cur.attr)
cur = cur.value
if not isinstance(cur, ast.Name):
continue
op_name = cur.id
chain.reverse()
if op_name not in ops:
continue
if len(chain) == 1 and chain[0] in ("x", "x1", "x2", "x3"):
val = node.value
if isinstance(val, ast.Attribute) and val.attr == "y" and isinstance(val.value, ast.Name):
ops[op_name]["inputs"][chain[0]] = val.value.id
continue
if len(chain) == 2 and chain[0] == "y":
if chain[1] == "size":
ops[op_name]["size"] = _eval_int_list(node.value)
elif chain[1] == "strides":
ops[op_name]["strides"] = _eval_int_list(node.value)
elif chain[1] == "axis":
ops[op_name]["axis"] = _eval_int_list(node.value)
return ops
def _collect_asc_graphs() -> List[Tuple[Path, Dict[str, dict]]]:
root = Path.cwd() / "torch_compile_debug"
if not root.exists():
return []
return [(p, _parse_asc_graph(p)) for p in sorted(root.rglob("asc_graph.py"))]
def _contig_stride(size):
"""跟 DenseLoop.contiguous_stride 对齐:size=1 维 stride=0;其余维 stride=
product(后面所有维 size)。"""
stride = []
mult = 1
for s in reversed(size):
if isinstance(s, int) and s == 1:
stride.append(0)
else:
stride.append(mult)
if isinstance(s, int):
mult *= s
stride.reverse()
return stride
class TestInductorNpuExt(unittest.TestCase):
def setUp(self) -> None:
import torch._dynamo
torch._dynamo.reset()
config.trace.enabled = True
for d in (Path.cwd() / "torch_compile_debug", Path.cwd() / ".npu_kernels_root"):
if d.exists():
shutil.rmtree(d)
def tearDown(self) -> None:
torch_compile_debug = Path.cwd() / "torch_compile_debug"
npu_kernels_root = Path.cwd() / ".npu_kernels_root"
if torch_compile_debug.exists():
shutil.rmtree(torch_compile_debug)
if npu_kernels_root.exists():
shutil.rmtree(npu_kernels_root)
def _assert_contig_non_load(self, graphs):
"""除 Load 外的所有节点:stride 必须 == contiguous_stride(size)。"""
issues = []
for path, ops in graphs:
for name, info in ops.items():
if info["type"] in ("Load", "Data", "Output", "Workspace", "Scalar"):
continue
size, stride = info["size"], info["strides"]
if size is None or stride is None:
continue
expect = _contig_stride(size)
if list(stride) != list(expect):
issues.append(f"[{path.parent.name}] {name}({info['type']}) "
f"size={size} stride={stride} != contig {expect}")
self.assertEqual(issues, [], "non-Load nodes 必须 contiguous_stride:\n "
+ "\n ".join(issues))
def _assert_binary_axis_consistent(self, graphs):
"""二元 op 的 x1/x2 输入 axis 必须完全一致(含顺序)。"""
issues = []
for path, ops in graphs:
for name, info in ops.items():
if info["type"] not in _BINARY_OPS:
continue
src_axes = []
for slot in ("x1", "x2"):
src_name = info["inputs"].get(slot)
if src_name and src_name in ops and ops[src_name]["axis"] is not None:
src_axes.append((slot, src_name, ops[src_name]["axis"]))
if len(src_axes) >= 2 and src_axes[0][2] != src_axes[1][2]:
issues.append(f"[{path.parent.name}] {name}({info['type']}) "
f"两路输入 axis 不一致: "
f"{src_axes[0][0]}<-{src_axes[0][1]} {src_axes[0][2]} vs "
f"{src_axes[1][0]}<-{src_axes[1][1]} {src_axes[1][2]}")
self.assertEqual(issues, [], "二元 op 输入 axis 必须一致:\n "
+ "\n ".join(issues))
def _assert_has_op(self, graphs, op_type, hint=""):
for _, ops in graphs:
if any(o["type"] == op_type for o in ops.values()):
return
kinds = [(p.parent.name, sorted({o["type"] for o in ops.values()})) for p, ops in graphs]
self.fail(f"asc_graph 未找到 {op_type} 节点 ({hint});现有 kernels: {kinds}")
def _run_and_collect(self, func, *args):
with torch.no_grad():
func(*args)
graphs = _collect_asc_graphs()
self.assertGreater(len(graphs), 0, "未生成任何 asc_graph.py")
return graphs
def test_add(self):
try:
@torch.compile
def func(x, y):
return x + y
x = torch.randn(2)
y = torch.randn(2)
func(x, y)
except Exception as e:
self.fail(f"test_add raised an exception: {e}")
def test_benchmark_generation(self):
@torch.compile
def func(x, y):
return x + y
config.trace.enabled = True
x = torch.randn(2, 2)
y = torch.randn(2, 2)
func(x, y)
benchmark_files = list(Path.cwd().rglob("benchmark.py"))
self.assertGreater(len(benchmark_files), 0, "Should generate benchmark.py")
benchmark_path = benchmark_files[0]
content = benchmark_path.read_text()
required_elements = [
"import sys",
"import torch",
"import torch_npu",
"async_compile_ascendc",
"torch_npu.profiler.profile",
"if __name__ == '__main__':",
"if sys.argv[-1] == 'e2e':",
"else:",
]
for element in required_elements:
self.assertIn(element, content, f"benchmark.py should contain {element}")
e2e_section = content[content.find("if sys.argv[-1] == 'e2e':"):content.find("else:")]
self.assertIn("tiling_def, host_impl, device_impl = fuser.codegen(", e2e_section, "e2e mode should open asc_graph.py")
default_section = content[content.find("else:"):]
self.assertIn("tiling_def", default_section, "Default mode should have tiling_def")
self.assertIn("host_impl", default_section, "Default mode should have host_impl")
self.assertIn("device_impl", default_section, "Default mode should have device_impl")
self.assertNotIn("tiling_def, host_impl, device_impl = fuser.codegen(", default_section, "default mode should not open asc_graph.py")
benchmark_path.unlink()
def test_view_pure_broadcast(self):
"""单维 1→N broadcast,没 transpose。"""
@torch.compile
def fn(x, y):
return x + y
graphs = self._run_and_collect(
fn,
torch.randn(1, 8, 16),
torch.randn(4, 8, 16),
)
self._assert_has_op(graphs, "Broadcast", "pure_broadcast")
self._assert_contig_non_load(graphs)
self._assert_binary_axis_consistent(graphs)
def test_view_multi_dim_broadcast(self):
"""多维 1→N broadcast 一起出现。"""
@torch.compile
def fn(x, y):
return x * y
graphs = self._run_and_collect(
fn,
torch.randn(1, 1, 5),
torch.randn(3, 4, 5),
)
self._assert_has_op(graphs, "Broadcast", "multi_dim_broadcast")
self._assert_contig_non_load(graphs)
self._assert_binary_axis_consistent(graphs)
def test_view_pure_transpose(self):
"""单纯 permute 让 src.axis 跟 dst.axis 顺序不同(无 broadcast)。"""
@torch.compile
def fn(x, y):
return x.permute(1, 0, 2) + y
graphs = self._run_and_collect(
fn,
torch.randn(4, 8, 16),
torch.randn(8, 4, 16),
)
self._assert_has_op(graphs, "Transpose", "pure_transpose")
self._assert_contig_non_load(graphs)
self._assert_binary_axis_consistent(graphs)
def test_view_transpose_then_broadcast(self):
"""main21 同款:permute + broadcast 一次完成 —— 历史上漏 broadcast、
且 transpose 输出 stride 不连续,是这个用例要兜住的回归点。"""
@torch.compile
def fn(x, y):
return y * x.permute(1, 0, 2, 3)
graphs = self._run_and_collect(
fn,
torch.randn(64, 32, 5, 1),
torch.randn(32, 64, 5, 56),
)
self._assert_has_op(graphs, "Transpose", "transpose_then_broadcast")
self._assert_has_op(graphs, "Broadcast", "transpose_then_broadcast")
self._assert_contig_non_load(graphs)
self._assert_binary_axis_consistent(graphs)
def test_view_unsqueeze_broadcast(self):
"""unsqueeze(隐式插 size=1 维)+ broadcast:跟 transpose 路径不同,
但同样要保证 Broadcast 节点 contig。"""
@torch.compile
def fn(x, y):
return x.unsqueeze(-1) + y
graphs = self._run_and_collect(
fn,
torch.randn(4, 8),
torch.randn(4, 8, 16),
)
self._assert_has_op(graphs, "Broadcast", "unsqueeze_broadcast")
self._assert_contig_non_load(graphs)
self._assert_binary_axis_consistent(graphs)
def _lower_and_check(self, fn, args, expect_op=None):
import torch._dynamo
torch._dynamo.reset()
for d in (Path.cwd() / "torch_compile_debug", Path.cwd() / ".npu_kernels_root"):
if d.exists():
shutil.rmtree(d)
with torch.no_grad():
torch.compile(fn)(*args)
graphs = _collect_asc_graphs()
self.assertGreater(len(graphs), 0, "未生成任何 asc_graph.py")
if expect_op is not None:
self._assert_has_op(graphs, expect_op)
self._assert_contig_non_load(graphs)
self._assert_binary_axis_consistent(graphs)
def test_lowering_pointwise(self):
"""白名单 pointwise 算子各自能 lower 成对应 ascir op。"""
r = torch.randn
cases = [
("Add", lambda a, b: a + b, [r(8, 16), r(8, 16)]),
("Sub", lambda a, b: a - b, [r(8, 16), r(8, 16)]),
("Mul", lambda a, b: a * b, [r(8, 16), r(8, 16)]),
("TrueDiv", lambda a, b: a / (b.abs() + 1.0), [r(8, 16), r(8, 16)]),
("Pow", lambda a, b: a.abs() ** (b.abs() + 1.0), [r(8, 16), r(8, 16)]),
("Sqrt", lambda a: torch.sqrt(a.abs() + 1.0), [r(8, 16)]),
("Rsqrt", lambda a: torch.rsqrt(a.abs() + 1.0), [r(8, 16)]),
("Abs", lambda a: a.abs() + 1.0, [r(8, 16)]),
("Exp", lambda a: torch.exp(a * 0.1), [r(8, 16)]),
("Sigmoid", lambda a: torch.sigmoid(a), [r(8, 16)]),
("Relu", lambda a: torch.relu(a) + 1.0, [r(8, 16)]),
("Neg", lambda a: -a + 1.0, [r(8, 16)]),
("Sign", lambda a: torch.sgn(a) + 1.0, [r(8, 16)]),
("Log1p", lambda a: torch.log1p(a.abs() + 1.0), [r(8, 16)]),
(None, lambda a: torch.nn.functional.silu(a), [r(8, 16)]),
(None, lambda a, b: torch.remainder(a, b.abs() + 1.0), [r(8, 16), r(8, 16)]),
(None, lambda a, b: torch.floor_divide(a.abs(), b.abs() + 1.0), [r(8, 16), r(8, 16)]),
]
for expect_op, fn, args in cases:
with self.subTest(op=expect_op or "compound"):
self._lower_and_check(fn, args, expect_op)
def test_lowering_compare(self):
"""比较 op 输出 bool —— 看护 support_out_dtypes 放行 bool/uint8、且
convert_element_type 接受 bool 输入两处配置,缺一会 fallback 回 eager。
用 .to(float32) 把 bool 转回,模拟典型用法。"""
r = torch.randn
cases = [
("Ge", lambda a, b: (a >= b).to(torch.float32), [r(8, 16), r(8, 16)]),
("Le", lambda a, b: (a <= b).to(torch.float32), [r(8, 16), r(8, 16)]),
("Gt", lambda a, b: (a > b).to(torch.float32), [r(8, 16), r(8, 16)]),
("Lt", lambda a, b: (a < b).to(torch.float32), [r(8, 16), r(8, 16)]),
("Eq", lambda a, b: (a == b).to(torch.float32), [r(8, 16), r(8, 16)]),
("Ne", lambda a, b: (a != b).to(torch.float32), [r(8, 16), r(8, 16)]),
]
for expect_op, fn, args in cases:
with self.subTest(op=expect_op):
self._lower_and_check(fn, args, expect_op)
def test_lowering_reduce_and_convert(self):
"""白名单 reduce(sum/mean/max/min)+ dtype 转换能 lower。"""
r = torch.randn
cases = [
("Sum", lambda a: a.sum(-1), [r(8, 16)]),
("Sum", lambda a: a.mean(-1), [r(8, 16)]),
("Max", lambda a: torch.max(a), [r(8, 16)]),
("Min", lambda a: torch.min(a), [r(8, 16)]),
("Cast", lambda a: (a + 1.0).to(torch.float16), [r(8, 16)]),
]
for expect_op, fn, args in cases:
with self.subTest(op=expect_op):
self._lower_and_check(fn, args, expect_op)
def test_lowering_view_ops(self):
"""白名单 view 类算子(permute/unsqueeze/squeeze/select/slice/expand/
transpose)必须跟计算 op 组合才会产生实际融合;只校验生成的图合法
(view 可能被 reinterpret 进 load 的 stride,不强求出现独立 view 节点)。"""
r = torch.randn
cases = [
("permute", lambda a, b: a.permute(1, 0) + b, [r(8, 16), r(16, 8)]),
("unsqueeze", lambda a, b: a.unsqueeze(-1) + b, [r(8, 16), r(8, 16, 4)]),
("squeeze", lambda a, b: a.squeeze(1) + b, [r(8, 1, 16), r(8, 16)]),
("select", lambda a, b: a.select(0, 0) + b, [r(4, 8, 16), r(8, 16)]),
("slice", lambda a, b: a[:, 1:3] + b, [r(8, 16), r(8, 2)]),
("expand", lambda a, b: a.expand(8, 16) * b, [r(1, 16), r(8, 16)]),
("transpose", lambda a, b: a.transpose(0, 1) + b, [r(8, 16), r(16, 8)]),
]
for name, fn, args in cases:
with self.subTest(view=name):
self._lower_and_check(fn, args, expect_op=None)
def test_soc_gating(self):
"""看护 _LoweringGuard.support(since=...) 的 SoC gating 逻辑。
gating 规则:current_soc 已知且 since 给定且 current_soc < since 时
跳过注册(该算子在当前 SoC 上会 fallback);否则注册。
UT 跑在 cpu 模式(current_soc=None),gating 默认不生效,所以这里直接
patch lowering.common.current_soc 模拟各档 SoC,用白名单外的探针 op
(aten.atan)验证注册与否,避免污染真实白名单、也不依赖真实设备。
"""
from inductor_npu_ext.lowering import common as lc
from inductor_npu_ext.lowering.common import float_dtypes
from inductor_npu_ext.common import Soc
probe = torch.ops.aten.atan.default
saved = lc.current_soc
def reg(soc, since):
lc.current_soc = soc
lc._LoweringGuard._data.pop(probe, None)
lc._LoweringGuard.support(probe, float_dtypes(), since=since)
return lc._LoweringGuard.has(probe)
try:
self.assertFalse(reg(Soc.A2, Soc.A5), "A2 < A5 应跳过 since=A5 注册")
self.assertFalse(reg(Soc.A3, Soc.A5), "A3 < A5 应跳过")
self.assertTrue(reg(Soc.A5, Soc.A5), "A5 >= A5 应注册")
self.assertTrue(reg(Soc.A2, None), "since=None 不 gating,应注册")
self.assertTrue(reg(None, Soc.A5), "current_soc=None 不 gating,应注册")
finally:
lc.current_soc = saved
lc._LoweringGuard._data.pop(probe, None)
if __name__ == "__main__":
unittest.main()