import os
import unittest
import torch
import torch.nn.functional as F
from mindiesd.compilation import MindieSDBackend, CompilationConfig
class _SingleIOModel(torch.nn.Module):
def forward(self, x):
return x * 2 + 1
class _MultiOutputModel(torch.nn.Module):
def forward(self, x):
return x * 2, x + 1
class _MixedInputModel(torch.nn.Module):
def forward(self, x, scale):
return x * scale
class _VariableLenModel(torch.nn.Module):
"""Model accepting variable-length 2D input [L, D] where L varies."""
def __init__(self, dim):
super().__init__()
self.fc = torch.nn.Linear(dim, dim)
def forward(self, x):
return self.fc(x)
@unittest.skipIf(
os.environ.get("MINDIE_TEST_MODE", "ALL") == "CPU",
"Skip NPU-dependent tests when MINDIE_TEST_MODE is CPU.",
)
class TestAclGraphBackend(unittest.TestCase):
def setUp(self):
self._orig_aclgraph_only = CompilationConfig.aclgraph_only
self._orig_aclgraph_with_compile = CompilationConfig.aclgraph_with_compile
def tearDown(self):
CompilationConfig.aclgraph_only = self._orig_aclgraph_only
CompilationConfig.aclgraph_with_compile = self._orig_aclgraph_with_compile
def test_aclgraph_only_output_consistent(self):
CompilationConfig.aclgraph_only = True
CompilationConfig.aclgraph_with_compile = False
model = _SingleIOModel()
compiled = torch.compile(model, backend=MindieSDBackend())
x = torch.randn(4, 4, dtype=torch.float32, device="npu")
output_compiled = compiled(x)
output_original = model(x)
self.assertTrue(
torch.allclose(output_original, output_compiled),
"ACLGraph-only output mismatch",
)
def test_aclgraph_with_compile_output_consistent(self):
CompilationConfig.aclgraph_only = False
CompilationConfig.aclgraph_with_compile = True
model = _SingleIOModel()
compiled = torch.compile(model, backend=MindieSDBackend())
x = torch.randn(4, 4, dtype=torch.float32, device="npu")
output_compiled = compiled(x)
output_original = model(x)
self.assertTrue(
torch.allclose(output_original, output_compiled),
"ACLGraph-with-compile output mismatch",
)
def test_both_modes_output_agree(self):
model = _SingleIOModel()
x = torch.randn(4, 4, dtype=torch.float32, device="npu")
CompilationConfig.aclgraph_only = True
CompilationConfig.aclgraph_with_compile = False
compiled_only = torch.compile(model, backend=MindieSDBackend())
out_only = compiled_only(x)
CompilationConfig.aclgraph_only = False
CompilationConfig.aclgraph_with_compile = True
compiled_with = torch.compile(model, backend=MindieSDBackend())
out_with = compiled_with(x)
self.assertTrue(
torch.allclose(out_only, out_with),
"Two ACLGraph modes should produce same output",
)
def test_same_shape_reuses_cached_graph(self):
CompilationConfig.aclgraph_only = True
CompilationConfig.aclgraph_with_compile = False
model = _SingleIOModel()
compiled = torch.compile(model, backend=MindieSDBackend())
x = torch.randn(4, 4, dtype=torch.float32, device="npu")
out1 = compiled(x)
self.assertTrue(torch.allclose(model(x), out1))
out2 = compiled(x)
self.assertTrue(torch.allclose(model(x), out2))
def test_different_shape_triggers_recapture(self):
CompilationConfig.aclgraph_only = True
CompilationConfig.aclgraph_with_compile = False
model = _SingleIOModel()
compiled = torch.compile(model, backend=MindieSDBackend())
x1 = torch.randn(4, 4, dtype=torch.float32, device="npu")
out1 = compiled(x1)
self.assertTrue(torch.allclose(model(x1), out1))
x2 = torch.randn(8, 8, dtype=torch.float32, device="npu")
out2 = compiled(x2)
self.assertTrue(torch.allclose(model(x2), out2))
def test_multiple_output_tensors(self):
CompilationConfig.aclgraph_only = True
CompilationConfig.aclgraph_with_compile = False
model = _MultiOutputModel()
compiled = torch.compile(model, backend=MindieSDBackend())
x = torch.randn(4, 4, dtype=torch.float32, device="npu")
out1, out2 = compiled(x)
ref1, ref2 = model(x)
self.assertTrue(torch.allclose(ref1, out1), "Output 1 mismatch")
self.assertTrue(torch.allclose(ref2, out2), "Output 2 mismatch")
def test_mixed_scalar_tensor_input(self):
CompilationConfig.aclgraph_only = True
CompilationConfig.aclgraph_with_compile = False
model = _MixedInputModel()
compiled = torch.compile(model, backend=MindieSDBackend())
x = torch.randn(4, 4, dtype=torch.float32, device="npu")
scale = 2.5
output_compiled = compiled(x, scale)
output_original = model(x, scale)
self.assertTrue(
torch.allclose(output_original, output_compiled),
"Mixed scalar/tensor input output mismatch",
)
def test_padding_shorter_input_output_correct(self):
max_len, dim = 64, 32
model = _VariableLenModel(dim).to("npu")
CompilationConfig.aclgraph_only = True
CompilationConfig.aclgraph_with_compile = False
compiled = torch.compile(model, backend=MindieSDBackend())
full_inp = torch.randn(max_len, dim, dtype=torch.float32, device="npu")
_ = compiled(full_inp)
short_len = 20
short_inp = torch.randn(short_len, dim, dtype=torch.float32, device="npu")
padding = (0, 0, 0, max_len - short_len)
padded = F.pad(short_inp, padding)
output_padded = compiled(padded)
output_actual = output_padded[:short_len]
golden = model(short_inp)
self.assertTrue(
torch.allclose(golden, output_actual),
"Padded-path output should match golden after slicing",
)
def test_padding_equal_input_no_extra_copy(self):
max_len, dim = 64, 32
model = _VariableLenModel(dim).to("npu")
CompilationConfig.aclgraph_only = True
CompilationConfig.aclgraph_with_compile = False
compiled = torch.compile(model, backend=MindieSDBackend())
full_inp = torch.randn(max_len, dim, dtype=torch.float32, device="npu")
out_captured = compiled(full_inp)
self.assertTrue(torch.allclose(model(full_inp), out_captured))
another = torch.randn(max_len, dim, dtype=torch.float32, device="npu")
out = compiled(another)
self.assertTrue(torch.allclose(model(another), out))
def test_padding_longer_input_triggers_recapture(self):
"""Longer input triggers re-capture rather than a crash."""
capture_len, dim = 32, 16
model = _VariableLenModel(dim).to("npu")
CompilationConfig.aclgraph_only = True
CompilationConfig.aclgraph_with_compile = False
compiled = torch.compile(model, backend=MindieSDBackend())
inp = torch.randn(capture_len, dim, dtype=torch.float32, device="npu")
compiled(inp)
longer = torch.randn(capture_len * 2, dim, dtype=torch.float32, device="npu")
longer_out = compiled(longer)
self.assertEqual(
longer_out.shape[0], capture_len * 2, "Re-captured output should match longer input's first dim"
)
golden = model(longer)
self.assertTrue(
torch.allclose(golden, longer_out, atol=1e-5),
"Re-captured output should match golden",
)
def test_teardown_restores_config(self):
self.assertEqual(
(self._orig_aclgraph_only, self._orig_aclgraph_with_compile),
(False, False),
"Original config should default to (False, False)",
)
if __name__ == "__main__":
unittest.main()