import dataclasses
import logging
import os
import unittest
from typing import List
from pathlib import Path
import torch
import torch_npu
import torchair
from torch._subclasses.fake_tensor import FakeTensorMode
from torchair.configs.compiler_config import CompilerConfig
from torchair.core.utils import logger
torch._logging.set_logs(dynamo=logging.INFO)
torch.manual_seed(7)
torch.npu.manual_seed_all(7)
logger.setLevel(logging.DEBUG)
def find_op(gm, op_default):
for node in gm.graph.nodes:
if node.op == "call_function" and node.target == op_default:
return True
return False
def create_optimize_wrapper(assert_func):
original_func = torchair.npu_fx_compiler._optimize_fx
def wrapper(gm, config, observer):
ret = original_func(gm, config, observer)
assert_func(gm)
return ret
return wrapper
class AclgraphTest(unittest.TestCase):
def setUp(self) -> None:
self.optimize_fx_bak = torchair.npu_fx_compiler._optimize_fx
if not hasattr(torch.npu, "fake_record_stream"):
patch_dynamo()
torchair._acl_concrete_graph.replace_stream_event.GraphCounter.set_graph_id(-1)
return super().setUp()
def tearDown(self) -> None:
if self.optimize_fx_bak is not None:
torchair.npu_fx_compiler._optimize_fx = self.optimize_fx_bak
return super().tearDown()
def test_a_aclgraph_memory_state_setting(self):
def test_func(x):
return x + x
config = torchair.CompilerConfig()
config.mode = "reduce-overhead"
backend = torchair.get_npu_backend(compiler_config=config)
test_func = torch.compile(test_func, backend=backend, dynamic=True)
a = torch.randn(2, 3, device="npu:1")
res = test_func(a)
self.assertTrue(res.device == torch.device("npu:1"))
b = torch.randn(4, 3, device="npu:1")
res = test_func(b)
self.assertTrue(res.device == torch.device("npu:1"))
def test_aclgraph_cache_with_static_kernel(self):
class CachedAclGraphModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.cached_prompt = torchair.inference.cache_compile(self.prompt, config=static_kernel_config)
def forward(self, qp, q, k, v, scale, actual_seq_lenq, actual_seq_len, actual_seq_len2, narrow_start):
return self.cached_prompt(
qp, q, k, v, scale, actual_seq_lenq, actual_seq_len, actual_seq_len2, narrow_start
)
def prompt(self, qp, q, k, v, scale, actual_seq_lenq, actual_seq_len, actual_seq_len2, narrow_start):
return self._forward(qp, q, k, v, scale, actual_seq_lenq, actual_seq_len, actual_seq_len2, narrow_start)
def _forward(self, qp, q, k, v, scale, actual_seq_lenq, actual_seq_len, actual_seq_len2, narrow_start):
pfa0, _ = torch_npu.npu_fused_infer_attention_score(
qp,
k,
v,
num_heads=32,
input_layout="BNSD",
scale=scale,
softmax_lse_flag=False,
actual_seq_lengths=actual_seq_lenq,
actual_seq_lengths_kv=actual_seq_len,
)
q = q * scale
ifa1, _ = torch_npu.npu_fused_infer_attention_score(
q,
k,
v,
num_heads=32,
input_layout="BNSD",
scale=scale,
softmax_lse_flag=False,
actual_seq_lengths_kv=actual_seq_len,
)
mm1 = ifa1.view([ifa1.shape[-1], -1]).clone()
q = q + 0.01
ifa2, _ = torch_npu.npu_fused_infer_attention_score(
q,
k,
v,
num_heads=32,
input_layout="BNSD",
scale=scale,
softmax_lse_flag=False,
actual_seq_lengths_kv=[66, 166, 266],
)
mm2 = ifa2.view([-1, ifa2.shape[-1]]).clone()
mmm = torch.mm(mm1, mm2) + pfa0.mean()
k = k * 1.1
v = v / 1.1
ifa3 = torch_npu.npu_fused_infer_attention_score(
q,
k,
v,
num_heads=32,
input_layout="BNSD",
scale=scale,
softmax_lse_flag=False,
actual_seq_lengths_kv=actual_seq_len2,
)
add3 = ifa3[0]
add3 = torch.narrow(add3, -1, 32, 32)
res = add3 * mmm.mean()
return res
static_compile_call_count = 0
def wrapped_static_compile(func):
def wrapper(*args, **kwargs):
nonlocal static_compile_call_count
static_compile_call_count += 1
return func(*args, **kwargs)
return wrapper
static_compile_bak = torchair._acl_concrete_graph.static_kernel.static_compile
torchair._acl_concrete_graph.static_kernel.static_compile = wrapped_static_compile(static_compile_bak)
static_kernel_config = CompilerConfig()
static_kernel_config.debug.aclgraph.disable_reinplace_inplaceable_ops_pass = True
static_kernel_config.debug.aclgraph.disable_reinplace_input_mutated_ops_pass = True
static_kernel_config.mode = "reduce-overhead"
static_kernel_config.experimental_config.aclgraph._aclnn_static_shape_kernel = True
static_kernel_config.experimental_config.aclgraph._aclnn_static_shape_kernel_build_dir = "./static_kernel"
mm = CachedAclGraphModel()
from torchair.core import _torchair
_torchair.GetSocName()
_torchair.AclopStartDumpArgs(1, "./static_kernel")
_torchair.AclopStopDumpArgs(1)
length = [28, 29, 1]
length2 = [66, 88, 55]
lengthq = [33, 44, 55]
scale = 1 / 0.0078125
narrow_start = 32
query_prefill = torch.randn(3, 32, 512, 128, dtype=torch.float16).npu()
query = torch.randn(3, 32, 1, 128, dtype=torch.float16).npu()
key = torch.randn(3, 32, 512, 128, dtype=torch.float16).npu()
value = torch.randn(3, 32, 512, 128, dtype=torch.float16).npu()
torch._dynamo.mark_static(query_prefill)
torch._dynamo.mark_static(query)
torch._dynamo.mark_static(key)
torch._dynamo.mark_static(value)
mmc = mm.npu()
from torchair.inference._cache_compiler import CompiledModel, ModelCacheSaver
prompt_cache_bin = CompiledModel.get_cache_bin(mm.prompt, config=static_kernel_config)
ModelCacheSaver.remove_cache(os.path.abspath(os.path.dirname(prompt_cache_bin)))
prompt_cache_dir = os.path.abspath(os.path.dirname(prompt_cache_bin))
self.assertFalse(os.path.exists(prompt_cache_dir))
mmc(query_prefill, query, key, value, scale, lengthq, length, length2, narrow_start)
self.assertTrue(os.path.exists(prompt_cache_dir))
prompt_cache_dir_path = Path(prompt_cache_dir)
outputs_dirs = [
d for d in prompt_cache_dir_path.iterdir() if d.is_dir() and d.name == "aclnn_static_shape_kernel_outputs"
]
self.assertEqual(len(outputs_dirs), 1)
ts_outputs_dirs = [
d
for d in outputs_dirs[0].iterdir()
if d.is_dir() and d.name.endswith("_outputs") and d.name.startswith("ts")
]
self.assertEqual(len(ts_outputs_dirs), 1)
run_pkgs = list(ts_outputs_dirs[0].glob("*.run"))
self.assertTrue(len(run_pkgs) >= 1)
self.assertTrue(static_compile_call_count, 1)
static_kernel_path = Path("./static_kernel")
static_kernel_path_dirs = [d for d in static_kernel_path.iterdir()]
self.assertEqual(len(static_kernel_path_dirs), 0)
mm2 = CachedAclGraphModel().npu()
with self.assertLogs(logger, level="DEBUG") as cm:
mm2(query_prefill, query, key, value, scale, lengthq, length, length2, narrow_start)
self.assertTrue(
any("Rebasing" in log for log in cm.output),
f"Expected DEBUG cache_compile 'Rebasing'not found in logs: {cm.output}",
)
self.assertFalse(
any("static kernel run eager success" in log for log in cm.output),
f"Not Expected DEBUG 'static kernel run eager success'found in logs: {cm.output}",
)
self.assertTrue(static_compile_call_count, 1)
self.assertTrue(os.path.exists(prompt_cache_dir))
def test_aclgraph_cache_with_static_kernel_multi_model(self):
@dataclasses.dataclass
class InputMeta:
data: torch.Tensor
is_prompt: bool
class Model(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear1 = torch.nn.Linear(2, 1)
self.linear2 = torch.nn.Linear(2, 1)
for param in self.parameters():
torch.nn.init.ones_(param)
self.cached_prompt = torchair.inference.cache_compile(self.prompt, config=static_kernel_config)
self.cached_decode = torchair.inference.cache_compile(self.decode, config=static_kernel_config)
def forward(self, x: InputMeta, kv: List[torch.Tensor]):
if x.is_prompt:
return self.cached_prompt(x, kv)
return self.cached_decode(x, kv)
def _forward(self, x, kv):
return self.linear2(x.data) + self.linear2(kv[0])
def prompt(self, x, y):
return self._forward(x, y)
def decode(self, x, y):
return self._forward(x, y)
static_kernel_config = CompilerConfig()
static_kernel_config.debug.aclgraph.disable_reinplace_inplaceable_ops_pass = True
static_kernel_config.debug.aclgraph.disable_reinplace_input_mutated_ops_pass = True
static_kernel_config.mode = "reduce-overhead"
static_kernel_config.experimental_config.aclgraph._aclnn_static_shape_kernel = True
static_kernel_config.experimental_config.aclgraph._aclnn_static_shape_kernel_build_dir = "./static_kernel_2"
model = Model().npu()
from torchair.inference._cache_compiler import CompiledModel, ModelCacheSaver
prompt_cache_bin = CompiledModel.get_cache_bin(model.prompt, config=static_kernel_config)
ModelCacheSaver.remove_cache(os.path.abspath(os.path.dirname(prompt_cache_bin)))
decode_cache_bin = CompiledModel.get_cache_bin(model.decode, config=static_kernel_config)
ModelCacheSaver.remove_cache(os.path.abspath(os.path.dirname(decode_cache_bin)))
x = InputMeta(data=torch.randn(2, 2).npu(), is_prompt=True)
kv = [torch.randn(2, 2).npu()]
model(x, kv)
x.is_prompt = False
model(x, kv)
pid = os.getpid()
prompt_cache_dir = os.path.abspath(os.path.dirname(prompt_cache_bin))
prompt_cache_dir_path = Path(prompt_cache_dir)
prompt_out_path = [d for d in prompt_cache_dir_path.iterdir() if d.is_dir() and d.name.endswith("_outputs")]
self.assertEqual(len(prompt_out_path), 1)
prompt_ts_path = [d for d in prompt_out_path[0].iterdir() if d.is_dir() and str(pid) in d.name]
self.assertEqual(len(prompt_ts_path), 1)
decode_cache_dir = os.path.abspath(os.path.dirname(decode_cache_bin))
decode_cache_dir_path = Path(decode_cache_dir)
decode_out_path = [d for d in decode_cache_dir_path.iterdir() if d.is_dir() and d.name.endswith("_outputs")]
self.assertEqual(len(decode_out_path), 1)
decode_ts_path = [d for d in decode_out_path[0].iterdir() if d.is_dir() and str(pid) in d.name]
self.assertEqual(len(decode_ts_path), 1)
first_opcompile_path = [d for d in prompt_ts_path[0].iterdir() if d.is_dir() and d.name.endswith("_opcompile")]
second_opcompile_path = [d for d in decode_ts_path[0].iterdir() if d.is_dir() and d.name.endswith("_opcompile")]
self.assertEqual(len(first_opcompile_path), 1)
self.assertEqual(len(second_opcompile_path), 1)
first_opcompile_selected_path = [
d for d in prompt_ts_path[0].iterdir() if d.is_dir() and d.name.endswith("_opcompile_selected")
]
second_opcompile_selected_path = [
d for d in decode_ts_path[0].iterdir() if d.is_dir() and d.name.endswith("_opcompile_selected")
]
self.assertEqual(len(first_opcompile_selected_path), 0)
self.assertEqual(len(second_opcompile_selected_path), 0)
first_json = [d for d in first_opcompile_path[0].iterdir() if d.is_dir() and d.name.endswith(".json")]
second_json = [d for d in second_opcompile_path[0].iterdir() if d.is_dir() and d.name.endswith(".json")]
self.assertEqual(len(first_json), len(second_json))
from collections import Counter
self.assertEqual(Counter(first_json), Counter(second_json))
def test_aclgraph_cache_recapture_with_ops_update(self):
class RecaptureModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.cached_prompt = torchair.inference.cache_compile(self.prompt, config=tng_config)
def forward(self, qp, q, k, v, scale, actual_seq_lenq, actual_seq_len, actual_seq_len2, narrow_start):
return self.cached_prompt(
qp, q, k, v, scale, actual_seq_lenq, actual_seq_len, actual_seq_len2, narrow_start
)
def prompt(self, qp, q, k, v, scale, actual_seq_lenq, actual_seq_len, actual_seq_len2, narrow_start):
return self._forward(qp, q, k, v, scale, actual_seq_lenq, actual_seq_len, actual_seq_len2, narrow_start)
def _forward(self, qp, q, k, v, scale, actual_seq_lenq, actual_seq_len, actual_seq_len2, narrow_start):
pfa0, _ = torch_npu.npu_fused_infer_attention_score(
qp,
k,
v,
num_heads=32,
input_layout="BNSD",
scale=scale,
softmax_lse_flag=False,
actual_seq_lengths=actual_seq_lenq,
actual_seq_lengths_kv=actual_seq_len,
)
q = q * scale
ifa1, _ = torch_npu.npu_fused_infer_attention_score(
q,
k,
v,
num_heads=32,
input_layout="BNSD",
scale=scale,
softmax_lse_flag=False,
actual_seq_lengths_kv=actual_seq_len,
)
mm1 = ifa1.view([ifa1.shape[-1], -1]).clone()
q = q + 0.01
ifa2, _ = torch_npu.npu_fused_infer_attention_score(
q,
k,
v,
num_heads=32,
input_layout="BNSD",
scale=scale,
softmax_lse_flag=False,
actual_seq_lengths_kv=[66, 166, 266],
)
mm2 = ifa2.view([-1, ifa2.shape[-1]]).clone()
mmm = torch.mm(mm1, mm2) + pfa0.mean()
k.mul_(1.1)
v = v / 1.1
ifa3 = torch_npu.npu_fused_infer_attention_score(
q,
k,
v,
num_heads=32,
input_layout="BNSD",
scale=scale,
softmax_lse_flag=False,
actual_seq_lengths_kv=actual_seq_len2,
)
add3 = ifa3[0]
add3 = torch.narrow(add3, -1, 32, 32)
res = add3 * mmm.mean()
return res
tng_config = torchair.CompilerConfig()
tng_config.mode = 'reduce-overhead'
tng_config.debug.aclgraph.disable_reinplace_inplaceable_ops_pass = True
tng_config.experimental_config.keep_inference_input_mutations = True
tng_config.debug.aclgraph.disable_mempool_reuse_in_same_fx = True
model1 = RecaptureModel().npu()
length = [28, 29, 1]
length2 = [66, 88, 55]
lengthq = [33, 44, 55]
length_new = [88, 99, 1]
length2_new = [40, 50, 60]
lengthq_new = [99, 50, 10]
scale = 1 / 0.0078125
narrow_start = 32
query_prefill_ = torch.randn(3, 32, 512, 128, dtype=torch.float16).npu()
query_ = torch.randn(3, 32, 1, 128, dtype=torch.float16).npu()
key_ = torch.randn(3, 32, 512, 128, dtype=torch.float16).npu()
value_ = torch.randn(3, 32, 512, 128, dtype=torch.float16).npu()
key = key_.clone()
torch._dynamo.mark_static(query_prefill_)
torch._dynamo.mark_static(query_)
torch._dynamo.mark_static(key_)
torch._dynamo.mark_static(value_)
from torchair.inference._cache_compiler import CompiledModel, ModelCacheSaver
prompt_cache_bin = CompiledModel.get_cache_bin(model1.prompt, config=tng_config)
ModelCacheSaver.remove_cache(os.path.abspath(os.path.dirname(prompt_cache_bin)))
prompt_cache_dir = os.path.abspath(os.path.dirname(prompt_cache_bin))
self.assertFalse(os.path.exists(prompt_cache_dir))
model1(query_prefill_, query_, key_, value_, scale, lengthq, length, length2, narrow_start)
self.assertTrue(os.path.exists(prompt_cache_dir))
model2 = RecaptureModel().npu()
with self.assertLogs(logger, level="DEBUG") as cm:
model2(query_prefill_, query_, key_, value_, scale, length_new, length2_new, lengthq_new, narrow_start)
model2(query_prefill_, query_, key, value_, scale, length_new, length2_new, lengthq_new, narrow_start)
self.assertTrue(
any("The current AclGraph needs to be recaptured" in log for log in cm.output),
f"Expected DEBUG 'The current AclGraph needs to be recaptured'not found in logs: {cm.output}",
)
self.assertTrue(
any("Record the 2 th updated node" in log for log in cm.output),
f"Expected DEBUG 'Record the 1 th updated node'not found in logs: {cm.output}",
)
self.assertFalse(
any("Record the 3 th updated node" in log for log in cm.output),
f"Not expected DEBUG 'Record the 2 th updated node'found in logs: {cm.output}",
)
def assert_pattern_pass(self, graph_after, check_exist):
fusion_cast_op_found_after = False
fusion_dq_op_found_after = False
for node in graph_after.graph.nodes:
if node.op == "call_function":
if node.target == torch.ops.npu.npu_add_rms_norm_dynamic_quant.default:
fusion_dq_op_found_after = True
if node.target == torch.ops.npu.npu_add_rms_norm_cast.default:
fusion_cast_op_found_after = True
if check_exist:
self.assertTrue(fusion_cast_op_found_after, "npu_add_rms_norm_cast should exist in the graph after fusion")
self.assertTrue(
fusion_dq_op_found_after, "npu_add_rms_norm_dynamic_quant should exist in the graph after fusion"
)
else:
self.assertFalse(
fusion_cast_op_found_after, "npu_add_rms_norm_cast should not exist in the graph after fusion"
)
self.assertFalse(
fusion_dq_op_found_after, "npu_add_rms_norm_dynamic_quant should not exist in the graph after fusion"
)
@unittest.skipIf(torch.__version__ < "2.6", "pattern_fusion_pass is unsupported when torch < 2.6")
def test_pattern_pass_for_aclgraph(self):
class DsModel(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, x1, x2, weight, smooth_scales):
y, _, xOut = torch_npu.npu_add_rms_norm(x1, x2, weight)
yOut, scale1Out = torch_npu.npu_dynamic_quant(y, smooth_scales=smooth_scales)
y1, _, xOut1 = torch_npu.npu_add_rms_norm(x1, x2, weight)
h1 = y1.size(-1)
y2 = y1.view(-1, h1)
yOut2, scale1Out2 = torch_npu.npu_dynamic_quant(y2, smooth_scales=smooth_scales)
_, _, h2 = y1.shape
y1 = y1.view(-1, h2).to(torch.float32)
y3, _, xOut3 = torch_npu.npu_add_rms_norm(x1, x2, weight)
yOut3, scale1Out3 = torch_npu.npu_dynamic_quant(y3.flatten(0, 1))
scale1Out3_view = scale1Out3.view(-1, 1)
return yOut, xOut, scale1Out, y1, xOut1, yOut2, scale1Out2, xOut3, yOut3, scale1Out3_view
torchair.npu_fx_compiler._optimize_fx = create_optimize_wrapper(lambda gm: self.assert_pattern_pass(gm, True))
npu_config = torchair.CompilerConfig()
npu_config.mode = "reduce-overhead"
npu_backend = torchair.get_npu_backend(compiler_config=npu_config)
model = DsModel()
model_compile = torch.compile(model, backend=npu_backend)
x1 = torch.randn(1, 2, 3, dtype=torch.float16, device='npu')
x2 = torch.randn(1, 2, 3, dtype=torch.float16, device='npu')
gamma = torch.ones(3, dtype=torch.float16, device='npu')
smooth_scale1 = torch.ones(3, dtype=torch.float16, device='npu')
compile_res = model_compile(x1, x2, gamma, smooth_scale1)
expected = model(x1, x2, gamma, smooth_scale1)
self.assertEqual(len(compile_res), len(expected))
for comp, exp in zip(compile_res, expected):
self.assertTrue(torch.allclose(comp, exp, atol=1e-5))
@unittest.skipIf(torch.__version__ < "2.6", "pattern_fusion_pass is unsupported when torch < 2.6")
def test_pattern_pass_for_aclgraph_with_epsilon(self):
class DsModel(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, x1, x2, weight, smooth_scales, epilson1, epilson2):
y, _, _ = torch_npu.npu_add_rms_norm(x1, x2, weight, epilson1)
yOut, scale1Out = torch_npu.npu_dynamic_quant(y, smooth_scales=smooth_scales)
y1, _, xOut1 = torch_npu.npu_add_rms_norm(x1, x2, weight, epilson2)
h1 = y1.size(-1)
y2 = y1.view(-1, h1)
yOut2, scale1Out2 = torch_npu.npu_dynamic_quant(y2, smooth_scales=smooth_scales)
y1 = y1.to(torch.float32)
y3, _, xOut3 = torch_npu.npu_add_rms_norm(x1, x2, weight, epilson1)
yOut3, scale1Out3 = torch_npu.npu_dynamic_quant(y3.flatten(0, 1))
scale1Out3_view = scale1Out3.view(-1, 1)
return yOut, scale1Out, y1, xOut1, yOut2, scale1Out2, xOut3, yOut3, scale1Out3_view
torchair.npu_fx_compiler._optimize_fx = create_optimize_wrapper(lambda gm: self.assert_pattern_pass(gm, True))
npu_config = torchair.CompilerConfig()
npu_config.mode = "reduce-overhead"
npu_backend = torchair.get_npu_backend(compiler_config=npu_config)
model = DsModel()
model_compile = torch.compile(model, backend=npu_backend)
x1 = torch.randn(1, 2, 3, dtype=torch.float16, device='npu')
x2 = torch.randn(1, 2, 3, dtype=torch.float16, device='npu')
gamma = torch.ones(3, dtype=torch.float16, device='npu')
smooth_scale1 = torch.ones(3, dtype=torch.float16, device='npu')
compile_res = model_compile(x1, x2, gamma, smooth_scale1, 1e-05, 1e-06)
expected = model(x1, x2, gamma, smooth_scale1, 1e-05, 1e-06)
self.assertEqual(len(compile_res), len(expected))
for comp, exp in zip(compile_res, expected):
self.assertTrue(torch.allclose(comp, exp, atol=1e-5))
@unittest.skipIf(torch.__version__ < "2.6", "pattern_fusion_pass is unsupported when torch < 2.6")
def test_pattern_pass_for_ge(self):
class DsModel(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, x1, x2, weight, smooth_scales):
y, _, xOut = torch_npu.npu_add_rms_norm(x1, x2, weight)
yOut, scale1Out = torch_npu.npu_dynamic_quant(y, smooth_scales=smooth_scales)
y1, _, xOut1 = torch_npu.npu_add_rms_norm(x1, x2, weight)
h1 = y1.size(-1)
y2 = y1.view(-1, h1)
yOut2, scale1Out2 = torch_npu.npu_dynamic_quant(y2, smooth_scales=smooth_scales)
_, _, h2 = y1.shape
y1 = y1.view(-1, h2).to(torch.float32)
y3, _, xOut3 = torch_npu.npu_add_rms_norm(x1, x2, weight)
yOut3, scale1Out3 = torch_npu.npu_dynamic_quant(y3.flatten(0, 1))
scale1Out3_view = scale1Out3.view(-1, 1)
return yOut, xOut, scale1Out, y1, xOut1, yOut2, scale1Out2, xOut3, yOut3, scale1Out3_view
torchair.npu_fx_compiler._optimize_fx = create_optimize_wrapper(lambda gm: self.assert_pattern_pass(gm, True))
npu_config = torchair.CompilerConfig()
npu_config.mode = "max-autotune"
npu_backend = torchair.get_npu_backend(compiler_config=npu_config)
model = DsModel()
model_compile = torch.compile(model, backend=npu_backend)
x1 = torch.randn(1, 2, 3, dtype=torch.float16, device='npu')
x2 = torch.randn(1, 2, 3, dtype=torch.float16, device='npu')
gamma = torch.ones(3, dtype=torch.float16, device='npu')
smooth_scale1 = torch.ones(3, dtype=torch.float16, device='npu')
compile_res = model_compile(x1, x2, gamma, smooth_scale1)
expected = model(x1, x2, gamma, smooth_scale1)
self.assertEqual(len(compile_res), len(expected))
for comp, exp in zip(compile_res, expected):
self.assertTrue(torch.allclose(comp, exp, atol=1e-5))
@unittest.skipIf(torch.__version__ < "2.6", "pattern_fusion_pass is unsupported when torch < 2.6")
def test_close_pattern_pass_for_aclgraph(self):
class DsModel(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, x1, x2, weight, smooth_scales):
y, _, xOut = torch_npu.npu_add_rms_norm(x1, x2, weight)
yOut, scale1Out = torch_npu.npu_dynamic_quant(y, smooth_scales=smooth_scales)
y1, _, xOut1 = torch_npu.npu_add_rms_norm(x1, x2, weight)
h1 = y1.size(-1)
y2 = y1.view(-1, h1)
yOut2, scale1Out2 = torch_npu.npu_dynamic_quant(y2, smooth_scales=smooth_scales)
_, _, h2 = y1.shape
y1 = y1.view(-1, h2).to(torch.float32)
y3, _, xOut3 = torch_npu.npu_add_rms_norm(x1, x2, weight)
yOut3, scale1Out3 = torch_npu.npu_dynamic_quant(y3.flatten(0, 1))
scale1Out3_view = scale1Out3.view(-1, 1)
return yOut, xOut, scale1Out, y1, xOut1, yOut2, scale1Out2, xOut3, yOut3, scale1Out3_view
torchair.npu_fx_compiler._optimize_fx = create_optimize_wrapper(lambda gm: self.assert_pattern_pass(gm, False))
npu_config = torchair.CompilerConfig()
npu_config.mode = "reduce-overhead"
npu_config.experimental_config.pattern_fusion_pass = False
npu_backend = torchair.get_npu_backend(compiler_config=npu_config)
model = DsModel()
model_compile = torch.compile(model, backend=npu_backend)
x1 = torch.randn(1, 2, 3, dtype=torch.float16, device='npu')
x2 = torch.randn(1, 2, 3, dtype=torch.float16, device='npu')
gamma = torch.ones(3, dtype=torch.float16, device='npu')
smooth_scale1 = torch.ones(3, dtype=torch.float16, device='npu')
model_compile(x1, x2, gamma, smooth_scale1)
@unittest.skipIf(torch.__version__ < "2.6", "pattern_fusion_pass is unsupported when torch < 2.6")
def test_close_pattern_pass_for_ge(self):
class DsModel(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, x1, x2, weight, smooth_scales):
y, _, xOut = torch_npu.npu_add_rms_norm(x1, x2, weight)
yOut, scale1Out = torch_npu.npu_dynamic_quant(y, smooth_scales=smooth_scales)
y1, _, xOut1 = torch_npu.npu_add_rms_norm(x1, x2, weight)
h1 = y1.size(-1)
y2 = y1.view(-1, h1)
yOut2, scale1Out2 = torch_npu.npu_dynamic_quant(y2, smooth_scales=smooth_scales)
_, _, h2 = y1.shape
y1 = y1.view(-1, h2).to(torch.float32)
y3, _, xOut3 = torch_npu.npu_add_rms_norm(x1, x2, weight)
yOut3, scale1Out3 = torch_npu.npu_dynamic_quant(y3.flatten(0, 1))
scale1Out3_view = scale1Out3.view(-1, 1)
return yOut, xOut, scale1Out, y1, xOut1, yOut2, scale1Out2, xOut3, yOut3, scale1Out3_view
torchair.npu_fx_compiler._optimize_fx = create_optimize_wrapper(lambda gm: self.assert_pattern_pass(gm, False))
npu_config = torchair.CompilerConfig()
npu_config.mode = "max-autotune"
npu_config.experimental_config.pattern_fusion_pass = False
npu_backend = torchair.get_npu_backend(compiler_config=npu_config)
model = DsModel()
model_compile = torch.compile(model, backend=npu_backend)
x1 = torch.randn(1, 2, 3, dtype=torch.float16, device='npu')
x2 = torch.randn(1, 2, 3, dtype=torch.float16, device='npu')
gamma = torch.ones(3, dtype=torch.float16, device='npu')
smooth_scale1 = torch.ones(3, dtype=torch.float16, device='npu')
model_compile(x1, x2, gamma, smooth_scale1)
@unittest.skipIf(torch.__version__ < "2.6", "pattern_fusion_pass is unsupported when torch < 2.6")
def test_pattern_pass_for_aclgraph_with_multistream(self):
class DsModel2(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, x1, x2, weight, smooth_scales):
y, _, xOut = torch_npu.npu_add_rms_norm(x1, x2, weight)
with torchair.scope.npu_stream_switch('2', 3):
y1, _, xOut1 = torch_npu.npu_add_rms_norm(x1, x2, weight)
_, _, h2 = y1.shape
y1 = y1.view(-1, h2).to(torch.float32)
yOut2, scale1Out2 = torch_npu.npu_dynamic_quant(y, smooth_scales=smooth_scales)
return xOut, yOut2, scale1Out2, y1, xOut1
torchair.npu_fx_compiler._optimize_fx = create_optimize_wrapper(lambda gm: self.assert_pattern_pass(gm, True))
npu_config = torchair.CompilerConfig()
npu_config.mode = "reduce-overhead"
npu_backend = torchair.get_npu_backend(compiler_config=npu_config)
model = DsModel2()
model_compile = torch.compile(model, backend=npu_backend)
x1 = torch.randn(1, 2, 3, dtype=torch.float16, device='npu')
x2 = torch.randn(1, 2, 3, dtype=torch.float16, device='npu')
gamma = torch.ones(3, dtype=torch.float16, device='npu')
smooth_scale1 = torch.ones(3, dtype=torch.float16, device='npu')
compile_res = model_compile(x1, x2, gamma, smooth_scale1)
expected = model(x1, x2, gamma, smooth_scale1)
self.assertEqual(len(compile_res), len(expected))
for comp, exp in zip(compile_res, expected):
self.assertTrue(torch.allclose(comp, exp, atol=1e-5))
@unittest.skipIf(torch.__version__ < "2.6", "pattern_fusion_pass is unsupported when torch < 2.6")
def test_pattern_pass_for_cast_with_subgraph_in_diff_stream(self):
class DsModel2(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, x1, x2, weight, smooth_scales):
y, _, xOut = torch_npu.npu_add_rms_norm(x1, x2, weight)
with torchair.scope.npu_stream_switch('2', 3):
torchair.ops.wait([y])
_, _, h2 = y.shape
y = y.view(-1, h2).to(torch.float32)
return y, xOut
torchair.npu_fx_compiler._optimize_fx = create_optimize_wrapper(lambda gm: self.assert_pattern_pass(gm, False))
npu_config = torchair.CompilerConfig()
npu_config.mode = "reduce-overhead"
npu_backend = torchair.get_npu_backend(compiler_config=npu_config)
model = DsModel2()
model_compile = torch.compile(model, backend=npu_backend)
x1 = torch.randn(1, 2, 3, dtype=torch.float16, device='npu')
x2 = torch.randn(1, 2, 3, dtype=torch.float16, device='npu')
gamma = torch.ones(3, dtype=torch.float16, device='npu')
smooth_scale1 = torch.ones(3, dtype=torch.float16, device='npu')
model_compile(x1, x2, gamma, smooth_scale1)
@unittest.skipIf(torch.__version__ < "2.6", "pattern_fusion_pass is unsupported when torch < 2.6")
def test_pattern_pass_for_dynamicquant_with_subgraph_in_diff_stream(self):
class DsModel2(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, x1, x2, weight, smooth_scales):
y, _, xOut = torch_npu.npu_add_rms_norm(x1, x2, weight)
with torchair.scope.npu_stream_switch('2', 3):
torchair.ops.wait([y])
yOut2, scale1Out2 = torch_npu.npu_dynamic_quant(y, smooth_scales=smooth_scales)
return yOut2, xOut, scale1Out2
torchair.npu_fx_compiler._optimize_fx = create_optimize_wrapper(lambda gm: self.assert_pattern_pass(gm, False))
npu_config = torchair.CompilerConfig()
npu_config.mode = "reduce-overhead"
npu_backend = torchair.get_npu_backend(compiler_config=npu_config)
model = DsModel2()
model_compile = torch.compile(model, backend=npu_backend)
x1 = torch.randn(1, 2, 3, dtype=torch.float16, device='npu')
x2 = torch.randn(1, 2, 3, dtype=torch.float16, device='npu')
gamma = torch.ones(3, dtype=torch.float16, device='npu')
smooth_scale1 = torch.ones(3, dtype=torch.float16, device='npu')
model_compile(x1, x2, gamma, smooth_scale1)
def test_aclgraph_userinput_construct_in_share_memory_with_parameter_and_mutated(self):
class RecaptureModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear = torch.nn.Linear(32, 32)
def forward(self, qp, q, k, v, scale, actual_seq_lenq, actual_seq_len, actual_seq_len2, narrow_start, x):
pfa0, _ = torch_npu.npu_fused_infer_attention_score(
qp,
k,
v,
num_heads=32,
input_layout="BNSD",
scale=scale,
softmax_lse_flag=False,
actual_seq_lengths=actual_seq_lenq,
actual_seq_lengths_kv=actual_seq_len,
)
q = q * scale
ifa1, _ = torch_npu.npu_fused_infer_attention_score(
q,
k,
v,
num_heads=32,
input_layout="BNSD",
scale=scale,
softmax_lse_flag=False,
actual_seq_lengths_kv=actual_seq_len,
)
mm1 = ifa1.view([ifa1.shape[-1], -1]).clone()
q = q + 0.01
ifa2, _ = torch_npu.npu_fused_infer_attention_score(
q,
k,
v,
num_heads=32,
input_layout="BNSD",
scale=scale,
softmax_lse_flag=False,
actual_seq_lengths_kv=[66, 166, 266],
)
mm2 = ifa2.view([-1, ifa2.shape[-1]]).clone()
mmm = torch.mm(mm1, mm2) + pfa0.mean()
k.mul_(1.1)
v = v / 1.1
ifa3 = torch_npu.npu_fused_infer_attention_score(
q,
k,
v,
num_heads=32,
input_layout="BNSD",
scale=scale,
softmax_lse_flag=False,
actual_seq_lengths_kv=actual_seq_len2,
)
add3 = ifa3[0]
add3 = torch.narrow(add3, -1, 32, 32)
add3 = add3 @ self.linear(x)
res = add3 * mmm.mean()
return res
tng_config = torchair.CompilerConfig()
tng_config.mode = 'reduce-overhead'
tng_config.debug.aclgraph.disable_reinplace_inplaceable_ops_pass = True
tng_config.experimental_config.keep_inference_input_mutations = True
npu_backend = torchair.get_npu_backend(compiler_config=tng_config)
model1 = RecaptureModel().npu()
length_new = [88, 99, 1]
length2_new = [40, 50, 60]
lengthq_new = [99, 50, 10]
scale = 1 / 0.0078125
narrow_start = 32
query_prefill_ = torch.randn(3, 32, 512, 128, dtype=torch.float16).npu()
query_ = torch.randn(3, 32, 1, 128, dtype=torch.float16).npu()
key_ = torch.randn(3, 32, 512, 128, dtype=torch.float16).npu()
value_ = torch.randn(3, 32, 512, 128, dtype=torch.float16).npu()
key = key_.clone()
torch._dynamo.mark_static(query_prefill_)
torch._dynamo.mark_static(query_)
torch._dynamo.mark_static(key_)
torch._dynamo.mark_static(value_)
x = torch.randn([32, 32]).npu()
a = torch.ones(32, 32).npu()
b = torch.zeros(32, 32).npu()
compiled_model1 = torch.compile(model1, backend=npu_backend, fullgraph=True, dynamic=True)
with self.assertLogs(logger, level="DEBUG") as cm:
compiled_model1.linear.weight.data = a
compiled_model1(
query_prefill_, query_, key_, value_, scale, length_new, length2_new, lengthq_new, narrow_start, x
)
self.assertTrue(
any("No find captured AclGraph" in log for log in cm.output),
f"Expected DEBUG 'No find captured AclGraph'not found in logs: {cm.output}",
)
self.assertTrue(
any("After capturing fx_graph" in log for log in cm.output),
f"Expected DEBUG 'After capturing fx_graph'not found in logs: {cm.output}",
)
self.assertTrue(
any("'activate_num': 4" in log for log in cm.output),
f"Expected DEBUG ''activate_num': 4'not found in logs: {cm.output}",
)
with self.assertLogs(logger, level="DEBUG") as cm:
graph_res2 = compiled_model1(
query_prefill_, query_, key_, value_, scale, length_new, length2_new, lengthq_new, narrow_start, x
)
self.assertTrue(
any("The current AclGraph no needs to be recaptured" in log for log in cm.output),
f"Expected DEBUG 'The current AclGraph no needs to be recaptured'not found in logs: {cm.output}",
)
with self.assertLogs(logger, level="DEBUG") as cm:
graph_res3 = compiled_model1(
query_prefill_, query_, key, value_, scale, length_new, length2_new, lengthq_new, narrow_start, x
)
self.assertTrue(
any("The current AclGraph needs to be recaptured" in log for log in cm.output),
f"Expected DEBUG 'The current AclGraph needs to be recaptured'not found in logs: {cm.output}",
)
self.assertTrue(
any("After capturing fx_graph" in log for log in cm.output),
f"Expected DEBUG 'After capturing fx_graph'not found in logs: {cm.output}",
)
self.assertTrue(
any("'activate_num': 7" in log for log in cm.output),
f"Expected DEBUG ''activate_num': 7'not found in logs: {cm.output}",
)
del graph_res2
del graph_res3
with self.assertLogs(logger, level="DEBUG") as cm:
compiled_model1.linear.weight.data = b
compiled_model1(
query_prefill_, query_, key, value_, scale, length_new, length2_new, lengthq_new, narrow_start, x
)
self.assertFalse(
any("The current AclGraph needs to be recaptured" in log for log in cm.output),
f"Not Expected DEBUG 'The current AclGraph needs to be recaptured'found in logs: {cm.output}",
)
self.assertTrue(
any("After capturing fx_graph" in log for log in cm.output),
f"Expected DEBUG 'After capturing fx_graph'not found in logs: {cm.output}",
)
self.assertTrue(
any("'activate_num': 4" in log for log in cm.output),
f"Expected DEBUG ''activate_num': 4'not found in logs: {cm.output}",
)
def test_aclgraph_userinput_construct_in_share_memory_with_parameter_and_mutated_clone_input_false(self):
class RecaptureModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear = torch.nn.Linear(32, 32)
def forward(self, qp, q, k, v, scale, actual_seq_lenq, actual_seq_len, actual_seq_len2, narrow_start, x):
pfa0, _ = torch_npu.npu_fused_infer_attention_score(
qp,
k,
v,
num_heads=32,
input_layout="BNSD",
scale=scale,
softmax_lse_flag=False,
actual_seq_lengths=actual_seq_lenq,
actual_seq_lengths_kv=actual_seq_len,
)
q = q * scale
ifa1, _ = torch_npu.npu_fused_infer_attention_score(
q,
k,
v,
num_heads=32,
input_layout="BNSD",
scale=scale,
softmax_lse_flag=False,
actual_seq_lengths_kv=actual_seq_len,
)
mm1 = ifa1.view([ifa1.shape[-1], -1]).clone()
q = q + 0.01
ifa2, _ = torch_npu.npu_fused_infer_attention_score(
q,
k,
v,
num_heads=32,
input_layout="BNSD",
scale=scale,
softmax_lse_flag=False,
actual_seq_lengths_kv=[66, 166, 266],
)
mm2 = ifa2.view([-1, ifa2.shape[-1]]).clone()
mmm = torch.mm(mm1, mm2) + pfa0.mean()
k.mul_(1.1)
v = v / 1.1
ifa3 = torch_npu.npu_fused_infer_attention_score(
q,
k,
v,
num_heads=32,
input_layout="BNSD",
scale=scale,
softmax_lse_flag=False,
actual_seq_lengths_kv=actual_seq_len2,
)
add3 = ifa3[0]
add3 = torch.narrow(add3, -1, 32, 32)
add3 = add3 @ self.linear(x)
res = add3 * mmm.mean()
return res
tng_config = torchair.CompilerConfig()
tng_config.mode = 'reduce-overhead'
tng_config.debug.aclgraph.disable_reinplace_inplaceable_ops_pass = True
tng_config.debug.aclgraph.clone_input = False
tng_config.experimental_config.keep_inference_input_mutations = True
npu_backend = torchair.get_npu_backend(compiler_config=tng_config)
model1 = RecaptureModel().npu()
length_new = [88, 99, 1]
length2_new = [40, 50, 60]
lengthq_new = [99, 50, 10]
scale = 1 / 0.0078125
narrow_start = 32
query_prefill_ = torch.randn(3, 32, 512, 128, dtype=torch.float16).npu()
query_ = torch.randn(3, 32, 1, 128, dtype=torch.float16).npu()
key_ = torch.randn(3, 32, 512, 128, dtype=torch.float16).npu()
value_ = torch.randn(3, 32, 512, 128, dtype=torch.float16).npu()
key = key_.clone()
torch._dynamo.mark_static(query_prefill_)
torch._dynamo.mark_static(query_)
torch._dynamo.mark_static(key_)
torch._dynamo.mark_static(value_)
x = torch.randn([32, 32]).npu()
a = torch.ones(32, 32).npu()
b = torch.zeros(32, 32).npu()
compiled_model1 = torch.compile(model1, backend=npu_backend, fullgraph=True, dynamic=True)
with self.assertLogs(logger, level="DEBUG") as cm:
compiled_model1.linear.weight.data = a
graph_res1 = compiled_model1(
query_prefill_, query_, key_, value_, scale, length_new, length2_new, lengthq_new, narrow_start, x
)
self.assertTrue(
any("No find captured AclGraph" in log for log in cm.output),
f"Expected DEBUG 'No find captured AclGraph'not found in logs: {cm.output}",
)
self.assertTrue(
any("After capturing fx_graph" in log for log in cm.output),
f"Expected DEBUG 'After capturing fx_graph'not found in logs: {cm.output}",
)
self.assertTrue(
any("'activate_num': 3" in log for log in cm.output),
f"Expected DEBUG ''activate_num': 3'not found in logs: {cm.output}",
)
with self.assertLogs(logger, level="DEBUG") as cm:
graph_res2 = compiled_model1(
query_prefill_, query_, key_, value_, scale, length_new, length2_new, lengthq_new, narrow_start, x
)
self.assertTrue(
any("The current AclGraph no needs to be recaptured" in log for log in cm.output),
f"Expected DEBUG 'The current AclGraph no needs to be recaptured'not found in logs: {cm.output}",
)
with self.assertLogs(logger, level="DEBUG") as cm:
graph_res3 = compiled_model1(
query_prefill_, query_, key, value_, scale, length_new, length2_new, lengthq_new, narrow_start, x
)
self.assertTrue(
any("The current AclGraph needs to be recaptured" in log for log in cm.output),
f"Expected DEBUG 'The current AclGraph needs to be recaptured'not found in logs: {cm.output}",
)
self.assertTrue(
any("After capturing fx_graph" in log for log in cm.output),
f"Expected DEBUG 'After capturing fx_graph'not found in logs: {cm.output}",
)
self.assertTrue(
any("'activate_num': 6" in log for log in cm.output),
f"Expected DEBUG ''activate_num': 6'not found in logs: {cm.output}",
)
del graph_res1
del graph_res2
del graph_res3
with self.assertLogs(logger, level="DEBUG") as cm:
compiled_model1.linear.weight.data = b
compiled_model1(
query_prefill_, query_, key, value_, scale, length_new, length2_new, lengthq_new, narrow_start, x
)
self.assertFalse(
any("The current AclGraph needs to be recaptured" in log for log in cm.output),
f"Not Expected DEBUG 'The current AclGraph needs to be recaptured'found in logs: {cm.output}",
)
self.assertTrue(
any("After capturing fx_graph" in log for log in cm.output),
f"Expected DEBUG 'After capturing fx_graph'not found in logs: {cm.output}",
)
self.assertTrue(
any("'activate_num': 3" in log for log in cm.output),
f"Expected DEBUG ''activate_num': 3'not found in logs: {cm.output}",
)
def test_aclgraph_userinput_construct_in_share_memory_with_parameter_and_mutated_dynamic_false(self):
class RecaptureModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear = torch.nn.Linear(32, 32)
def forward(self, qp, q, k, v, scale, actual_seq_lenq, actual_seq_len, actual_seq_len2, narrow_start, x):
pfa0, _ = torch_npu.npu_fused_infer_attention_score(
qp,
k,
v,
num_heads=32,
input_layout="BNSD",
scale=scale,
softmax_lse_flag=False,
actual_seq_lengths=actual_seq_lenq,
actual_seq_lengths_kv=actual_seq_len,
)
q = q * scale
ifa1, _ = torch_npu.npu_fused_infer_attention_score(
q,
k,
v,
num_heads=32,
input_layout="BNSD",
scale=scale,
softmax_lse_flag=False,
actual_seq_lengths_kv=actual_seq_len,
)
mm1 = ifa1.view([ifa1.shape[-1], -1]).clone()
q = q + 0.01
ifa2, _ = torch_npu.npu_fused_infer_attention_score(
q,
k,
v,
num_heads=32,
input_layout="BNSD",
scale=scale,
softmax_lse_flag=False,
actual_seq_lengths_kv=[66, 166, 266],
)
mm2 = ifa2.view([-1, ifa2.shape[-1]]).clone()
mmm = torch.mm(mm1, mm2) + pfa0.mean()
k.mul_(1.1)
v = v / 1.1
ifa3 = torch_npu.npu_fused_infer_attention_score(
q,
k,
v,
num_heads=32,
input_layout="BNSD",
scale=scale,
softmax_lse_flag=False,
actual_seq_lengths_kv=actual_seq_len2,
)
add3 = ifa3[0]
add3 = torch.narrow(add3, -1, 32, 32)
add3 = add3 @ self.linear(x)
res = add3 * mmm.mean()
return res
tng_config = torchair.CompilerConfig()
tng_config.mode = 'reduce-overhead'
tng_config.debug.aclgraph.disable_reinplace_inplaceable_ops_pass = True
tng_config.experimental_config.keep_inference_input_mutations = True
npu_backend = torchair.get_npu_backend(compiler_config=tng_config)
model1 = RecaptureModel().npu()
length_new = [88, 99, 1]
length2_new = [40, 50, 60]
lengthq_new = [99, 50, 10]
scale = 1 / 0.0078125
narrow_start = 32
query_prefill_ = torch.randn(3, 32, 512, 128, dtype=torch.float16).npu()
query_ = torch.randn(3, 32, 1, 128, dtype=torch.float16).npu()
key_ = torch.randn(3, 32, 512, 128, dtype=torch.float16).npu()
value_ = torch.randn(3, 32, 512, 128, dtype=torch.float16).npu()
key = key_.clone()
torch._dynamo.mark_static(query_prefill_)
torch._dynamo.mark_static(query_)
torch._dynamo.mark_static(key_)
torch._dynamo.mark_static(value_)
x = torch.randn([32, 32]).npu()
a = torch.ones(32, 32).npu()
b = torch.zeros(32, 32).npu()
compiled_model1 = torch.compile(model1, backend=npu_backend, fullgraph=True, dynamic=False)
with self.assertLogs(logger, level="DEBUG") as cm:
compiled_model1.linear.weight.data = a
graph_res1 = compiled_model1(
query_prefill_, query_, key_, value_, scale, length_new, length2_new, lengthq_new, narrow_start, x
)
self.assertTrue(
any("No find captured AclGraph" in log for log in cm.output),
f"Expected DEBUG 'No find captured AclGraph'not found in logs: {cm.output}",
)
self.assertTrue(
any("After capturing fx_graph" in log for log in cm.output),
f"Expected DEBUG 'After capturing fx_graph'not found in logs: {cm.output}",
)
self.assertTrue(
any("'activate_num': 4" in log for log in cm.output),
f"Expected DEBUG ''activate_num': 4'not found in logs: {cm.output}",
)
with self.assertLogs(logger, level="DEBUG") as cm:
graph_res2 =compiled_model1(
query_prefill_, query_, key_, value_, scale, length_new, length2_new, lengthq_new, narrow_start, x
)
self.assertTrue(
any("The current AclGraph no needs to be recaptured" in log for log in cm.output),
f"Expected DEBUG 'The current AclGraph no needs to be recaptured'not found in logs: {cm.output}",
)
with self.assertLogs(logger, level="DEBUG") as cm:
graph_res3 =compiled_model1(
query_prefill_, query_, key, value_, scale, length_new, length2_new, lengthq_new, narrow_start, x
)
self.assertTrue(
any("The current AclGraph needs to be recaptured" in log for log in cm.output),
f"Expected DEBUG 'The current AclGraph needs to be recaptured'not found in logs: {cm.output}",
)
self.assertTrue(
any("After capturing fx_graph" in log for log in cm.output),
f"Expected DEBUG 'After capturing fx_graph'not found in logs: {cm.output}",
)
self.assertTrue(
any("'activate_num': 7" in log for log in cm.output),
f"Expected DEBUG ''activate_num': 7'not found in logs: {cm.output}",
)
del graph_res1
del graph_res2
del graph_res3
with self.assertLogs(logger, level="DEBUG") as cm:
compiled_model1.linear.weight.data = b
graph_res4 = compiled_model1(
query_prefill_, query_, key, value_, scale, length_new, length2_new, lengthq_new, narrow_start, x
)
self.assertFalse(
any("The current AclGraph needs to be recaptured" in log for log in cm.output),
f"Expected DEBUG 'The current AclGraph needs to be recaptured'not found in logs: {cm.output}",
)
self.assertTrue(
any("After capturing fx_graph" in log for log in cm.output),
f"Not Expected DEBUG 'After capturing fx_graph'found in logs: {cm.output}",
)
self.assertTrue(
any("'activate_num': 4" in log for log in cm.output),
f"Expected DEBUG ''activate_num': 4'not found in logs: {cm.output}",
)
def test_aclgraph_userinput_construct_in_share_memory_with_parameter_and_mutated_clone_input_false_static(self):
class RecaptureModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear = torch.nn.Linear(32, 32)
def forward(self, qp, q, k, v, scale, actual_seq_lenq, actual_seq_len, actual_seq_len2, narrow_start, x):
pfa0, _ = torch_npu.npu_fused_infer_attention_score(
qp,
k,
v,
num_heads=32,
input_layout="BNSD",
scale=scale,
softmax_lse_flag=False,
actual_seq_lengths=actual_seq_lenq,
actual_seq_lengths_kv=actual_seq_len,
)
q = q * scale
ifa1, _ = torch_npu.npu_fused_infer_attention_score(
q,
k,
v,
num_heads=32,
input_layout="BNSD",
scale=scale,
softmax_lse_flag=False,
actual_seq_lengths_kv=actual_seq_len,
)
mm1 = ifa1.view([ifa1.shape[-1], -1]).clone()
q = q + 0.01
ifa2, _ = torch_npu.npu_fused_infer_attention_score(
q,
k,
v,
num_heads=32,
input_layout="BNSD",
scale=scale,
softmax_lse_flag=False,
actual_seq_lengths_kv=[66, 166, 266],
)
mm2 = ifa2.view([-1, ifa2.shape[-1]]).clone()
mmm = torch.mm(mm1, mm2) + pfa0.mean()
k.mul_(1.1)
v = v / 1.1
ifa3 = torch_npu.npu_fused_infer_attention_score(
q,
k,
v,
num_heads=32,
input_layout="BNSD",
scale=scale,
softmax_lse_flag=False,
actual_seq_lengths_kv=actual_seq_len2,
)
add3 = ifa3[0]
add3 = torch.narrow(add3, -1, 32, 32)
add3 = add3 @ self.linear(x)
res = add3 * mmm.mean()
return res
tng_config = torchair.CompilerConfig()
tng_config.mode = 'reduce-overhead'
tng_config.debug.aclgraph.disable_reinplace_inplaceable_ops_pass = True
tng_config.debug.aclgraph.clone_input = False
tng_config.experimental_config.keep_inference_input_mutations = True
npu_backend = torchair.get_npu_backend(compiler_config=tng_config)
model1 = RecaptureModel().npu()
length_new = [88, 99, 1]
length2_new = [40, 50, 60]
lengthq_new = [99, 50, 10]
scale = 1 / 0.0078125
narrow_start = 32
query_prefill_ = torch.randn(3, 32, 512, 128, dtype=torch.float16).npu()
query_ = torch.randn(3, 32, 1, 128, dtype=torch.float16).npu()
key_ = torch.randn(3, 32, 512, 128, dtype=torch.float16).npu()
value_ = torch.randn(3, 32, 512, 128, dtype=torch.float16).npu()
key = key_.clone()
torch._dynamo.mark_static(query_prefill_)
torch._dynamo.mark_static(query_)
torch._dynamo.mark_static(key_)
torch._dynamo.mark_static(value_)
x = torch.randn([32, 32]).npu()
a = torch.ones(32, 32).npu()
b = torch.zeros(32, 32).npu()
compiled_model1 = torch.compile(model1, backend=npu_backend, fullgraph=True, dynamic=False)
with self.assertLogs(logger, level="DEBUG") as cm:
compiled_model1.linear.weight.data = a
graph_res1 = compiled_model1(
query_prefill_, query_, key_, value_, scale, length_new, length2_new, lengthq_new, narrow_start, x
)
self.assertTrue(
any("No find captured AclGraph" in log for log in cm.output),
f"Expected DEBUG 'No find captured AclGraph'not found in logs: {cm.output}",
)
self.assertTrue(
any("After capturing fx_graph" in log for log in cm.output),
f"Expected DEBUG 'After capturing fx_graph'not found in logs: {cm.output}",
)
self.assertTrue(
any("'activate_num': 3" in log for log in cm.output),
f"Expected DEBUG ''activate_num': 3'not found in logs: {cm.output}",
)
with self.assertLogs(logger, level="DEBUG") as cm:
graph_res2 = compiled_model1(
query_prefill_, query_, key_, value_, scale, length_new, length2_new, lengthq_new, narrow_start, x
)
self.assertTrue(
any("The current AclGraph no needs to be recaptured" in log for log in cm.output),
f"Expected DEBUG 'The current AclGraph no needs to be recaptured'not found in logs: {cm.output}",
)
with self.assertLogs(logger, level="DEBUG") as cm:
graph_res3 = compiled_model1(
query_prefill_, query_, key, value_, scale, length_new, length2_new, lengthq_new, narrow_start, x
)
self.assertTrue(
any("The current AclGraph needs to be recaptured" in log for log in cm.output),
f"Expected DEBUG 'The current AclGraph needs to be recaptured'not found in logs: {cm.output}",
)
self.assertTrue(
any("After capturing fx_graph" in log for log in cm.output),
f"Expected DEBUG 'After capturing fx_graph'not found in logs: {cm.output}",
)
self.assertTrue(
any("'activate_num': 6" in log for log in cm.output),
f"Expected DEBUG ''activate_num': 6'not found in logs: {cm.output}",
)
del graph_res1
del graph_res2
del graph_res3
with self.assertLogs(logger, level="DEBUG") as cm:
compiled_model1.linear.weight.data = b
compiled_model1(
query_prefill_, query_, key, value_, scale, length_new, length2_new, lengthq_new, narrow_start, x
)
self.assertFalse(
any("The current AclGraph needs to be recaptured" in log for log in cm.output),
f"Not Expected DEBUG 'The current AclGraph needs to be recaptured'found in logs: {cm.output}",
)
self.assertTrue(
any("After capturing fx_graph" in log for log in cm.output),
f"Expected DEBUG 'After capturing fx_graph'not found in logs: {cm.output}",
)
self.assertTrue(
any("'activate_num': 3" in log for log in cm.output),
f"Expected DEBUG ''activate_num': 3'not found in logs: {cm.output}",
)
def test_aclgraph_userinput_construct_in_share_memory_with_multiple_fx(self):
class RecaptureModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear = torch.nn.Linear(32, 32)
def forward(
self,
qp,
q,
k,
v,
scale,
actual_seq_lenq,
actual_seq_len,
actual_seq_len2,
narrow_start,
x,
is_prompt=True,
):
pfa0, _ = torch_npu.npu_fused_infer_attention_score(
qp,
k,
v,
num_heads=32,
input_layout="BNSD",
scale=scale,
softmax_lse_flag=False,
actual_seq_lengths=actual_seq_lenq,
actual_seq_lengths_kv=actual_seq_len,
)
q = q * scale
ifa1, _ = torch_npu.npu_fused_infer_attention_score(
q,
k,
v,
num_heads=32,
input_layout="BNSD",
scale=scale,
softmax_lse_flag=False,
actual_seq_lengths_kv=actual_seq_len,
)
mm1 = ifa1.view([ifa1.shape[-1], -1]).clone()
q = q + 0.01
ifa2, _ = torch_npu.npu_fused_infer_attention_score(
q,
k,
v,
num_heads=32,
input_layout="BNSD",
scale=scale,
softmax_lse_flag=False,
actual_seq_lengths_kv=[66, 166, 266],
)
mm2 = ifa2.view([-1, ifa2.shape[-1]]).clone()
mmm = torch.mm(mm1, mm2) + pfa0.mean()
k.mul_(1.1)
v = v / 1.1
ifa3 = torch_npu.npu_fused_infer_attention_score(
q,
k,
v,
num_heads=32,
input_layout="BNSD",
scale=scale,
softmax_lse_flag=False,
actual_seq_lengths_kv=actual_seq_len2,
)
if is_prompt:
add3 = ifa3[0]
add3 = torch.narrow(add3, -1, 32, 32)
add3 = add3 @ self.linear(x)
res = add3 * mmm.mean()
else:
res = ifa3[0]
return res
tng_config = torchair.CompilerConfig()
tng_config.mode = 'reduce-overhead'
tng_config.debug.aclgraph.disable_reinplace_inplaceable_ops_pass = True
tng_config.experimental_config.keep_inference_input_mutations = True
npu_backend = torchair.get_npu_backend(compiler_config=tng_config)
model1 = RecaptureModel().npu()
length_new = [88, 99, 1]
length2_new = [40, 50, 60]
lengthq_new = [99, 50, 10]
scale = 1 / 0.0078125
narrow_start = 32
query_prefill_ = torch.randn(3, 32, 512, 128, dtype=torch.float16).npu()
query_ = torch.randn(3, 32, 1, 128, dtype=torch.float16).npu()
key_ = torch.randn(3, 32, 512, 128, dtype=torch.float16).npu()
value_ = torch.randn(3, 32, 512, 128, dtype=torch.float16).npu()
key_.clone()
torch._dynamo.mark_static(query_prefill_)
torch._dynamo.mark_static(query_)
torch._dynamo.mark_static(key_)
torch._dynamo.mark_static(value_)
x = torch.randn([32, 32]).npu()
a = torch.ones(32, 32).npu()
compiled_model1 = torch.compile(model1, backend=npu_backend, fullgraph=True, dynamic=True)
with self.assertLogs(logger, level="DEBUG") as cm:
compiled_model1.linear.weight.data = a
graph_res1 = compiled_model1(
query_prefill_, query_, key_, value_, scale, length_new, length2_new, lengthq_new, narrow_start, x, True
)
self.assertTrue(
any("No find captured AclGraph" in log for log in cm.output),
f"Expected DEBUG 'No find captured AclGraph'not found in logs: {cm.output}",
)
self.assertTrue(
any("After capturing fx_graph" in log for log in cm.output),
f"Expected DEBUG 'After capturing fx_graph'not found in logs: {cm.output}",
)
self.assertTrue(
any("'activate_num': 4" in log for log in cm.output),
f"Expected DEBUG ''activate_num': 4'not found in logs: {cm.output}",
)
with self.assertLogs(logger, level="DEBUG") as cm:
compiled_model1.linear.weight.data = a
compiled_model1(
query_prefill_,
query_,
key_,
value_,
scale,
length_new,
length2_new,
lengthq_new,
narrow_start,
x,
False,
)
self.assertTrue(
any("No find captured AclGraph" in log for log in cm.output),
f"Expected DEBUG 'No find captured AclGraph'not found in logs: {cm.output}",
)
self.assertTrue(
any("After capturing fx_graph" in log for log in cm.output),
f"Expected DEBUG 'After capturing fx_graph'not found in logs: {cm.output}",
)
self.assertTrue(
any("'activate_num': 2" in log for log in cm.output),
f"Expected DEBUG ''activate_num': 2'not found in logs: {cm.output}",
)
with self.assertLogs(logger, level="DEBUG") as cm:
compiled_model1.linear.weight.data = a
compiled_model1(
query_prefill_, query_, key_, value_, scale, length_new, length2_new, lengthq_new, narrow_start, x, True
)
self.assertTrue(
any("The current AclGraph no needs to be recaptured" in log for log in cm.output),
f"Expected DEBUG 'The current AclGraph no needs to be recaptured'not found in logs: {cm.output}",
)
def test_aclgraph_static_kernel(self):
class Model1(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear1 = torch.nn.Linear(2, 2)
self.linear2 = torch.nn.Linear(2, 2)
def forward(self, x):
ln1 = self.linear1(x)
ln2 = self.linear2(x)
return ln1 + ln2
class Model2(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear1 = torch.nn.Linear(2, 2)
self.linear2 = torch.nn.Linear(2, 2)
def forward(self, x):
ln1 = self.linear1(x)
ln2 = self.linear2(x)
return ln1 + ln2
x = torch.randn(2, 2).npu()
config = torchair.CompilerConfig()
config.mode = "reduce-overhead"
config.experimental_config.aclgraph._aclnn_static_shape_kernel = True
npu_mode1 = Model1().npu()
npu_backend1 = torchair.get_npu_backend(compiler_config=config)
npu_mode1 = torch.compile(npu_mode1, fullgraph=True, backend=npu_backend1, dynamic=False)
with self.assertLogs(logger, level="DEBUG") as cm1:
_ = npu_mode1(x)
self.assertTrue(
any("reselect_static_kernel executed successfully" in log for log in cm1.output),
f"Expected DEBUG 'reselect_static_kernel executed successfully' not found in logs: {cm1.output}",
)
npu_mode2 = Model2().npu()
npu_backend2 = torchair.get_npu_backend(compiler_config=config)
npu_mode2 = torch.compile(npu_mode2, fullgraph=True, backend=npu_backend2, dynamic=False)
with self.assertLogs(logger, level="DEBUG") as cm2:
_ = npu_mode2(x)
self.assertTrue(
any("Static compilation skipped" in log or "Using debug directory" in log for log in cm2.output),
f"Expected DEBUG 'Static compilation skipped' or 'Using debug directory' not found in logs: {cm2.output}",
)
def test_aclgraph_static_kernel_debug_time(self):
class Model1(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear1 = torch.nn.Linear(2, 2)
self.linear2 = torch.nn.Linear(2, 2)
def forward(self, x):
ln1 = self.linear1(x)
ln2 = self.linear2(x)
return ln1 + ln2
x = torch.randn([2, 2], dtype=torch.float16).npu()
config = torchair.CompilerConfig()
config.mode = "reduce-overhead"
config.experimental_config.aclgraph._aclnn_static_shape_kernel = True
npu_mode1 = Model1().npu()
npu_backend1 = torchair.get_npu_backend(compiler_config=config)
npu_mode1 = torch.compile(npu_mode1, fullgraph=True, backend=npu_backend1, dynamic=True)
with self.assertLogs(logger, level="DEBUG") as cm1:
_ = npu_mode1(x)
self.assertTrue(
any("The operation of [static kernel] execute dump json phase" in log for log in cm1.output),
f"Expected DEBUG 'The operation of [static kernel] execute dump json phase' "
f"not found in logs: {cm1.output}",
)
self.assertTrue(
any("The operation of [static kernel] static compile phase" in log for log in cm1.output),
f"Expected DEBUG 'The operation of [static kernel] static compile phase' not found in logs: {cm1.output}",
)
self.assertTrue(
any("The operation of [static kernel] install static kernel run pkgs phase" in log for log in cm1.output),
f"Expected DEBUG 'The operation of [static kernel] install static kernel run pkgs phase' "
f"not found in logs: {cm1.output}",
)
x = torch.randn([2, 2], dtype=torch.float32).npu()
with self.assertLogs(logger, level="INFO") as cm2:
_ = npu_mode1(x)
self.assertFalse(
any("The operation of [static kernel]" in log for log in cm2.output),
f"Not Expected DEBUG 'The operation of [static kernel]' found in logs: {cm2.output}",
)
def test_aclgraph_userinput_construct_in_share_memory_with_cache_compile(self):
class RecaptureModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear = torch.nn.Linear(32, 32)
self.cached_prompt = torchair.inference.cache_compile(self.prompt, config=tng_config)
self.cached_decode = torchair.inference.cache_compile(self.decode, config=tng_config)
def forward(
self,
qp,
q,
k,
v,
scale,
actual_seq_lenq,
actual_seq_len,
actual_seq_len2,
narrow_start,
x,
is_prompt=True,
):
if is_prompt:
return self.cached_prompt(
qp, q, k, v, scale, actual_seq_lenq, actual_seq_len, actual_seq_len2, narrow_start, x, is_prompt
)
else:
return self.cached_decode(
qp, q, k, v, scale, actual_seq_lenq, actual_seq_len, actual_seq_len2, narrow_start, x, is_prompt
)
def prompt(
self, qp, q, k, v, scale, actual_seq_lenq, actual_seq_len, actual_seq_len2, narrow_start, x, is_prompt
):
return self._forward(
qp, q, k, v, scale, actual_seq_lenq, actual_seq_len, actual_seq_len2, narrow_start, x, is_prompt
)
def decode(
self, qp, q, k, v, scale, actual_seq_lenq, actual_seq_len, actual_seq_len2, narrow_start, x, is_prompt
):
return self._forward(
qp, q, k, v, scale, actual_seq_lenq, actual_seq_len, actual_seq_len2, narrow_start, x, is_prompt
)
def _forward(
self,
qp,
q,
k,
v,
scale,
actual_seq_lenq,
actual_seq_len,
actual_seq_len2,
narrow_start,
x,
is_prompt=True,
):
pfa0, _ = torch_npu.npu_fused_infer_attention_score(
qp,
k,
v,
num_heads=32,
input_layout="BNSD",
scale=scale,
softmax_lse_flag=False,
actual_seq_lengths=actual_seq_lenq,
actual_seq_lengths_kv=actual_seq_len,
)
q = q * scale
ifa1, _ = torch_npu.npu_fused_infer_attention_score(
q,
k,
v,
num_heads=32,
input_layout="BNSD",
scale=scale,
softmax_lse_flag=False,
actual_seq_lengths_kv=actual_seq_len,
)
mm1 = ifa1.view([ifa1.shape[-1], -1]).clone()
q = q + 0.01
ifa2, _ = torch_npu.npu_fused_infer_attention_score(
q,
k,
v,
num_heads=32,
input_layout="BNSD",
scale=scale,
softmax_lse_flag=False,
actual_seq_lengths_kv=[66, 166, 266],
)
mm2 = ifa2.view([-1, ifa2.shape[-1]]).clone()
mmm = torch.mm(mm1, mm2) + pfa0.mean()
k.mul_(1.1)
v = v / 1.1
ifa3 = torch_npu.npu_fused_infer_attention_score(
q,
k,
v,
num_heads=32,
input_layout="BNSD",
scale=scale,
softmax_lse_flag=False,
actual_seq_lengths_kv=actual_seq_len2,
)
if is_prompt:
add3 = ifa3[0]
add3 = torch.narrow(add3, -1, 32, 32)
add3 = add3 @ self.linear(x)
res = add3 * mmm.mean()
else:
res = ifa3[0]
return res
tng_config = torchair.CompilerConfig()
tng_config.mode = 'reduce-overhead'
tng_config.debug.aclgraph.disable_reinplace_inplaceable_ops_pass = True
tng_config.experimental_config.keep_inference_input_mutations = True
length_new = [88, 99, 1]
length2_new = [40, 50, 60]
lengthq_new = [99, 50, 10]
scale = 1 / 0.0078125
narrow_start = 32
query_prefill_ = torch.randn(3, 32, 512, 128, dtype=torch.float16).npu()
query_ = torch.randn(3, 32, 1, 128, dtype=torch.float16).npu()
key_ = torch.randn(3, 32, 512, 128, dtype=torch.float16).npu()
value_ = torch.randn(3, 32, 512, 128, dtype=torch.float16).npu()
torch._dynamo.mark_static(query_prefill_)
torch._dynamo.mark_static(query_)
torch._dynamo.mark_static(key_)
torch._dynamo.mark_static(value_)
x = torch.randn([32, 32]).npu()
a = torch.ones(32, 32).npu()
model1 = RecaptureModel().npu()
model1.linear.weight.data = a
from torchair.inference._cache_compiler import CompiledModel, ModelCacheSaver
prompt_cache_bin = CompiledModel.get_cache_bin(model1.prompt, config=tng_config)
ModelCacheSaver.remove_cache(os.path.abspath(os.path.dirname(prompt_cache_bin)))
prompt_cache_dir = os.path.abspath(os.path.dirname(prompt_cache_bin))
decode_cache_bin = CompiledModel.get_cache_bin(model1.decode, config=tng_config)
ModelCacheSaver.remove_cache(os.path.abspath(os.path.dirname(decode_cache_bin)))
decode_cache_dir = os.path.abspath(os.path.dirname(decode_cache_bin))
self.assertFalse(os.path.exists(prompt_cache_dir))
model1(query_prefill_, query_, key_, value_, scale, length_new, length2_new, lengthq_new, narrow_start, x, True)
self.assertTrue(os.path.exists(prompt_cache_dir))
self.assertFalse(os.path.exists(decode_cache_dir))
model1(
query_prefill_, query_, key_, value_, scale, length_new, length2_new, lengthq_new, narrow_start, x, False
)
self.assertTrue(os.path.exists(decode_cache_dir))
torch._dynamo.reset()
model2 = RecaptureModel().npu()
with self.assertLogs(logger, level="DEBUG") as cm:
model2.linear.weight.data = a
model2(
query_prefill_, query_, key_, value_, scale, length_new, length2_new, lengthq_new, narrow_start, x, True
)
model2(
query_prefill_,
query_,
key_,
value_,
scale,
length_new,
length2_new,
lengthq_new,
narrow_start,
x,
False,
)
self.assertTrue(
any("Rebasing" in log for log in cm.output),
f"Expected DEBUG cache_compile 'Rebasing'not found in logs: {cm.output}",
)
self.assertTrue(
any("No find captured AclGraph" in log for log in cm.output),
f"Expected DEBUG 'No find captured AclGraph'not found in logs: {cm.output}",
)
self.assertTrue(
any("After capturing fx_graph" in log for log in cm.output),
f"Expected DEBUG 'After capturing fx_graph'not found in logs: {cm.output}",
)
def test_replay_update_stream_same(self):
class MM(torch.nn.Module):
def forward(self, q, k, v, scale, actual_seq_len):
ifa, _ = torch_npu.npu_fused_infer_attention_score(
q,
k,
v,
num_heads=32,
input_layout="BNSD",
scale=scale,
softmax_lse_flag=False,
actual_seq_lengths_kv=actual_seq_len,
)
return ifa
length = [28, 29, 1]
scale = 1 / 0.0078125
query = torch.randn(3, 32, 1, 128, dtype=torch.float16).npu()
key = torch.randn(3, 32, 512, 128, dtype=torch.float16).npu()
value = torch.randn(3, 32, 512, 128, dtype=torch.float16).npu()
torch._dynamo.mark_static(query)
torch._dynamo.mark_static(key)
torch._dynamo.mark_static(value)
mm = MM()
compiler_config = torchair.CompilerConfig()
compiler_config.mode = 'reduce-overhead'
npu_backend = torchair.get_npu_backend(compiler_config=compiler_config)
mmc = torch.compile(mm, backend=npu_backend, dynamic=True)
replay_stream = torch.npu.Stream(priority=-1)
print(f"replay stream: {replay_stream.stream_id}")
with torch.npu.stream(replay_stream):
_ = mmc(query, key, value, scale, length)
torch.npu.synchronize()
update_stream = torchair._acl_concrete_graph.acl_graph.CapturedGraphUpdateAndReplay._update_stream
with self.assertLogs(logger, level="INFO") as cm:
with torch.npu.stream(update_stream):
_ = mmc(query, key, value, scale, length)
torch.npu.synchronize()
self.assertTrue(
any("Update the stream for parameter" in log for log in cm.output),
f"Expected DEBUG 'The current AclGraph needs to be recaptured'not found in logs: {cm.output}",
)
@unittest.skipIf(
'ATB_HOME_PATH' not in os.environ,
"_npu_paged_attention is unsupported without ATB_HOME_PATH environment variable",
)
def test_aclgraph_update_param_with__npu_paged_attention(self):
class Model(torch.nn.Module):
def __init__(self):
super(Model, self).__init__()
def forward(self, query, key_cache, value_cache, block_table, context_lens):
output = torch.zeros_like(query[:, :, :96])
torch_npu._npu_paged_attention(
query=query,
key_cache=key_cache,
value_cache=value_cache,
num_kv_heads=16,
num_heads=32,
scale_value=0.38888,
block_table=block_table,
context_lens=context_lens,
out=output,
)
return output + 1
from torch._dynamo import allow_in_graph
allow_in_graph(torch_npu._npu_paged_attention)
model = Model()
config = CompilerConfig()
config.mode = 'reduce-overhead'
npu_backend = torchair.get_npu_backend(compiler_config=config)
compiled_model = torch.compile(model, fullgraph=True, backend=npu_backend, dynamic=True)
num_blocks = 64
num_tokens = 2
block_size = 128
kv_heads = 16
head_size = 288
num_heads = 32
head_size_v = 96
import random
import numpy as np
query_np = np.random.uniform(-1, 1, (num_tokens, num_heads, head_size)).astype(np.float16)
key_cache_np = np.random.uniform(-1, 1, (num_blocks, block_size, kv_heads, head_size)).astype(np.float16)
value_cache_np = np.random.uniform(-1, 1, (num_blocks, block_size, kv_heads, head_size_v)).astype(np.float16)
max_blocks_per_seq = (1024 + block_size - 1) // block_size
block_table_np = np.array(
[[random.randint(0, num_blocks - 1) for _ in range(max_blocks_per_seq)] for _ in range(num_tokens)],
dtype=np.int32,
)
context_lens_np = np.full(num_tokens, 128, dtype=np.int32)
context_lens_np_new = np.full(num_tokens, 512, dtype=np.int32)
query = torch.from_numpy(query_np).npu()
key_cache = torch.from_numpy(key_cache_np).npu()
value_cache = torch.from_numpy(value_cache_np).npu()
block_table = torch.from_numpy(block_table_np).npu()
context_lens = torch.from_numpy(context_lens_np)
context_lens_new = torch.from_numpy(context_lens_np_new)
torch._dynamo.mark_static(query)
torch._dynamo.mark_static(key_cache)
torch._dynamo.mark_static(value_cache)
torch._dynamo.mark_static(block_table)
eager_res1 = model(query, key_cache, value_cache, block_table, context_lens)
eager_res2 = model(query, key_cache, value_cache, block_table, context_lens_new)
with self.assertLogs(logger, level="DEBUG") as cm:
graph_res1 = compiled_model(query, key_cache, value_cache, block_table, context_lens)
self.assertTrue(torch.allclose(eager_res1, graph_res1))
graph_res2 = compiled_model(query, key_cache, value_cache, block_table, context_lens_new)
self.assertTrue(torch.allclose(eager_res2, graph_res2))
self.assertTrue(
any("Replay AclGraph and update input params successfully" in log for log in cm.output),
f"Expected DEBUG 'Replay AclGraph and update input params successfully'not found in logs: {cm.output}",
)
def test_aclgraph_scope_with_post_pass(self):
class Network(torch.nn.Module):
def __init__(self):
super(Network, self).__init__()
self.relu = torch.nn.ReLU()
def forward(self, x, y, z):
sqrt_01 = torch.sqrt(x)
softmax_01 = torch.softmax(sqrt_01, dim=-1)
abs_01 = torch.abs(softmax_01)
split_01, split_02 = torch.split(abs_01, split_size_or_sections=[6, 6], dim=0)
matmul_01 = torch.matmul(split_01, y)
add_01 = torch.add(split_02, matmul_01)
concat_01 = torch.cat([add_01, z], dim=0)
relu_01 = self.relu(concat_01)
transpose_01 = torch.transpose(relu_01, 0, 1)
return transpose_01
def parallel_abs_sub_1(gm, example_inputs, config: torchair.CompilerConfig):
fx_graph = gm.graph
for node in fx_graph.nodes:
if node.op == "call_function" and node.target == torch.ops.aten.sqrt.default:
with fx_graph.inserting_before(node):
fx_graph.call_function(
torch.ops.air.scope_enter.default, args=(["_user_stream_label"], ["stream0"])
)
if node.op == "call_function" and node.target == torch.ops.aten.add.Tensor:
with fx_graph.inserting_after(node):
fx_graph.call_function(torch.ops.air.scope_exit.default, args=())
def parallel_abs_sub_2(gm, example_inputs, config: torchair.CompilerConfig):
fx_graph = gm.graph
for node in fx_graph.nodes:
if node.op == "call_function" and node.target == torch.ops.aten._softmax.default:
with fx_graph.inserting_before(node):
fx_graph.call_function(
torch.ops.air.scope_enter.default, args=(["_user_stream_label"], ["stream1"])
)
if node.op == "call_function" and node.target == torch.ops.aten.split_with_sizes.default:
with fx_graph.inserting_after(node):
fx_graph.call_function(torch.ops.air.scope_exit.default, args=())
config = CompilerConfig()
config.mode = "reduce-overhead"
config.debug.aclgraph.clone_input = False
config.debug.aclgraph.disable_reinplace_inplaceable_ops_pass = True
config.post_grad_custom_pre_pass = parallel_abs_sub_1
config.post_grad_custom_post_pass = parallel_abs_sub_2
npu_backend = torchair.get_npu_backend(compiler_config=config)
input0 = torch.randn(12, 6, dtype=torch.float32).npu()
input1 = torch.randn(6, 6, dtype=torch.float32).npu()
input2 = torch.randn(12, 6, dtype=torch.float32).npu()
npu_mode = Network()
npu_mode = torch.compile(npu_mode, fullgraph=True, backend=npu_backend, dynamic=False)
npu_mode(input0, input1, input2)
@unittest.skipIf(torch.__version__ < "2.6", "pattern_fusion_pass is unsupported when torch < 2.6")
def test_pattern_pass_batch_matmul_transpose_for_aclgraph(self):
class DsModel(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, x1, x2):
output = torch.matmul(x1, x2).transpose(1, 0)
return output
npu_config = torchair.CompilerConfig()
npu_config.mode = "reduce-overhead"
npu_backend = torchair.get_npu_backend(compiler_config=npu_config)
model = DsModel()
model_compile = torch.compile(model, backend=npu_backend)
x1 = torch.randn(64, 4, 512, dtype=torch.float16, device='npu')
x2 = torch.randn(64, 512, 128, dtype=torch.float16, device='npu')
eager_output = model(x1, x2)
with self.assertLogs(logger, level="DEBUG") as cm, torch.no_grad():
compile_output = model_compile(x1, x2)
self.assertTrue(torch.allclose(eager_output, compile_output))
@unittest.skipIf(torch.__version__ < "2.6", "pattern_fusion_pass is unsupported when torch < 2.6")
def test_close_pattern_pass_batch_matmul_transpose_for_aclgraph(self):
class DsModel(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, x1, x2):
output = torch.matmul(x1, x2).transpose(1, 0)
return output
npu_config = torchair.CompilerConfig()
npu_config.mode = "reduce-overhead"
npu_config.experimental_config.pattern_fusion_pass = False
npu_backend = torchair.get_npu_backend(compiler_config=npu_config)
model = DsModel()
model_compile = torch.compile(model, backend=npu_backend)
x1 = torch.randn(64, 4, 512, dtype=torch.float16, device='npu')
x2 = torch.randn(64, 512, 128, dtype=torch.float16, device='npu')
eager_output = model(x1, x2)
with self.assertLogs(logger, level="DEBUG") as cm, torch.no_grad():
compile_output = model_compile(x1, x2)
self.assertTrue(torch.allclose(eager_output, compile_output))
@unittest.skipIf(torch.__version__ < "2.6", "pattern_fusion_pass is unsupported when torch < 2.6")
def test_pattern_pass_batch_matmul_transpose_for_aclgraph_with_multistream(self):
class DsModel2(torch.nn.Module):
def __init__(self):
super().__init__()
self.event1 = torchair.ops.npu_create_tagged_event(tag="66")
self.event2 = torchair.ops.npu_create_tagged_event(tag="77")
def forward(self, x1, x2):
y = torch.matmul(x1, x2)
torchair.ops.npu_tagged_event_record(self.event1)
with torchair.scope.npu_stream_switch('2', 3):
torchair.ops.npu_tagged_event_wait(self.event1)
output = torch.transpose(y, 1, 0)
torchair.ops.npu_tagged_event_record(self.event2)
torchair.ops.npu_record_tagged_stream(output, '2')
torchair.ops.npu_tagged_event_wait(self.event2)
return output
npu_config = torchair.CompilerConfig()
npu_config.mode = "reduce-overhead"
npu_backend = torchair.get_npu_backend(compiler_config=npu_config)
model = DsModel2()
model_compile = torch.compile(model, backend=npu_backend)
x1 = torch.randn(64, 4, 512, dtype=torch.float16, device='npu')
x2 = torch.randn(64, 512, 128, dtype=torch.float16, device='npu')
with self.assertLogs(logger, level="DEBUG") as cm, torch.no_grad():
model_compile(x1, x2)
@unittest.skipIf(torch.__version__ < "2.6", "pattern_fusion_pass is unsupported when torch < 2.6")
def test_close_pattern_pass_batch_matmul_transpose_for_aclgraph_KN(self):
class DsModel(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, x1, x2):
output = torch.matmul(x1, x2).transpose(1, 0)
return output
npu_config = torchair.CompilerConfig()
npu_config.mode = "reduce-overhead"
npu_backend = torchair.get_npu_backend(compiler_config=npu_config)
model = DsModel()
model_compile = torch.compile(model, backend=npu_backend)
x1 = torch.randn(64, 4, 511, dtype=torch.float16, device='npu')
x2 = torch.randn(64, 511, 128, dtype=torch.float16, device='npu')
eager_output = model(x1, x2)
with self.assertLogs(logger, level="DEBUG") as cm, torch.no_grad():
compile_output = model_compile(x1, x2)
self.assertTrue(torch.allclose(eager_output, compile_output))
@unittest.skipIf(torch.__version__ < "2.6", "pattern_fusion_pass is unsupported when torch < 2.6")
def test_close_pattern_pass_batch_matmul_transpose_for_aclgraph_view(self):
class DsModel(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, x1, x2):
output = torch.matmul(x1, x2).transpose(0, 1)
return output
npu_config = torchair.CompilerConfig()
npu_config.mode = "reduce-overhead"
npu_config.experimental_config.remove_noop_ops = False
npu_backend = torchair.get_npu_backend(compiler_config=npu_config)
model = DsModel()
model_compile = torch.compile(model, backend=npu_backend)
x1 = torch.randn(64, 4, 512, dtype=torch.float16, device='npu')
x2 = torch.randn(64, 512, 128, dtype=torch.float16, device='npu')
eager_output = model(x1, x2)
with self.assertLogs(logger, level="DEBUG") as cm, torch.no_grad():
compile_output = model_compile(x1, x2)
self.assertTrue(torch.allclose(eager_output, compile_output))
@unittest.skipIf(torch.__version__ < "2.6", "pattern_fusion_pass is unsupported when torch < 2.6")
def test_close_pattern_pass_batch_matmul_transpose_for_aclgraph_view1(self):
class DsModel(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, x1, x2):
output = torch.matmul(x1, x2).transpose(0, 2)
return output
npu_config = torchair.CompilerConfig()
npu_config.mode = "reduce-overhead"
npu_config.experimental_config.remove_noop_ops = False
npu_backend = torchair.get_npu_backend(compiler_config=npu_config)
model = DsModel()
model_compile = torch.compile(model, backend=npu_backend)
x1 = torch.randn(64, 4, 512, dtype=torch.float16, device='npu')
x2 = torch.randn(64, 512, 128, dtype=torch.float16, device='npu')
eager_output = model(x1, x2)
with self.assertLogs(logger, level="DEBUG") as cm, torch.no_grad():
compile_output = model_compile(x1, x2)
self.assertTrue(torch.allclose(eager_output, compile_output))
def assert_addrmsnorm_quant(self, after_gm, expect_fused=True):
"""
Check whether the pattern fusion of add_rms_norm + quantize is successful.
"""
check_rules = [
(torch.ops.npu.npu_add_rms_norm_quant.default, expect_fused),
(torch.ops.npu.npu_add_rms_norm.default, not expect_fused),
(torch.ops.npu.npu_quantize.default, not expect_fused),
]
for torch_op, expect_exist in check_rules:
found = find_op(after_gm, torch_op)
if expect_exist:
self.assertTrue(found, f"Expected operator '{torch_op}' but not find")
else:
self.assertFalse(found, f"Not expected operator '{torch_op}' but find")
def get_quant_input(self, last_axis, dtype1, dtype2, dtype3):
"""
Get the input of the add_rms_norm + quantize pattern.
"""
x1 = torch.randn(1, 2, last_axis, dtype=dtype1, device='npu')
x2 = torch.randn(1, 2, last_axis, dtype=dtype1, device='npu')
gamma = torch.ones(last_axis, dtype=dtype1, device='npu')
scales = torch.ones(last_axis, dtype=dtype2, device='npu')
zero_points = torch.zeros(last_axis, dtype=dtype3, device='npu')
return x1, x2, gamma, scales, zero_points
@unittest.skipIf(torch.__version__ < "2.6", "pattern_fusion_pass is unsupported when torch < 2.6")
def test_pattern_pass_addrmsnorm_quant(self):
def f(x1, x2, gamma, scales, zero_points, div_mode=True):
x1 = x1.reshape([1, -1, 16])
x2 = x2.reshape([1, -1, 16])
y, _, xOut = torch_npu.npu_add_rms_norm(x1, x2, gamma, 4e-6)
yOut = torch_npu.npu_quantize(y, scales, zero_points, torch.qint8, axis=-1, div_mode=div_mode)
return yOut, xOut
def f_static(x1, x2, gamma, scales, zero_points):
y, _, xOut = torch_npu.npu_add_rms_norm(x1, x2, gamma, 1e-6)
yOut = torch_npu.npu_quantize(y, scales, zero_points=zero_points, dtype=torch.int8, axis=-1)
return yOut, xOut
def f_no_xout(x1, x2, gamma, scales, zero_points):
y, _, xOut = torch_npu.npu_add_rms_norm(x1, x2, gamma)
yOut = torch_npu.npu_quantize(y, scales, zero_points=zero_points, dtype=torch.qint8, axis=-1)
return yOut
def f_no_xout_with_epsilon(x1, x2, gamma, scales, zero_points):
y, _, xOut = torch_npu.npu_add_rms_norm(x1, x2, gamma, 4e-6)
yOut = torch_npu.npu_quantize(y, scales, zero_points=zero_points, dtype=torch.int8, axis=-1)
return yOut
torchair.npu_fx_compiler._optimize_fx = create_optimize_wrapper(
lambda gm: self.assert_addrmsnorm_quant(gm, True)
)
npu_config = torchair.CompilerConfig()
npu_config.mode = "reduce-overhead"
npu_backend = torchair.get_npu_backend(compiler_config=npu_config)
compile_model = torch.compile(f, backend=npu_backend, fullgraph=True, dynamic=True)
x1, x2, gamma, scales, zero_points = self.get_quant_input(16, torch.float16, torch.float, torch.int32)
y1, y2 = f(x1, x2, gamma, scales, zero_points)
y3, y4 = compile_model(x1, x2, gamma, scales, zero_points)
self.assertTrue(torch.equal(y1, y3))
self.assertTrue(torch.equal(y2, y4))
x1, x2, gamma, scales, zero_points = self.get_quant_input(16, torch.bfloat16, torch.bfloat16, torch.bfloat16)
y1, y2 = f(x1, x2, gamma, scales, zero_points)
y3, y4 = compile_model(x1, x2, gamma, scales, zero_points)
self.assertTrue(torch.equal(y1, y3))
self.assertTrue(torch.equal(y2, y4))
compile_model = torch.compile(f_static, backend=npu_backend, fullgraph=True, dynamic=False)
y1, y2 = f_static(x1, x2, gamma, scales, zero_points)
y3, y4 = compile_model(x1, x2, gamma, scales, zero_points)
self.assertTrue(torch.equal(y1, y3))
self.assertTrue(torch.equal(y2, y4))
compile_model = torch.compile(f_no_xout, backend=npu_backend, fullgraph=True, dynamic=False)
x1, x2, gamma, scales, zero_points = self.get_quant_input(16, torch.bfloat16, torch.bfloat16, torch.bfloat16)
y1 = f_no_xout(x1, x2, gamma, scales, zero_points)
y3 = compile_model(x1, x2, gamma, scales, zero_points)
self.assertTrue(torch.equal(y1, y3))
compile_model = torch.compile(f_no_xout_with_epsilon, backend=npu_backend, fullgraph=True, dynamic=False)
x1, x2, gamma, scales, zero_points = self.get_quant_input(16, torch.bfloat16, torch.bfloat16, torch.bfloat16)
y1 = f_no_xout_with_epsilon(x1, x2, gamma, scales, zero_points)
y3 = compile_model(x1, x2, gamma, scales, zero_points)
self.assertTrue(torch.equal(y1, y3))
@unittest.skipIf(torch.__version__ < "2.6", "pattern_fusion_pass is unsupported when torch < 2.6")
def test_pattern_pass_addrmsnorm_quant_mismatched(self):
def f(x1, x2, gamma, scales, zero_points, out_dtype=torch.qint8, div_mode=True):
x1 = x1.reshape([1, -1, 16])
x2 = x2.reshape([1, -1, 16])
y, _, xOut = torch_npu.npu_add_rms_norm(x1, x2, gamma)
yOut = torch_npu.npu_quantize(y, scales, zero_points, out_dtype, axis=-1, div_mode=div_mode)
return yOut, xOut
def f_use(x1, x2, gamma, scales, zero_points):
x1 = x1.reshape([1, -1, 16])
x2 = x2.reshape([1, -1, 16])
y, _, xOut = torch_npu.npu_add_rms_norm(x1, x2, gamma)
yOut = torch_npu.npu_quantize(y, scales, zero_points=zero_points, dtype=torch.qint8, axis=-1)
yOut = y + yOut
return yOut, xOut
def f_noreshape(x1, x2, gamma, scales, zero_points):
y, _, xOut = torch_npu.npu_add_rms_norm(x1, x2, gamma)
yOut = torch_npu.npu_quantize(y, scales, zero_points=zero_points, dtype=torch.qint8, axis=-1, div_mode=True)
return yOut, xOut
def f_no_xout(x1, x2, gamma, scales, zero_points):
y, _, xOut = torch_npu.npu_add_rms_norm(x1, x2, gamma, 4e-6)
yOut = torch_npu.npu_quantize(y, scales, zero_points=zero_points, dtype=torch.int8, axis=-1, div_mode=False)
return yOut
torchair.npu_fx_compiler._optimize_fx = create_optimize_wrapper(
lambda gm: self.assert_addrmsnorm_quant(gm, False)
)
npu_config = torchair.CompilerConfig()
npu_config.mode = "reduce-overhead"
npu_backend = torchair.get_npu_backend(compiler_config=npu_config)
compile_model = torch.compile(f, backend=npu_backend, fullgraph=True, dynamic=True)
x1, x2, gamma, scales, zero_points = self.get_quant_input(16, torch.float16, torch.float16, torch.uint8)
f(x1, x2, gamma, scales, zero_points)
compile_model(x1, x2, gamma, scales, zero_points)
zero_points = torch.zeros(16, dtype=torch.int8, device='npu')
f(x1, x2, gamma, scales, zero_points)
compile_model(x1, x2, gamma, scales, zero_points)
x1, x2, gamma, scales, zero_points = self.get_quant_input(16, torch.bfloat16, torch.bfloat16, torch.bfloat16)
f(x1, x2, gamma, scales, zero_points, torch.int32)
compile_model(x1, x2, gamma, scales, zero_points, torch.int32)
compile_model = torch.compile(f_use, backend=npu_backend, fullgraph=True, dynamic=True)
f_use(x1, x2, gamma, scales, zero_points)
compile_model(x1, x2, gamma, scales, zero_points)
compile_model = torch.compile(f, backend=npu_backend, fullgraph=True, dynamic=True)
f(x1, x2, gamma, scales, zero_points, div_mode=False)
compile_model(x1, x2, gamma, scales, zero_points, div_mode=False)
compile_model = torch.compile(f_no_xout, backend=npu_backend, fullgraph=True, dynamic=False)
f_no_xout(x1, x2, gamma, scales, zero_points)
compile_model(x1, x2, gamma, scales, zero_points)
gamma = torch.ones(1, 2, 16, dtype=torch.bfloat16, device='npu')
scales = torch.ones(16, dtype=torch.bfloat16, device='npu')
zero_points = torch.zeros(16, dtype=torch.bfloat16, device='npu')
f(x1, x2, gamma, scales, zero_points)
compile_model(x1, x2, gamma, scales, zero_points)
scales = torch.ones(1, dtype=torch.bfloat16, device='npu')
zero_points = torch.zeros(1, dtype=torch.bfloat16, device='npu')
f(x1, x2, gamma, scales, zero_points)
compile_model(x1, x2, gamma, scales, zero_points)
compile_model = torch.compile(f_noreshape, backend=npu_backend, fullgraph=True, dynamic=False)
x1, x2, gamma, scales, zero_points = self.get_quant_input(3, torch.bfloat16, torch.bfloat16, torch.bfloat16)
f_noreshape(x1, x2, gamma, scales, zero_points)
compile_model(x1, x2, gamma, scales, zero_points)
compile_model = torch.compile(f_noreshape, backend=npu_backend, fullgraph=True, dynamic=True)
f_noreshape(x1, x2, gamma, scales, zero_points)
compile_model(x1, x2, gamma, scales, zero_points)
@unittest.skipIf(torch.__version__ < "2.6", "pattern_fusion_pass is unsupported when torch < 2.6")
def test_pattern_pass_addrmsnorm_quant_with_diff_stream(self):
class Model(torch.nn.Module):
def __init__(self):
super().__init__()
self.event1 = torchair.ops.npu_create_tagged_event(tag="22")
self.event2 = torchair.ops.npu_create_tagged_event(tag="33")
def forward(self, x1, x2, gamma, scales, zero_points):
y, _, xOut = torch_npu.npu_add_rms_norm(x1, x2, gamma)
torchair.ops.npu_tagged_event_record(self.event1)
with torchair.scope.npu_stream_switch('2', 3):
torchair.ops.npu_tagged_event_wait(self.event1)
yOut = torch_npu.npu_quantize(y, scales, zero_points=zero_points, dtype=torch.qint8, axis=-1)
torchair.ops.npu_tagged_event_record(self.event2)
torchair.ops.npu_record_tagged_stream(yOut, '2')
torchair.ops.npu_tagged_event_wait(self.event2)
return yOut, xOut
torchair.npu_fx_compiler._optimize_fx = create_optimize_wrapper(
lambda gm: self.assert_addrmsnorm_quant(gm, False)
)
model = Model()
npu_config = torchair.CompilerConfig()
npu_config.mode = "reduce-overhead"
npu_backend = torchair.get_npu_backend(compiler_config=npu_config)
compile_model = torch.compile(model, backend=npu_backend, fullgraph=True, dynamic=True)
x1, x2, gamma, scales, zero_points = self.get_quant_input(16, torch.bfloat16, torch.bfloat16, torch.bfloat16)
compile_model(x1, x2, gamma, scales, zero_points)
@unittest.skipIf(torch.__version__ < "2.7", "pattern_fusion_pass skip_duplicates is unsupported when torch < 2.7")
def test_pattern_pass_addrmsnorm_quant_skip_duplicates(self):
def f(x1, x2):
return x1 + x2
def search_fn(x1, x2, gamma, scales, zero_points, epsilon, dtype):
y, _, x_out = torch.ops.npu.npu_add_rms_norm.default(x1, x2, gamma, epsilon)
y_out = torch.ops.npu.npu_quantize.default(y, scales, zero_points=zero_points, dtype=dtype, axis=-1)
return y_out, x_out
def replace_fn(x1, x2, gamma, scales, zero_points, epsilon, _):
y1, _, x_out = torch.ops.npu.npu_add_rms_norm_quant.default(
x1, x2, gamma, scales, zero_points, axis=-1, epsilon=epsilon
)
return y1, x_out
fake_mode = FakeTensorMode()
with fake_mode:
torchair.register_replacement(
search_fn=search_fn,
replace_fn=replace_fn,
example_inputs=self.get_quant_input(16, torch.bfloat16, torch.bfloat16, torch.bfloat16),
scalar_workaround={"epsilon": 2e-6, "dtype": 1},
skip_duplicates=True,
)
npu_config = torchair.CompilerConfig()
npu_config.mode = "reduce-overhead"
npu_backend = torchair.get_npu_backend(compiler_config=npu_config)
torch.compile(f, backend=npu_backend, fullgraph=True, dynamic=True)
@unittest.skipIf(torch.__version__ < "2.6", "pattern_fusion_pass is unsupported when torch < 2.6")
def test_pattern_pass_transpose_batch_matmul_transpose_for_aclgraph(self):
"""Test transpose+bmm+transpose pattern fusion (perm_x1=[1,0,2])."""
class DsModel(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, x1, x2):
output = torch.matmul(x1.transpose(0, 1), x2).transpose(0, 1)
return output
npu_config = torchair.CompilerConfig()
npu_config.mode = "reduce-overhead"
npu_backend = torchair.get_npu_backend(compiler_config=npu_config)
model = DsModel()
model_compile = torch.compile(model, backend=npu_backend)
x1 = torch.randn(4, 64, 512, dtype=torch.float16, device='npu')
x2 = torch.randn(64, 512, 128, dtype=torch.float16, device='npu')
eager_output = model(x1, x2)
with self.assertLogs(logger, level="DEBUG") as cm, torch.no_grad():
compile_output = model_compile(x1, x2)
self.assertTrue(torch.allclose(eager_output, compile_output))
@unittest.skipIf(torch.__version__ < "2.6", "pattern_fusion_pass is unsupported when torch < 2.6")
def test_close_pattern_pass_transpose_batch_matmul_transpose_for_aclgraph(self):
"""Test transpose+bmm+transpose with pattern_fusion_pass disabled."""
class DsModel(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, x1, x2):
output = torch.matmul(x1.transpose(0, 1), x2).transpose(0, 1)
return output
npu_config = torchair.CompilerConfig()
npu_config.mode = "reduce-overhead"
npu_config.experimental_config.pattern_fusion_pass = False
npu_backend = torchair.get_npu_backend(compiler_config=npu_config)
model = DsModel()
model_compile = torch.compile(model, backend=npu_backend)
x1 = torch.randn(4, 64, 512, dtype=torch.float16, device='npu')
x2 = torch.randn(64, 512, 128, dtype=torch.float16, device='npu')
eager_output = model(x1, x2)
with self.assertLogs(logger, level="DEBUG") as cm, torch.no_grad():
compile_output = model_compile(x1, x2)
self.assertTrue(torch.allclose(eager_output, compile_output))
@unittest.skipIf(torch.__version__ < "2.6", "pattern_fusion_pass is unsupported when torch < 2.6")
def test_pattern_pass_transpose_batch_matmul_transpose_for_aclgraph_with_multistream(self):
"""Test transpose+bmm+transpose with multistream (event record/wait)."""
class DsModel2(torch.nn.Module):
def __init__(self):
super().__init__()
self.event1 = torchair.ops.npu_create_tagged_event(tag="88")
self.event2 = torchair.ops.npu_create_tagged_event(tag="99")
def forward(self, x1, x2):
y = torch.matmul(x1.transpose(0, 1), x2)
torchair.ops.npu_tagged_event_record(self.event1)
with torchair.scope.npu_stream_switch('2', 3):
torchair.ops.npu_tagged_event_wait(self.event1)
output = torch.transpose(y, 0, 1)
torchair.ops.npu_tagged_event_record(self.event2)
torchair.ops.npu_record_tagged_stream(output, '2')
torchair.ops.npu_tagged_event_wait(self.event2)
return output
npu_config = torchair.CompilerConfig()
npu_config.mode = "reduce-overhead"
npu_backend = torchair.get_npu_backend(compiler_config=npu_config)
model = DsModel2()
model_compile = torch.compile(model, backend=npu_backend)
x1 = torch.randn(4, 64, 512, dtype=torch.float16, device='npu')
x2 = torch.randn(64, 512, 128, dtype=torch.float16, device='npu')
with self.assertLogs(logger, level="DEBUG") as cm, torch.no_grad():
compile_output = model_compile(x1, x2)
@unittest.skipIf(torch.__version__ < "2.6", "pattern_fusion_pass is unsupported when torch < 2.6")
def test_close_pattern_pass_transpose_batch_matmul_transpose_for_aclgraph_KN(self):
"""Test transpose+bmm+transpose when K/N not aligned (no fusion)."""
class DsModel(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, x1, x2):
output = torch.matmul(x1.transpose(0, 1), x2).transpose(0, 1)
return output
npu_config = torchair.CompilerConfig()
npu_config.mode = "reduce-overhead"
npu_backend = torchair.get_npu_backend(compiler_config=npu_config)
model = DsModel()
model_compile = torch.compile(model, backend=npu_backend)
x1 = torch.randn(4, 64, 511, dtype=torch.float16, device='npu')
x2 = torch.randn(64, 511, 128, dtype=torch.float16, device='npu')
eager_output = model(x1, x2)
with self.assertLogs(logger, level="DEBUG") as cm, torch.no_grad():
compile_output = model_compile(x1, x2)
self.assertTrue(torch.allclose(eager_output, compile_output))
@unittest.skipIf(torch.__version__ < "2.6", "pattern_fusion_pass is unsupported when torch < 2.6")
def test_close_pattern_pass_transpose_batch_matmul_transpose_for_aclgraph_view(self):
"""Test transpose+bmm+transpose with transpose(1, 0) on output (no-op view case)."""
class DsModel(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, x1, x2):
output = torch.matmul(x1.transpose(0, 1), x2).transpose(1, 0)
return output
npu_config = torchair.CompilerConfig()
npu_config.mode = "reduce-overhead"
npu_config.experimental_config.remove_noop_ops = False
npu_backend = torchair.get_npu_backend(compiler_config=npu_config)
model = DsModel()
model_compile = torch.compile(model, backend=npu_backend)
x1 = torch.randn(4, 64, 512, dtype=torch.float16, device='npu')
x2 = torch.randn(64, 512, 128, dtype=torch.float16, device='npu')
eager_output = model(x1, x2)
with self.assertLogs(logger, level="DEBUG") as cm, torch.no_grad():
compile_output = model_compile(x1, x2)
self.assertTrue(torch.allclose(eager_output, compile_output))
@unittest.skipIf(torch.__version__ < "2.6", "pattern_fusion_pass is unsupported when torch < 2.6")
def test_close_pattern_pass_transpose_batch_matmul_transpose_for_aclgraph_view1(self):
"""Test transpose+bmm+transpose with transpose(0, 2) on output (unsupported dims, no fusion)."""
class DsModel(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, x1, x2):
output = torch.matmul(x1.transpose(0, 1), x2).transpose(0, 2)
return output
npu_config = torchair.CompilerConfig()
npu_config.mode = "reduce-overhead"
npu_config.experimental_config.remove_noop_ops = False
npu_backend = torchair.get_npu_backend(compiler_config=npu_config)
model = DsModel()
model_compile = torch.compile(model, backend=npu_backend)
x1 = torch.randn(4, 64, 512, dtype=torch.float16, device='npu')
x2 = torch.randn(64, 512, 128, dtype=torch.float16, device='npu')
eager_output = model(x1, x2)
with self.assertLogs(logger, level="DEBUG") as cm, torch.no_grad():
compile_output = model_compile(x1, x2)
self.assertTrue(torch.allclose(eager_output, compile_output))
@unittest.skipIf(torch.__version__ < "2.6", "pattern_fusion_pass is unsupported when torch < 2.6")
def test_close_pattern_pass_transpose_batch_matmul_transpose_for_aclgraph_view2(self):
"""Test matmul(x1.transpose(0,2), x2).transpose(0, 1); input transpose(0,2) unsupported (dims>1), no fusion."""
class DsModel(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, x1, x2):
output = torch.matmul(x1.transpose(0, 2), x2).transpose(0, 1)
return output
npu_config = torchair.CompilerConfig()
npu_config.mode = "reduce-overhead"
npu_backend = torchair.get_npu_backend(compiler_config=npu_config)
model = DsModel()
model_compile = torch.compile(model, backend=npu_backend)
x1 = torch.randn(512, 4, 64, dtype=torch.float16, device='npu')
x2 = torch.randn(64, 512, 128, dtype=torch.float16, device='npu')
eager_output = model(x1, x2)
with self.assertLogs(logger, level="DEBUG") as cm, torch.no_grad():
compile_output = model_compile(x1, x2)
self.assertTrue(torch.allclose(eager_output, compile_output))
@unittest.skipIf(torch.__version__ < "2.6", "pattern_fusion_pass is unsupported when torch < 2.6")
def test_close_pattern_pass_transpose_batch_matmul_transpose_for_aclgraph_KB_constraint(self):
"""Test perm_x1=[1,0,2] constraint K*B < 65536: when K*M >= 65536 fusion is skipped, result still correct."""
class DsModel(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, x1, x2):
output = torch.matmul(x1.transpose(0, 1), x2).transpose(0, 1)
return output
npu_config = torchair.CompilerConfig()
npu_config.mode = "reduce-overhead"
npu_backend = torchair.get_npu_backend(compiler_config=npu_config)
model = DsModel()
model_compile = torch.compile(model, backend=npu_backend)
x1 = torch.randn(4, 64, 1152, dtype=torch.float16, device='npu')
x2 = torch.randn(64, 1152, 128, dtype=torch.float16, device='npu')
eager_output = model(x1, x2)
with self.assertLogs(logger, level="DEBUG") as cm, torch.no_grad():
compile_output = model_compile(x1, x2)
self.assertTrue(torch.allclose(eager_output, compile_output))
@unittest.skipIf(torch.__version__ < "2.6", "torch_npu stream api is unsupported when torch < 2.6")
def test_stream_event_replace_in_fx(self):
def cus_func(t):
s = torch.npu.Stream()
tmp = torch.add(t, 2)
event = torch.npu.Event()
event.record()
with torch.npu.stream(s):
event.wait(s)
r = torch.relu(tmp)
r.record_stream(s)
return r
def my_backend(gm: torch.fx.GraphModule, example_inputs):
from torchair._acl_concrete_graph.replace_stream_event import replace_stream_event_pass
gm = replace_stream_event_pass(gm)
fx_target_list = []
for node in gm.graph.nodes:
if hasattr(node.target, "__name__"):
fx_target_list.append(node.target.__name__)
else:
fx_target_list.append(node.target)
if node.name == "record":
self.assertEqual(node.args[0], "graph_0_event")
if node.name == "wait":
self.assertEqual(node.args[0], "graph_0_event")
if node.name == "set_stream":
self.assertIn("graph_0_stream", node.args[1])
print(f"fx_target_list is :{fx_target_list}")
torchair_ir_list = (
'tagged_event_record',
'tagged_event_wait_on_stream',
'record_tagged_stream_',
'scope_enter',
'scope_exit',
)
for torchair_ir in torchair_ir_list:
self.assertIn(torchair_ir, fx_target_list)
return gm
opt_m = torch.compile(cus_func, backend=my_backend, fullgraph=True, dynamic=False)
i = torch.randn([3, 3]).to('npu')
opt_m(i)
@unittest.skipIf(torch.__version__ < "2.6", "torch_npu stream api is unsupported when torch < 2.6")
def test_stream_event_replace_with_set_stream(self):
def cus_func(t):
default_stream = torch.npu.current_stream()
s1 = torch.npu.Stream()
s2 = torch.npu.Stream()
s3 = torch.npu.Stream()
tmp = torch.add(t, 0)
torch.npu.set_stream(s1)
tmp = torch.add(tmp, 1)
current_s1 = torch.npu.current_stream()
torch.npu.set_stream(s2)
tmp = torch.add(tmp, 2)
torch.npu.set_stream(s3)
tmp = torch.add(tmp, 3)
torch.npu.set_stream(current_s1)
r = torch.add(tmp, 4)
torch.npu.set_stream(default_stream)
return r
def my_backend(gm: torch.fx.GraphModule, example_inputs):
from torchair._acl_concrete_graph.replace_stream_event import replace_stream_event_pass
gm = replace_stream_event_pass(gm)
print(f'after replace graph is : {gm.graph}')
fx_target_list = []
fx_node_name_list = []
for node in gm.graph.nodes:
fx_node_name_list.append(node.name)
fx_target_list
if hasattr(node.target, "__name__"):
fx_target_list.append(node.target.__name__)
else:
fx_target_list.append(node.target)
self.assertEqual(fx_target_list.count('scope_enter'), fx_target_list.count('scope_exit'))
self.assertEqual(fx_node_name_list[13], 'set_stream_3')
self.assertEqual(fx_target_list[13], 'set_stream')
self.assertEqual(fx_target_list[14], 'scope_exit')
self.assertEqual(fx_node_name_list[14], 'scope_exit_1')
self.assertEqual(fx_target_list[15], 'scope_exit')
self.assertEqual(fx_node_name_list[15], 'scope_exit')
self.assertEqual(fx_node_name_list[-2], 'scope_exit_2')
self.assertEqual(fx_target_list[-2], 'scope_exit')
return gm
opt_m = torch.compile(cus_func, backend=my_backend, fullgraph=True, dynamic=False)
i = torch.randn([3, 3]).to('npu')
opt_m(i)
@unittest.skipIf(torch.__version__ < "2.6", "torch_npu stream api is unsupported when torch < 2.6")
def test_stream_event_replace_without_set_default(self):
def cus_func(t):
s1 = torch.npu.Stream()
tmp = torch.add(t, 0)
torch.npu.set_stream(s1)
r = torch.add(tmp, 1)
return r
def my_backend(gm: torch.fx.GraphModule, example_inputs):
from torchair._acl_concrete_graph.replace_stream_event import replace_stream_event_pass
gm = replace_stream_event_pass(gm)
return gm
opt_m = torch.compile(cus_func, backend=my_backend, fullgraph=True, dynamic=False)
i = torch.randn([3, 3]).to('npu')
with self.assertRaises(RuntimeError) as context:
r = opt_m(i)
self.assertIn(
"When use npugraph_ex, you must make sure at the end of your code set stream to the same stream "
"as the begin of your code",
str(context.exception),
)
@unittest.skipIf(torch.__version__ < "2.6", "torch_npu stream api is unsupported when torch < 2.6")
def test_event_wait_with_stream(self):
"""Test event.wait(stream) is replaced with tagged_event_wait_on_stream"""
def cus_func(t):
default_stream = torch.npu.current_stream()
s1 = torch.npu.Stream()
e1 = torch.npu.Event()
torch.npu.set_stream(s1)
e1.record()
tmp = torch.add(t, 1)
torch.npu.set_stream(default_stream)
e1.wait(s1)
r = torch.add(tmp, 2)
return r
def my_backend(gm: torch.fx.GraphModule, example_inputs):
from torchair._acl_concrete_graph.replace_stream_event import replace_stream_event_pass
gm = replace_stream_event_pass(gm)
print(f'after replace graph is : {gm.graph}')
node_info = []
for node in gm.graph.nodes:
if hasattr(node.target, "__name__"):
target_name = node.target.__name__
else:
target_name = str(node.target)
node_info.append((node.name, target_name, node.args))
wait_on_stream_nodes = [n for n in node_info if n[1] == 'tagged_event_wait_on_stream']
self.assertEqual(len(wait_on_stream_nodes), 1, "Should have exactly one tagged_event_wait_on_stream node")
args = wait_on_stream_nodes[0][2]
self.assertTrue(args[0].startswith('graph_'), f"event_tag should start with 'graph_', got {args[0]}")
self.assertIsInstance(args[1], str, "stream_id should be str")
self.assertIsInstance(args[2], str, "device_index should be str")
self.assertIsInstance(args[3], str, "device_type should be str")
self.assertEqual(args[4], True, "created_inside should be True")
return gm
opt_m = torch.compile(cus_func, backend=my_backend, fullgraph=True, dynamic=False)
i = torch.randn([3, 3]).to('npu')
opt_m(i)
@unittest.skipIf(torch.__version__ < "2.6", "torch_npu stream api is unsupported when torch < 2.6")
def test_event_record_with_stream(self):
"""Test event.record(stream) is replaced with tagged_event_record_on_stream"""
def cus_func(t):
default_stream = torch.npu.current_stream()
s1 = torch.npu.Stream()
e1 = torch.npu.Event()
torch.npu.set_stream(s1)
tmp = torch.add(t, 1)
torch.npu.set_stream(default_stream)
e1.record(s1)
r = torch.add(tmp, 2)
return r
def my_backend(gm: torch.fx.GraphModule, example_inputs):
from torchair._acl_concrete_graph.replace_stream_event import replace_stream_event_pass
gm = replace_stream_event_pass(gm)
print(f'after replace graph is : {gm.graph}')
node_info = []
for node in gm.graph.nodes:
if hasattr(node.target, "__name__"):
target_name = node.target.__name__
else:
target_name = str(node.target)
node_info.append((node.name, target_name, node.args))
record_on_stream_nodes = [n for n in node_info if n[1] == 'tagged_event_record_on_stream']
self.assertEqual(
len(record_on_stream_nodes), 1, "Should have exactly one tagged_event_record_on_stream node"
)
args = record_on_stream_nodes[0][2]
self.assertTrue(args[0].startswith('graph_'), f"event_tag should start with 'graph_', got {args[0]}")
self.assertIsInstance(args[1], str, "stream_id should be str")
self.assertIsInstance(args[2], str, "device_index should be str")
self.assertIsInstance(args[3], str, "device_type should be str")
self.assertEqual(args[4], True, "created_inside should be True")
return gm
opt_m = torch.compile(cus_func, backend=my_backend, fullgraph=True, dynamic=False)
i = torch.randn([3, 3]).to('npu')
opt_m(i)
@unittest.skipIf(torch.__version__ < "2.6", "torch_npu stream api is unsupported when torch < 2.6")
def test_stream_wait_event(self):
"""Test stream.wait_event(event) is replaced with tagged_event_wait_on_stream"""
def cus_func(t):
default_stream = torch.npu.current_stream()
s1 = torch.npu.Stream()
s2 = torch.npu.Stream()
e1 = torch.npu.Event()
torch.npu.set_stream(s1)
e1.record()
tmp = torch.add(t, 1)
torch.npu.set_stream(s2)
s2.wait_event(e1)
r = torch.add(tmp, 2)
torch.npu.set_stream(default_stream)
return r
def my_backend(gm: torch.fx.GraphModule, example_inputs):
from torchair._acl_concrete_graph.replace_stream_event import replace_stream_event_pass
gm = replace_stream_event_pass(gm)
print(f'after replace graph is : {gm.graph}')
node_info = []
for node in gm.graph.nodes:
if hasattr(node.target, "__name__"):
target_name = node.target.__name__
else:
target_name = str(node.target)
node_info.append((node.name, target_name, node.args))
wait_on_stream_nodes = [n for n in node_info if n[1] == 'tagged_event_wait_on_stream']
self.assertEqual(len(wait_on_stream_nodes), 1, "Should have exactly one tagged_event_wait_on_stream node")
args = wait_on_stream_nodes[0][2]
self.assertTrue(args[0].startswith('graph_'), f"event_tag should start with 'graph_', got {args[0]}")
self.assertIsInstance(args[1], str, "stream_id should be str")
self.assertIsInstance(args[2], str, "device_index should be str")
self.assertIsInstance(args[3], str, "device_type should be str")
self.assertEqual(args[4], True, "created_inside should be True")
return gm
opt_m = torch.compile(cus_func, backend=my_backend, fullgraph=True, dynamic=False)
i = torch.randn([3, 3]).to('npu')
opt_m(i)
@unittest.skipIf(torch.__version__ < "2.6", "torch_npu stream api is unsupported when torch < 2.6")
def test_stream_wait_stream(self):
"""Test stream.wait_stream(other_stream) is replaced with tagged_stream_wait_stream"""
def cus_func(t):
default_stream = torch.npu.current_stream()
s1 = torch.npu.Stream()
s2 = torch.npu.Stream()
torch.npu.set_stream(s1)
tmp = torch.add(t, 1)
torch.npu.set_stream(s2)
s2.wait_stream(s1)
r = torch.add(tmp, 2)
torch.npu.set_stream(default_stream)
return r
def my_backend(gm: torch.fx.GraphModule, example_inputs):
from torchair._acl_concrete_graph.replace_stream_event import replace_stream_event_pass
gm = replace_stream_event_pass(gm)
print(f'after replace graph is : {gm.graph}')
node_info = []
for node in gm.graph.nodes:
if hasattr(node.target, "__name__"):
target_name = node.target.__name__
else:
target_name = str(node.target)
node_info.append((node.name, target_name, node.args))
wait_stream_nodes = [n for n in node_info if n[1] == 'tagged_stream_wait_stream']
self.assertEqual(len(wait_stream_nodes), 1, "Should have exactly one tagged_stream_wait_stream node")
args = wait_stream_nodes[0][2]
self.assertIsInstance(args[0], str, "stream_id should be str")
self.assertIsInstance(args[1], str, "device_index should be str")
self.assertIsInstance(args[2], str, "device_type should be str")
self.assertIsInstance(args[3], str, "other_stream_id should be str")
self.assertIsInstance(args[4], str, "other_device_index should be str")
self.assertIsInstance(args[5], str, "other_device_type should be str")
return gm
opt_m = torch.compile(cus_func, backend=my_backend, fullgraph=True, dynamic=False)
i = torch.randn([3, 3]).to('npu')
opt_m(i)
@unittest.skipIf(torch.__version__ < "2.6", "torch_npu stream api is unsupported when torch < 2.6")
def test_stream_record_event_with_event(self):
"""Test stream.record_event(event) is replaced with tagged_event_record_on_stream"""
def cus_func(t):
default_stream = torch.npu.current_stream()
s1 = torch.npu.Stream()
e1 = torch.npu.Event()
torch.npu.set_stream(s1)
s1.record_event(e1)
tmp = torch.add(t, 1)
torch.npu.set_stream(default_stream)
r = torch.add(tmp, 2)
return r
def my_backend(gm: torch.fx.GraphModule, example_inputs):
from torchair._acl_concrete_graph.replace_stream_event import replace_stream_event_pass
gm = replace_stream_event_pass(gm)
print(f'after replace graph is : {gm.graph}')
node_info = []
for node in gm.graph.nodes:
if hasattr(node.target, "__name__"):
target_name = node.target.__name__
else:
target_name = str(node.target)
node_info.append((node.name, target_name, node.args))
record_on_stream_nodes = [n for n in node_info if n[1] == 'tagged_event_record_on_stream']
self.assertEqual(
len(record_on_stream_nodes), 1, "Should have exactly one tagged_event_record_on_stream node"
)
args = record_on_stream_nodes[0][2]
self.assertTrue(args[0].startswith('graph_'), f"event_tag should start with 'graph_', got {args[0]}")
self.assertIsInstance(args[1], str, "stream_id should be str")
self.assertIsInstance(args[2], str, "device_index should be str")
self.assertIsInstance(args[3], str, "device_type should be str")
self.assertEqual(args[4], True, "created_inside should be True")
return gm
opt_m = torch.compile(cus_func, backend=my_backend, fullgraph=True, dynamic=False)
i = torch.randn([3, 3]).to('npu')
opt_m(i)
@unittest.skipIf(torch.__version__ < "2.6", "torch_npu stream api is unsupported when torch < 2.6")
def test_stream_record_event_without_event(self):
"""Test stream.record_event() (event=None) creates new event and uses tagged_event_record_on_stream"""
def cus_func(t):
default_stream = torch.npu.current_stream()
s1 = torch.npu.Stream()
torch.npu.set_stream(s1)
event = s1.record_event()
tmp = torch.add(t, 1)
torch.npu.set_stream(default_stream)
r = torch.add(tmp, 2)
return r
def my_backend(gm: torch.fx.GraphModule, example_inputs):
from torchair._acl_concrete_graph.replace_stream_event import replace_stream_event_pass
gm = replace_stream_event_pass(gm)
print(f'after replace graph is : {gm.graph}')
node_info = []
for node in gm.graph.nodes:
if hasattr(node.target, "__name__"):
target_name = node.target.__name__
else:
target_name = str(node.target)
node_info.append((node.name, target_name, node.args))
record_on_stream_nodes = [n for n in node_info if n[1] == 'tagged_event_record_on_stream']
self.assertEqual(
len(record_on_stream_nodes), 1, "Should have exactly one tagged_event_record_on_stream node"
)
args = record_on_stream_nodes[0][2]
self.assertTrue(args[0].startswith('graph_'), f"event_tag should start with 'graph_', got {args[0]}")
self.assertIn('_record_event', args[0], f"event_tag should contain '_record_event', got {args[0]}")
self.assertIsInstance(args[1], str, "stream_id should be str")
self.assertIsInstance(args[2], str, "device_index should be str")
self.assertIsInstance(args[3], str, "device_type should be str")
self.assertEqual(args[4], True, "created_inside should be True")
return gm
opt_m = torch.compile(cus_func, backend=my_backend, fullgraph=True, dynamic=False)
i = torch.randn([3, 3]).to('npu')
opt_m(i)
@unittest.skipIf(True, "unsupported until cann support")
def test_aclgraph_with_superkernel(self):
class Module(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, x, y):
torch.npu.super_kernel_scope_begin("sk1")
z = torch.add(x, y)
torch.npu.super_kernel_scope_end("sk1")
return z
npu_config = torchair.CompilerConfig()
npu_config.mode = "reduce-overhead"
npu_config.experimental_config.aclgraph._aclnn_static_shape_kernel = True
npu_config.experimental_config.aclgraph._super_kernel_optimize = True
npu_backend = torchair.get_npu_backend(compiler_config=npu_config)
model = Module()
model_compile = torch.compile(model, backend=npu_backend)
x1 = torch.randn(64, 4, 512, dtype=torch.float16, device='npu')
x2 = torch.randn(64, 512, 128, dtype=torch.float16, device='npu')
z = model_compile(x1, x2)
expected = torch.add(x1, x2)
self.assertTrue(torch.allclose(z, expected, rtol=1e-3, atol=1e-3))
def test_inherited_global_limit_core(self):
config = CompilerConfig()
config.mode = "reduce-overhead"
aclgraph_backend = torchair.get_npu_backend(compiler_config=config)
class Model1(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, in1, in2, in3, in4):
add_result = torch.add(in1, in2)
mm_result = torch.mm(in3, in4)
return mm_result, add_result
x = torch.randn([3, 3], device='npu')
y = torch.randn([3, 3], device='npu')
z = torch.randn([3, 3], device='npu')
w = torch.randn([3, 3], device='npu')
model1 = Model1()
model1 = torch.compile(model1, backend=aclgraph_backend, fullgraph=True, dynamic=False)
with self.assertLogs(logger, level="DEBUG") as cm:
current_stream = torch.npu.current_stream()
torch.npu.set_stream_limit(current_stream, 2, 3)
model1(x, y, z, w)
torch.npu.set_stream_limit(current_stream, 4, 5)
model1(x, y, z, w)
self.assertTrue(
any("The current AclGraph needs to be recaptured" in log for log in cm.output),
f"Expected DEBUG 'The current AclGraph needs to be recaptured'not found in logs: {cm.output}",
)
with self.assertLogs(logger, level="DEBUG") as cm:
stream = torch.npu.Stream()
with torch.npu.stream(stream):
torch.npu.set_stream_limit(stream, 2, 3)
model1(x, y, z, w)
torch.npu.set_stream_limit(stream, 4, 5)
model1(x, y, z, w)
self.assertTrue(
any("The current AclGraph needs to be recaptured" in log for log in cm.output),
f"Expected DEBUG 'The current AclGraph needs to be recaptured'not found in logs: {cm.output}",
)
def test_clone_output_disable_mem_reuse(self):
def f(x):
return torch.add(x, x)
config = CompilerConfig()
config.mode = "reduce-overhead"
config.debug.aclgraph.disable_mempool_reuse_in_same_fx = True
config.debug.aclgraph.enable_output_clone = True
aclgraph_backend = torchair.get_npu_backend(compiler_config=config)
model = torch.compile(f, backend=aclgraph_backend, dynamic=False, fullgraph=True)
x = torch.randn(4, 4, dtype=torch.float32, device='npu')
expected = torch.add(x, x)
output1 = model(x)
output2 = model(x)
self.assertNotEqual(output1.untyped_storage()._cdata, output2.untyped_storage()._cdata)
self.assertTrue(torch.allclose(output1, expected))
self.assertTrue(torch.allclose(output2, expected))
def test_clone_output_single_graph(self):
def f(x):
return torch.add(x, x)
config = CompilerConfig()
config.mode = "reduce-overhead"
config.debug.aclgraph.enable_output_clone = True
aclgraph_backend = torchair.get_npu_backend(compiler_config=config)
model = torch.compile(f, backend=aclgraph_backend, dynamic=False, fullgraph=True)
x = torch.randn(4, 4, dtype=torch.float32, device='npu')
expected = torch.add(x, x)
output1 = model(x)
output2 = model(x)
self.assertNotEqual(output1.untyped_storage()._cdata, output2.untyped_storage()._cdata)
self.assertTrue(torch.allclose(output1, expected))
self.assertTrue(torch.allclose(output2, expected))
def test_clone_output_disabled(self):
def f(x):
return torch.add(x, x)
config = CompilerConfig()
config.mode = "reduce-overhead"
config.debug.aclgraph.enable_output_clone = False
aclgraph_backend = torchair.get_npu_backend(compiler_config=config)
model = torch.compile(f, backend=aclgraph_backend, dynamic=False, fullgraph=True)
x = torch.randn(4, 4, dtype=torch.float32, device='npu')
expected = torch.add(x, x)
output1 = model(x)
output2 = model(x)
self.assertEqual(output1.untyped_storage()._cdata, output2.untyped_storage()._cdata)
self.assertTrue(torch.allclose(output1, expected))
self.assertTrue(torch.allclose(output2, expected))
def patch_dynamo():
from torch._dynamo.variables.user_defined import UserDefinedClassVariable
def patch_user_defined_class_variable():
import functools
original_method = UserDefinedClassVariable._in_graph_classes
@staticmethod
@functools.lru_cache(None)
def patched_in_graph_classes():
result = original_method()
result.add(torch.npu.Event)
result.add(torch.npu.Stream)
return result
UserDefinedClassVariable._in_graph_classes = patched_in_graph_classes
def fake_record_stream(self, s):
"""
let dynamo trace Tensor.record_stream as this emtpy function,
and you can replace it later in your compile backend to an actual function
"""
if isinstance(self, torch._subclasses.fake_tensor.FakeTensor):
return
raise RuntimeError(
"tensor.record_stream is not supported on torch.compile! "
"You should write a pass to replace torch.npu.fake_record_stream to an actual function in FX graph "
"before aot_autograd."
)
def patch_record_stream():
torch.npu.fake_record_stream = fake_record_stream
def method_record_stream(self, s):
tx = torch._dynamo.symbolic_convert.InstructionTranslator.current_tx()
return torch._dynamo.variables.TorchInGraphFunctionVariable(torch.npu.fake_record_stream).call_function(
tx, [self, s], {}
)
torch._dynamo.variables.tensor.TensorVariable.method_record_stream = method_record_stream
def patch_variable_builder():
original_warp = torch._dynamo.variables.builder.VariableBuilder._wrap
def _patch_wrapper(self, value):
if isinstance(value, torch.npu.Event):
self.install_guards(torch._dynamo.guards.GuardBuilder.ID_MATCH)
torch._dynamo.utils.store_user_object_weakref(value)
event_proxy = self.tx.output.create_proxy(
"call_function",
torch._dynamo.utils.get_user_object_from_id,
(id(value),),
{},
)
torch._dynamo.utils.set_example_value(event_proxy.node, value)
out = torch._dynamo.variables.ctx_manager.EventVariable(
event_proxy,
value,
source=self.source,
)
return out
return original_warp(self, value)
torch._dynamo.variables.builder.VariableBuilder._wrap = _patch_wrapper
def patch_builtin_variable():
origin_call_id = torch._dynamo.variables.builtin.BuiltinVariable.call_id
def _wrap_call_id(self, tx, *args):
if torch._dynamo.variables.builtin.istype(args[0], torch._dynamo.variables.ctx_manager.EventVariable):
return torch._dynamo.variables.ConstantVariable.create(id(args[0].value))
return origin_call_id(self, tx, *args)
torch._dynamo.variables.builtin.BuiltinVariable.call_id = _wrap_call_id
patch_user_defined_class_variable()
patch_record_stream()
patch_variable_builder()
patch_builtin_variable()
if __name__ == '__main__':
unittest.main()