import os
os.environ['TNG_LOG_LEVEL'] = '0'
import torch
import torchair
from torchair.ge._ge_graph import GeGraph
from torchair._ge_concrete_graph.fx2ge_converter import GeConcreteGraph
import logging
from torchair.core.utils import logger
logger.setLevel(logging.DEBUG)
import unittest
config = torchair.CompilerConfig()
npu_backend = torchair.get_npu_backend(compiler_config=config)
m = torch.library.Library("npu", "FRAGMENT")
m.define("my_inplace_auto_kwargs_str(Tensor(a!) x, Tensor y, *, str alpha='alpha') -> ()")
@torch.library.impl(m, "my_inplace_auto_kwargs_str", "Meta")
def my_inplace_meta(x, y, alpha='alpha'):
pass
class TestCustomOps(unittest.TestCase):
def setUp(self) -> None:
self.call_bak = GeConcreteGraph.__call__
return super().setUp()
def tearDown(self) -> None:
GeConcreteGraph.__call__ = self.call_bak
return super().tearDown()
@unittest.skipIf(torch.__version__ < "2.2", "torch._auto_functionalize is unsupported when torch < 2.2")
def test_auto_functionalize_multi_ops_in_converter(self):
m.define("multi_ops_in_converter(Tensor(a!) x, Tensor y) -> Tensor")
@torch.library.impl(m, "multi_ops_in_converter", "Meta")
def my_inplace_meta(x, y):
return torch.empty_like(y)
@torchair.register_fx_node_ge_converter(torch.ops.npu.multi_ops_in_converter.default)
def converter_multi_ops_in_converter(x, y, meta_outputs = None):
tmp = torchair.ge.custom_op(
"MultiInplace1",
inputs={
"x": x,
"y": y,
},
outputs=['x', 'z']
)
out = torchair.ge.custom_op(
"MultiInplace2",
inputs={
"x": tmp[0],
"p": tmp[1],
},
outputs=['x', 'q']
)
return out[1]
def cus_func(x, y):
o2 = torch.ops.npu.multi_ops_in_converter(x, y)
return o2
def warp_concrete_graph():
def wrapper_call(func):
def wrapper(*args, **kwargs):
assert len(args) > 0
geGraph: GeGraph = args[0]._graph
op_name_dict = {op_node.name : op_node.input for op_node in geGraph.op}
print(f'op_name_dict is : {op_name_dict}')
self.assertIn("TensorMove", op_name_dict)
self.assertIn("MultiInplace1", op_name_dict)
self.assertIn("MultiInplace2", op_name_dict)
self.assertIn("Identity", op_name_dict)
self.assertIn("TensorMove:0", op_name_dict["MultiInplace1"])
self.assertIn("MultiInplace1:0", op_name_dict["MultiInplace2"])
self.assertIn("MultiInplace2:-1", op_name_dict["Identity"])
self.assertIn("TensorMove:0", op_name_dict["Identity"])
self.assertIn("Identity:0", op_name_dict["NetOutput"])
ret = func(*args, **kwargs)
return ret
return wrapper
GeConcreteGraph.__call__ = wrapper_call(GeConcreteGraph.__call__)
warp_concrete_graph()
compile_func = torch.compile(cus_func, backend=npu_backend, fullgraph=True, dynamic=False)
input1 = torch.ones(2, 2)
input2 = torch.ones(2, 1)
with self.assertRaises(RuntimeError) as context:
out = compile_func(input1, input2)
self.assertTrue("Assert outputs.empty()" in str(context.exception))
@unittest.skipIf(torch.__version__ < "2.2", "torch._auto_functionalize is unsupported when torch < 2.2")
def test_auto_functionalize_as_stride(self):
m.define("my_inplace_auto2(Tensor(a!) x, Tensor y) -> Tensor")
@torch.library.impl(m, "my_inplace_auto2", "Meta")
def my_inplace_meta(x, y):
return torch.empty_like(y)
@torchair.register_fx_node_ge_converter(torch.ops.npu.my_inplace_auto2.default)
def converter_npu_add_custom(x, y, meta_outputs = None):
out = torchair.ge.custom_op(
"MyInplaceAuto2",
inputs={
"x": x,
"y": y,
},
outputs=['x', 'z']
)
return out[1]
def cus_func(x, y):
add0 = torch.add(x, 1)
slice = add0[:, 1:]
o2 = torch.ops.npu.my_inplace_auto2(slice, y)
add1 = torch.add(slice, 1)
return add1, o2
def warp_concrete_graph():
def wrapper_call(func):
def wrapper(*args, **kwargs):
assert len(args) > 0
geGraph: GeGraph = args[0]._graph
op_name_dict = {op_node.name : op_node.input for op_node in geGraph.op}
self.assertIn("TensorMove", op_name_dict)
self.assertIn("AsStrided", op_name_dict)
self.assertIn("TensorMove:0", op_name_dict["AsStrided"])
self.assertIn("MyInplaceAuto2", op_name_dict)
self.assertIn("ViewCopy", op_name_dict)
self.assertIn("MyInplaceAuto2:-1", op_name_dict["ViewCopy"])
self.assertIn("TensorMove:0", op_name_dict["ViewCopy"])
self.assertIn("StridedSliceV2", op_name_dict)
self.assertIn("ViewCopy:0", op_name_dict["StridedSliceV2"])
ret = func(*args, **kwargs)
return ret
return wrapper
GeConcreteGraph.__call__ = wrapper_call(GeConcreteGraph.__call__)
warp_concrete_graph()
compile_func = torch.compile(cus_func, backend=npu_backend, fullgraph=True, dynamic=False)
input1 = torch.ones(2, 2)
input2 = torch.ones(2, 1)
out = compile_func(input1, input2)
@unittest.skipIf(torch.__version__ < "2.2", "torch._auto_functionalize is unsupported when torch < 2.2")
def test_auto_functionalize_not_view(self):
m.define("my_inplace_auto1(Tensor(a!) x, Tensor y) -> Tensor")
@torch.library.impl(m, "my_inplace_auto1", "Meta")
def my_inplace_meta(x, y):
return torch.empty_like(y)
@torchair.register_fx_node_ge_converter(torch.ops.npu.my_inplace_auto1.default)
def converter_npu_add_custom(x, y, meta_outputs = None):
out = torchair.ge.custom_op(
"MyInplaceAuto1",
inputs={
"x": x,
"y": y,
},
outputs=['x', 'z']
)
return out[1]
def cus_func(x, y):
add0 = torch.add(x, 1)
o2 = torch.ops.npu.my_inplace_auto1(add0, y)
add1 = torch.add(add0, 1)
return add1, o2
def warp_concrete_graph():
def wrapper_call(func):
def wrapper(*args, **kwargs):
assert len(args) > 0
geGraph: GeGraph = args[0]._graph
op_name_dict = {op_node.name : op_node.input for op_node in geGraph.op}
self.assertIn("TensorMove", op_name_dict)
self.assertIn("MyInplaceAuto1", op_name_dict)
self.assertIn("MyInplaceAuto1:0", op_name_dict["Add_1"])
ret = func(*args, **kwargs)
return ret
return wrapper
GeConcreteGraph.__call__ = wrapper_call(GeConcreteGraph.__call__)
warp_concrete_graph()
compile_func = torch.compile(cus_func, backend=npu_backend, fullgraph=True, dynamic=False)
input1 = torch.ones(2, 2)
input2 = torch.ones(2, 1)
out = compile_func(input1, input2)
@unittest.skipIf(torch.__version__ < "2.2", "torch._auto_functionalize is unsupported when torch < 2.2")
def test_auto_functionalize_no_output(self):
m.define("my_inplace_auto_no_output(Tensor(a!) x, Tensor y) -> ()")
@torch.library.impl(m, "my_inplace_auto_no_output", "Meta")
def my_inplace_meta(x, y):
pass
def cus_func(x, y):
add0 = torch.add(x, 1)
o2 = torch.ops.npu.my_inplace_auto_no_output(add0, y)
add1 = torch.add(add0, 1)
return add1, o2
def warp_concrete_graph():
def wrapper_call(func):
def wrapper(*args, **kwargs):
assert len(args) > 0
geGraph: GeGraph = args[0]._graph
op_name_dict = {op_node.name: op_node.input for op_node in geGraph.op}
self.assertIn("TensorMove", op_name_dict)
self.assertIn("MyInplaceAutoNoOutput", op_name_dict)
self.assertIn("TensorMove:0", op_name_dict["MyInplaceAutoNoOutput"])
ret = func(*args, **kwargs)
return ret
return wrapper
GeConcreteGraph.__call__ = wrapper_call(GeConcreteGraph.__call__)
warp_concrete_graph()
compile_func = torch.compile(cus_func, backend=npu_backend, fullgraph=True, dynamic=False)
input1 = torch.ones(2, 2)
input2 = torch.ones(2, 1)
out = compile_func(input1, input2)
@unittest.skipIf(torch.__version__ < "2.2", "torch._auto_functionalize is unsupported when torch < 2.2")
def test_auto_functionalize_two_inplace(self):
m.define("my_two_inplace(Tensor x, Tensor wkv, Tensor wgate, Tensor(a!) kv_state, Tensor(b!) score_state,"
"Tensor ape, Tensor norm_weight, Tensor rope_sin, Tensor rope_cos) -> (Tensor)")
@torch.library.impl(m, "my_two_inplace", "Meta")
def my_inplace_meta(x, wkv, wgate, kv_state, score_state, ape, norm_weight, rope_sin, rope_cos):
return torch.empty_like(x)
def cus_func(x, wkv, wgate, kv_state, score_state, ape, norm_weight, rope_sin, rope_cos):
add0 = torch.add(x, 1)
o2 = torch.ops.npu.my_two_inplace(x, wkv, wgate, kv_state, score_state, ape, norm_weight, rope_sin, rope_cos)
add1 = torch.add(add0, 1)
return add1, o2
def warp_concrete_graph():
def wrapper_call(func):
def wrapper(*args, **kwargs):
assert len(args) > 0
geGraph: GeGraph = args[0]._graph
op_name_dict = {op_node.name: op_node.input for op_node in geGraph.op}
print(f"---> op_name_dict: {op_name_dict}")
self.assertIn("TensorMove", op_name_dict)
self.assertIn("TensorMove_1", op_name_dict)
self.assertIn("MyTwoInplace", op_name_dict)
self.assertIn("TensorMove:0", op_name_dict["MyTwoInplace"])
self.assertIn("TensorMove_1:0", op_name_dict["MyTwoInplace"])
ret = func(*args, **kwargs)
return ret
return wrapper
GeConcreteGraph.__call__ = wrapper_call(GeConcreteGraph.__call__)
warp_concrete_graph()
compile_func = torch.compile(cus_func, backend=npu_backend, fullgraph=True, dynamic=False)
input1 = torch.ones(2, 2)
input2 = torch.ones(2, 1)
input3 = torch.ones(2, 1)
with self.assertRaises(RuntimeError) as context:
out = compile_func(input1, input1, input1, input2, input3, input1, input1, input1, input1)
self.assertTrue("Assert outputs.empty()" in str(context.exception))
@unittest.skipIf(torch.__version__ < "2.2", "torch._auto_functionalize is unsupported when torch < 2.2")
def test_infer_symbol_with_auto_functionalize(self):
m.define("my_op_inplace_z(Tensor(a!) x, Tensor y) -> Tensor z")
@torch.library.impl(m, "my_op_inplace_z", "Meta")
def my_op_meta(x, y):
size_y_0 = list(y.shape)[0] * 2
size_y_1 = list(y.shape)[1] // 2
out = torch.empty((size_y_0, size_y_1), dtype=y.dtype, device=y.device)
return out
def cus_func(x, y):
return torch.ops.npu.my_op_inplace_z(x, y)
def warp_concrete_graph():
def wrapper_call(func):
def wrapper(*args, **kwargs):
assert len(args) > 0
geGraph: GeGraph = args[0]._graph
import json
import ast
self.assertTrue("MyOpInplaceZ" in [op.name for op in geGraph.op])
for op in geGraph.op:
if op.name == "MyOpInplaceZ":
inference_rule = json.loads(op.attr["_inference_rule"].s)
self.assertEqual(inference_rule["shape"]["inputs"][0][0], "s2")
self.assertEqual(inference_rule["shape"]["inputs"][0][1], "s3")
self.assertEqual(inference_rule["shape"]["inputs"][1][0], "s0")
self.assertEqual(inference_rule["shape"]["inputs"][1][1], "s1")
is_high_python_version = hasattr(ast, 'unparse')
self.assertEqual(inference_rule["shape"]["outputs"][0][0], "s2")
self.assertEqual(inference_rule["shape"]["outputs"][0][1], "s3")
s2_out = "2 * s0" if is_high_python_version else "(2*s0)"
self.assertEqual(inference_rule["shape"]["outputs"][1][0], s2_out)
self.assertEqual(inference_rule["shape"]["outputs"][1][1], "Floor(Div(s1, 2))")
self.assertEqual(inference_rule["dtype"][0], 3)
self.assertEqual(inference_rule["dtype"][1], 0)
ret = func(*args, **kwargs)
return ret
return wrapper
GeConcreteGraph.__call__ = wrapper_call(GeConcreteGraph.__call__)
warp_concrete_graph()
compile_func = torch.compile(cus_func, backend=npu_backend, fullgraph=True, dynamic=True)
input1 = torch.ones((4, 4), dtype=torch.int32)
input2 = torch.ones((4, 4), dtype=torch.float)
with self.assertRaises(RuntimeError) as context:
out = compile_func(input1, input2)
self.assertTrue("Assert outputs.empty()" in str(context.exception))
def test_auto_functionalize_kwargs_int_with_input(self):
m.define("my_inplace_auto_kwargs_int(Tensor(a!) x, Tensor y, *, int alpha=1) -> ()")
@torch.library.impl(m, "my_inplace_auto_kwargs_int", "Meta")
def my_inplace_meta(x, y, alpha=1):
pass
def cus_func(x, y):
add0 = torch.add(x, 1)
o2 = torch.ops.npu.my_inplace_auto_kwargs_int(add0, y, alpha=2)
add1 = torch.add(add0, 1)
return add1, o2
def warp_concrete_graph():
def wrapper_call(func):
def wrapper(*args, **kwargs):
assert len(args) > 0
geGraph: GeGraph = args[0]._graph
op_name_dict = {op_node.name: op_node.input for op_node in geGraph.op}
self.assertIn("TensorMove", op_name_dict)
self.assertIn("MyInplaceAutoKwargsInt", op_name_dict)
self.assertIn("TensorMove:0", op_name_dict["MyInplaceAutoKwargsInt"])
ret = func(*args, **kwargs)
return ret
return wrapper
GeConcreteGraph.__call__ = wrapper_call(GeConcreteGraph.__call__)
warp_concrete_graph()
compile_func = torch.compile(cus_func, backend=npu_backend, fullgraph=True, dynamic=False)
input1 = torch.ones(2, 2)
input2 = torch.ones(2, 1)
out = compile_func(input1, input2)
def test_auto_functionalize_kwargs_str(self):
def cus_func(x, y):
add0 = torch.add(x, 1)
o2 = torch.ops.npu.my_inplace_auto_kwargs_str(add0, y, alpha='beta')
add1 = torch.add(add0, 1)
return add1, o2
def warp_concrete_graph():
def wrapper_call(func):
def wrapper(*args, **kwargs):
assert len(args) > 0
geGraph: GeGraph = args[0]._graph
op_name_dict = {op_node.name: op_node.input for op_node in geGraph.op}
self.assertIn("TensorMove", op_name_dict)
self.assertIn("MyInplaceAutoKwargsStr", op_name_dict)
self.assertIn("TensorMove:0", op_name_dict["MyInplaceAutoKwargsStr"])
ret = func(*args, **kwargs)
return ret
return wrapper
GeConcreteGraph.__call__ = wrapper_call(GeConcreteGraph.__call__)
warp_concrete_graph()
compile_func = torch.compile(cus_func, backend=npu_backend, fullgraph=True, dynamic=False)
input1 = torch.ones(2, 2)
input2 = torch.ones(2, 1)
out = compile_func(input1, input2)
def test_auto_functionalize_kwargs_str_without_input(self):
def cus_func(x, y):
add0 = torch.add(x, 1)
o2 = torch.ops.npu.my_inplace_auto_kwargs_str(add0, y)
add1 = torch.add(add0, 1)
return add1, o2
def warp_concrete_graph():
def wrapper_call(func):
def wrapper(*args, **kwargs):
assert len(args) > 0
geGraph: GeGraph = args[0]._graph
op_name_dict = {op_node.name: op_node.input for op_node in geGraph.op}
self.assertIn("TensorMove", op_name_dict)
self.assertIn("MyInplaceAutoKwargsStr", op_name_dict)
self.assertIn("TensorMove:0", op_name_dict["MyInplaceAutoKwargsStr"])
ret = func(*args, **kwargs)
return ret
return wrapper
GeConcreteGraph.__call__ = wrapper_call(GeConcreteGraph.__call__)
warp_concrete_graph()
compile_func = torch.compile(cus_func, backend=npu_backend, fullgraph=True, dynamic=False)
input1 = torch.ones(2, 2)
input2 = torch.ones(2, 1)
out = compile_func(input1, input2)
def test_auto_functionalize_optional_input(self):
m.define("my_inplace_auto_option_input(Tensor(a!) x, Tensor y, *, Tensor?z=None, str alpha='alpha') -> ()")
@torch.library.impl(m, "my_inplace_auto_option_input", "Meta")
def my_inplace_meta(x, y, z=None, alpha='alpha'):
pass
def cus_func(x, y):
add0 = torch.add(x, 1)
o2 = torch.ops.npu.my_inplace_auto_option_input(add0, y)
add1 = torch.add(add0, 1)
return add1, o2
def warp_concrete_graph():
def wrapper_call(func):
def wrapper(*args, **kwargs):
assert len(args) > 0
geGraph: GeGraph = args[0]._graph
op_name_dict = {op_node.name: op_node.input for op_node in geGraph.op}
self.assertIn("TensorMove", op_name_dict)
self.assertIn("MyInplaceAutoOptionInput", op_name_dict)
self.assertIn("TensorMove:0", op_name_dict["MyInplaceAutoOptionInput"])
ret = func(*args, **kwargs)
return ret
return wrapper
GeConcreteGraph.__call__ = wrapper_call(GeConcreteGraph.__call__)
warp_concrete_graph()
compile_func = torch.compile(cus_func, backend=npu_backend, fullgraph=True, dynamic=False)
input1 = torch.ones(2, 2)
input2 = torch.ones(2, 1)
out = compile_func(input1, input2)
if __name__ == "__main__":
unittest.main()