import os
import unittest
import shutil
from pathlib import Path
from unittest.mock import patch
import torch
import torch_npu
class StaticKernelBlacklistTest(unittest.TestCase):
"""Test static kernel blacklist functionality."""
def setUp(self):
self.kernel_build_dir = "./static_kernel_blacklist_test_dir"
if os.path.exists(self.kernel_build_dir):
shutil.rmtree(self.kernel_build_dir)
os.makedirs(self.kernel_build_dir, exist_ok=True)
def tearDown(self):
if os.path.exists(self.kernel_build_dir):
shutil.rmtree(self.kernel_build_dir)
def test_static_kernel_blacklist_for_add_operator(self):
"""
End-to-end test for static kernel blacklist functionality.
This test:
1. Creates a model with add, sub, mul operators
2. Adds 'Add' to the static kernel blacklist
3. Compiles the model with static kernel enabled using npugraph_ex backend
4. Verifies that Add operator JSON is in the blacklist folder
5. Verifies that Add operator JSON is NOT in the selected folder
"""
from npugraph_ex._acl_concrete_graph.static_kernel import (
_set_static_kernel_blacklist,
compile_static_kernel as original_compile_static_kernel
)
class ModelWithAddSubMul(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear = torch.nn.Linear(16, 16)
def forward(self, x):
add_result = torch.add(x, x)
sub_result = torch.sub(add_result, x)
mul_result = torch.mul(sub_result, 2.0)
return mul_result
_set_static_kernel_blacklist(["Add"])
kernel_build_dir = self.kernel_build_dir
def patched_compile_static_kernel(fx_func, *args, **kwargs):
kwargs['build_dir'] = kernel_build_dir
return original_compile_static_kernel(fx_func, *args, **kwargs)
try:
with patch('npugraph_ex._acl_concrete_graph.acl_graph.compile_static_kernel',
side_effect=patched_compile_static_kernel):
model = ModelWithAddSubMul().npu()
options = {"static_kernel_compile": True}
compiled_model = torch.compile(model, backend="npugraph_ex", options=options,
fullgraph=True, dynamic=False)
input_tensor = torch.randn(4, 16, dtype=torch.float16).npu()
output = compiled_model(input_tensor)
kernel_build_path = Path(self.kernel_build_dir)
self.assertTrue(kernel_build_path.exists(), f"Kernel build directory {self.kernel_build_dir} does not exist")
outputs_dirs = [d for d in kernel_build_path.iterdir()
if d.is_dir() and d.name.endswith("_outputs") and d.name.startswith("ts")]
if outputs_dirs:
outputs_dir = outputs_dirs[0]
blacklist_dirs = [d for d in outputs_dir.iterdir()
if d.is_dir() and "blacklist" in d.name]
selected_dirs = [d for d in outputs_dir.iterdir()
if d.is_dir() and "selected" in d.name]
if blacklist_dirs:
blacklist_dir = blacklist_dirs[0]
blacklist_json_files = list(blacklist_dir.glob("*.json"))
add_in_blacklist = any("Add" in f.name or "add" in f.name.lower()
for f in blacklist_json_files)
if add_in_blacklist:
print(f"[PASS] Add operator JSON found in blacklist directory: {blacklist_dir}")
for f in blacklist_json_files:
if "Add" in f.name or "add" in f.name.lower():
print(f" - Blacklisted JSON: {f.name}")
else:
print(f"[INFO] Blacklist directory exists but no Add JSON found. Files: {[f.name for f in blacklist_json_files]}")
if selected_dirs:
selected_dir = selected_dirs[0]
selected_json_files = list(selected_dir.glob("*.json"))
add_in_selected = any("Add" in f.name or "add" in f.name.lower()
for f in selected_json_files)
self.assertFalse(add_in_selected,
f"Add operator JSON should NOT be in selected directory. "
f"Found Add-related files: {[f.name for f in selected_json_files if 'Add' in f.name or 'add' in f.name.lower()]}")
print(f"[PASS] Add operator JSON NOT found in selected directory: {selected_dir}")
print(f" - Selected JSON files: {[f.name for f in selected_json_files]}")
if selected_dirs:
selected_dir = selected_dirs[0]
selected_json_files = list(selected_dir.glob("*.json"))
sub_in_selected = any("Sub" in f.name or "sub" in f.name.lower()
for f in selected_json_files)
mul_in_selected = any("Mul" in f.name or "mul" in f.name.lower()
for f in selected_json_files)
print(f"[INFO] Sub in selected: {sub_in_selected}")
print(f"[INFO] Mul in selected: {mul_in_selected}")
finally:
_set_static_kernel_blacklist([])
if __name__ == '__main__':
unittest.main()