import unittest
import sys
from pathlib import Path
from unittest.mock import patch, MagicMock
import dataflow.plugin.torch.torch_plugin as torch_plugin
import dataflow.data_type as dt
import dataflow.utils.utils as utils
import dataflow as df
class TestTorchPlugin(unittest.TestCase):
def setUp(self):
utils.set_running_device_id(0)
return super().setUp()
@staticmethod
def MockTorchType(torch):
torch.float32 = "float32"
torch.float16 = "float16"
torch.bfloat16 = "bfloat16"
torch.int8 = "int8"
torch.int16 = "int16"
torch.uint16 = "uint16"
torch.uint8 = "uint8"
torch.int32 = "int32"
torch.int64 = "int64"
torch.uint32 = "uint32"
torch.uint64 = "uint64"
torch.bool = "bool"
torch.float64 = "float64"
torch.__version__ = "2.5.0"
@patch.dict("sys.modules", {"torch": MagicMock()})
def test_convert_df_to_torch_tensor_dtype(self):
import torch
TestTorchPlugin.MockTorchType(torch)
self.assertEqual(
torch_plugin._convert_df_to_torch_tensor_dtype(dt.DT_FLOAT),
torch.float32,
)
self.assertEqual(
torch_plugin._convert_df_to_torch_tensor_dtype(dt.DT_FLOAT16),
torch.float16,
)
self.assertEqual(
torch_plugin._convert_df_to_torch_tensor_dtype(dt.DT_BF16),
torch.bfloat16,
)
self.assertEqual(
torch_plugin._convert_df_to_torch_tensor_dtype(dt.DT_INT8),
torch.int8,
)
self.assertEqual(
torch_plugin._convert_df_to_torch_tensor_dtype(dt.DT_INT16),
torch.int16,
)
self.assertEqual(
torch_plugin._convert_df_to_torch_tensor_dtype(dt.DT_UINT16),
torch.uint16,
)
self.assertEqual(
torch_plugin._convert_df_to_torch_tensor_dtype(dt.DT_UINT8),
torch.uint8,
)
self.assertEqual(
torch_plugin._convert_df_to_torch_tensor_dtype(dt.DT_INT32),
torch.int32,
)
self.assertEqual(
torch_plugin._convert_df_to_torch_tensor_dtype(dt.DT_INT64),
torch.int64,
)
self.assertEqual(
torch_plugin._convert_df_to_torch_tensor_dtype(dt.DT_UINT32),
torch.uint32,
)
self.assertEqual(
torch_plugin._convert_df_to_torch_tensor_dtype(dt.DT_UINT64),
torch.uint64,
)
self.assertEqual(
torch_plugin._convert_df_to_torch_tensor_dtype(dt.DT_BOOL),
torch.bool,
)
self.assertEqual(
torch_plugin._convert_df_to_torch_tensor_dtype(dt.DT_DOUBLE),
torch.float64,
)
@patch.dict("sys.modules", {"torch": MagicMock()})
def test_convert_torch_to_df_tensor_dtype(self):
import torch
TestTorchPlugin.MockTorchType(torch)
self.assertEqual(
torch_plugin._convert_torch_to_df_tensor_dtype(torch.float32),
dt.DT_FLOAT,
)
self.assertEqual(
torch_plugin._convert_torch_to_df_tensor_dtype(torch.float16),
dt.DT_FLOAT16,
)
self.assertEqual(
torch_plugin._convert_torch_to_df_tensor_dtype(torch.bfloat16),
dt.DT_BF16,
)
self.assertEqual(
torch_plugin._convert_torch_to_df_tensor_dtype(torch.int8),
dt.DT_INT8,
)
self.assertEqual(
torch_plugin._convert_torch_to_df_tensor_dtype(torch.int16),
dt.DT_INT16,
)
self.assertEqual(
torch_plugin._convert_torch_to_df_tensor_dtype(torch.uint16),
dt.DT_UINT16,
)
self.assertEqual(
torch_plugin._convert_torch_to_df_tensor_dtype(torch.uint8),
dt.DT_UINT8,
)
self.assertEqual(
torch_plugin._convert_torch_to_df_tensor_dtype(torch.int32),
dt.DT_INT32,
)
self.assertEqual(
torch_plugin._convert_torch_to_df_tensor_dtype(torch.int64),
dt.DT_INT64,
)
self.assertEqual(
torch_plugin._convert_torch_to_df_tensor_dtype(torch.uint32),
dt.DT_UINT32,
)
self.assertEqual(
torch_plugin._convert_torch_to_df_tensor_dtype(torch.uint64),
dt.DT_UINT64,
)
self.assertEqual(
torch_plugin._convert_torch_to_df_tensor_dtype(torch.bool),
dt.DT_BOOL,
)
self.assertEqual(
torch_plugin._convert_torch_to_df_tensor_dtype(torch.float64),
dt.DT_DOUBLE,
)
@patch.dict(
"sys.modules",
{"torch": MagicMock(), "torch_npu": MagicMock(), "torchair": MagicMock()},
)
def test_prepare_inputs(self):
import torch
import torch_npu
import torchair
TestTorchPlugin.MockTorchType(torch)
torchair.llm_datadist.create_npu_tensors.return_value = [MagicMock()]
import dataflow.flow_func.flowfunc_wrapper as fw
import dataflow.flow_func as ff
orig_context = fw.MetaRunContext()
context = ff.MetaRunContext(orig_context)
msg1 = context.alloc_tensor_msg([2, 3], dt.DT_FLOAT16)
inputs = [msg1._flow_msg]
input_num = 1
ret, input_list = torch_plugin._prepare_inputs(inputs, input_num)
self.assertEqual(ret, torch_plugin.ff.FLOW_FUNC_SUCCESS)
self.assertEqual(len(input_list), 1)
@patch.dict("sys.modules", {"torch": MagicMock()})
def test_prepare_outputs_invalid(self):
import dataflow.flow_func.flowfunc_wrapper as fw
import dataflow.flow_func as ff
orig_context = fw.MetaRunContext()
runtime_context = ff.MetaRunContext(orig_context)
import torch
TestTorchPlugin.MockTorchType(torch)
torch.Tensor = type("MockTensor", (object,), {})
mock_tensor = MagicMock()
mock_tensor.__class__ = torch.Tensor
mock_tensor.dtype = torch.float32
mock_tensor.shape = [2, 3]
outputs = mock_tensor
output_num = 2
ret, runtime_tensor_descs = torch_plugin._prepare_outputs(
runtime_context, outputs, output_num
)
self.assertNotEqual(ret, torch_plugin.ff.FLOW_FUNC_SUCCESS)
@patch.dict("sys.modules", {"torch": MagicMock()})
def test_prepare_outputs_notmatch(self):
import dataflow.flow_func.flowfunc_wrapper as fw
import dataflow.flow_func as ff
orig_context = fw.MetaRunContext()
runtime_context = ff.MetaRunContext(orig_context)
import torch
TestTorchPlugin.MockTorchType(torch)
torch.Tensor = type("MockTensor", (object,), {})
mock_tensor = MagicMock()
mock_tensor.__class__ = torch.Tensor
mock_tensor.dtype = torch.float32
mock_tensor.shape = [2, 3]
outputs = (mock_tensor, mock_tensor)
output_num = 3
ret, runtime_tensor_descs = torch_plugin._prepare_outputs(
runtime_context, outputs, output_num
)
self.assertNotEqual(ret, torch_plugin.ff.FLOW_FUNC_SUCCESS)
@patch.dict("sys.modules", {"torch": MagicMock()})
def test_prepare_output_not_torch_tensor(self):
import dataflow.flow_func.flowfunc_wrapper as fw
import dataflow.flow_func as ff
orig_context = fw.MetaRunContext()
runtime_context = ff.MetaRunContext(orig_context)
import torch
TestTorchPlugin.MockTorchType(torch)
torch.Tensor = type("MockTensor", (object,), {})
outputs = 1
output_num = 1
ret, runtime_tensor_descs = torch_plugin._prepare_outputs(
runtime_context, outputs, output_num
)
self.assertNotEqual(ret, torch_plugin.ff.FLOW_FUNC_SUCCESS)
@patch.dict("sys.modules", {"torch": MagicMock()})
def test_prepare_outputs_single(self):
import dataflow.flow_func.flowfunc_wrapper as fw
import dataflow.flow_func as ff
orig_context = fw.MetaRunContext()
runtime_context = ff.MetaRunContext(orig_context)
import torch
TestTorchPlugin.MockTorchType(torch)
torch.Tensor = type("MockTensor", (object,), {})
mock_tensor = MagicMock()
mock_tensor.__class__ = torch.Tensor
mock_tensor.device.type = "npu"
mock_tensor.dtype = torch.float32
mock_tensor.shape = [2, 3]
outputs = mock_tensor
output_num = 1
ret, runtime_tensor_descs = torch_plugin._prepare_outputs(
runtime_context, outputs, output_num
)
self.assertEqual(ret, torch_plugin.ff.FLOW_FUNC_SUCCESS)
self.assertEqual(len(runtime_tensor_descs), 1)
mock_tensor.device.type = "cpu"
error_outputs = mock_tensor
ret, runtime_tensor_descs = torch_plugin._prepare_outputs(
runtime_context, error_outputs, output_num
)
self.assertNotEqual(ret, torch_plugin.ff.FLOW_FUNC_SUCCESS)
@patch.dict("sys.modules", {"torch": MagicMock()})
def test_prepare_outputs_multi(self):
import dataflow.flow_func.flowfunc_wrapper as fw
import dataflow.flow_func as ff
orig_context = fw.MetaRunContext()
runtime_context = ff.MetaRunContext(orig_context)
import torch
TestTorchPlugin.MockTorchType(torch)
torch.Tensor = type("MockTensor", (object,), {})
mock_tensor = MagicMock()
mock_tensor.__class__ = torch.Tensor
mock_tensor.device.type = "npu"
mock_tensor.dtype = torch.float32
mock_tensor.shape = [2, 3]
outputs = (mock_tensor, mock_tensor, mock_tensor)
output_num = 3
ret, runtime_tensor_descs = torch_plugin._prepare_outputs(
runtime_context, outputs, output_num
)
self.assertEqual(ret, torch_plugin.ff.FLOW_FUNC_SUCCESS)
self.assertEqual(len(runtime_tensor_descs), 1)
mock_tensor.device.type = "cpu"
error_outputs = (mock_tensor, mock_tensor, mock_tensor)
ret, runtime_tensor_descs = torch_plugin._prepare_outputs(
runtime_context, error_outputs, output_num
)
self.assertNotEqual(ret, torch_plugin.ff.FLOW_FUNC_SUCCESS)
@patch.dict(
"sys.modules",
{"torch": MagicMock(), "torch_npu": MagicMock(), "torchair": MagicMock()},
)
def test_dynamo_export_static(self):
import torch
import torch_npu
import torchair
TestTorchPlugin.MockTorchType(torch)
class_ins = MagicMock()
class_ins._decorated_class.return_value = MagicMock()
class_ins._saved_args = ([], {})
input_desc = MagicMock()
input_desc._dtype = dt.DT_FLOAT
input_desc._shape = [1, 2, 3]
input_descs = [input_desc]
workspace_dir = "./test_ws"
torch_plugin._dynamo_export(class_ins, input_descs, workspace_dir)
torchair.dynamo_export.assert_called_once()
_, kwargs = torchair.dynamo_export.call_args
self.assertEqual(kwargs["export_path"], workspace_dir)
self.assertEqual(kwargs["dynamic"], False)
@patch.dict(
"sys.modules",
{"torch": MagicMock(), "torch_npu": MagicMock(), "torchair": MagicMock()},
)
def test_dynamo_export_dynamic(self):
import torch
import torch_npu
import torchair
class_ins = MagicMock()
class_ins._decorated_class.return_value = MagicMock()
class_ins._saved_args = ([], {})
input_desc = MagicMock()
input_desc._dtype = dt.DT_FLOAT
input_desc._shape = [1, -2, 3]
input_descs = [input_desc]
workspace_dir = "./test_ws"
torch_plugin._dynamo_export(class_ins, input_descs, workspace_dir)
torchair.dynamo_export.assert_called_once()
_, kwargs = torchair.dynamo_export.call_args
self.assertEqual(kwargs["export_path"], workspace_dir)
self.assertEqual(kwargs["dynamic"], True)
@patch.dict("sys.modules", {"torch": MagicMock()})
def test_serialize_with_torch_tensor(self):
import torch
TestTorchPlugin.MockTorchType(torch)
torch_tensor = MagicMock()
torch_tensor.device = torch.device("cpu")
torch_tensor.is_contiguous.return_value = True
torch_tensor.numpy.return_value = b"mock_data"
torch_tensor.data_ptr.return_value = 0
torch_tensor.shape = [1]
torch_tensor.dtype = torch.float32
torch_tensor.numel.return_value = 1
torch_tensor.element_size.return_value = 4
serialized = torch_plugin._serialize_with_torch_tensor(torch_tensor)
self.assertIsInstance(serialized, bytes)
@patch.dict("sys.modules", {"torch": MagicMock()})
def test_deserialize_with_torch_tensor(self):
import torch
buffer = bytes(1024)
deserialized = torch_plugin._deserialize_with_torch_tensor(buffer)
torch.frombuffer.assert_called_once()
@patch.dict(
"sys.modules",
{
"torch": MagicMock(),
"torch_npu": MagicMock(),
"torchair": MagicMock(),
"cloudpickle": MagicMock(),
},
)
@patch("dataflow.plugin.torch.torch_plugin.msg_type_register")
def test_func_with_npu_model(self, mock_msg_type_register):
import torch
import torch_npu
import torchair
import cloudpickle
TestTorchPlugin.MockTorchType(torch)
cloudpickle.dumps = MagicMock(return_value=bytes(1024))
@torch_plugin.npu_model(num_returns=2)
def func_with_npu_model(in0, in1):
return in0, in1
data0 = df.FlowData(name="d0")
data1 = df.FlowData(name="d1")
fnode = func_with_npu_model.fnode()
out = fnode(data0, data1)
self.assertEqual(len(out), 2)
self.assertEqual(Path("./func_with_npu_model_fnode_ws").exists(), True)
@patch.dict(
"sys.modules",
{
"torch": MagicMock(),
"torch_npu": MagicMock(),
"torchair": MagicMock(),
"cloudpickle": MagicMock(),
},
)
@patch("dataflow.plugin.torch.torch_plugin.msg_type_register")
def test_class_with_npu_model_1(self, mock_msg_type_register):
import torch
import torch_npu
import torchair
import cloudpickle
TestTorchPlugin.MockTorchType(torch)
cloudpickle.dumps = MagicMock(return_value=bytes(1024))
@torch_plugin.npu_model(optimize_level=1)
class TestNpuModel:
@df.method()
def sub_1(self, in0, in1):
return in0 + in1
data0 = df.FlowData(name="d0")
data1 = df.FlowData(name="d1")
fnode = TestNpuModel.fnode()
out = fnode(data0, data1)
self.assertTrue(isinstance(out, df.dataflow.FlowOutput))
self.assertEqual(Path("./TestNpuModel_fnode_ws").exists(), True)
@patch.dict(
"sys.modules",
{
"torch": MagicMock(),
"torch_npu": MagicMock(),
"torchair": MagicMock(),
"cloudpickle": MagicMock(),
},
)
@patch("dataflow.plugin.torch.torch_plugin.msg_type_register")
def test_class_with_npu_model_2_dynamic(self, mock_msg_type_register):
import torch
import torch_npu
import torchair
import cloudpickle
TestTorchPlugin.MockTorchType(torch)
cloudpickle.dumps = MagicMock(return_value=bytes(1024))
@torch_plugin.npu_model(
optimize_level=2,
input_descs=[
df.TensorDesc(dt.DT_FLOAT, [-1, 1024]),
df.TensorDesc(dt.DT_FLOAT, [-1, 1024]),
],
)
class TestNpuModelLevel2Dynamic:
def __init__(self):
pass
@df.method()
def forward(self, in0, in1):
return in0
workspace_dir = "./TestNpuModelLevel2Dynamic_fnode_ws"
data0 = df.FlowData(name="d0")
data1 = df.FlowData(name="d1")
fnode = TestNpuModelLevel2Dynamic.fnode()
out = fnode(data0, data1)
self.assertTrue(isinstance(out, df.dataflow.FlowOutput))
torchair.dynamo_export.assert_called_once()
_, kwargs = torchair.dynamo_export.call_args
self.assertEqual(kwargs["export_path"], workspace_dir)
self.assertEqual(kwargs["dynamic"], True)
@patch.dict(
"sys.modules",
{
"torch": MagicMock(),
"torch_npu": MagicMock(),
"torchair": MagicMock(),
"cloudpickle": MagicMock(),
},
)
@patch("dataflow.plugin.torch.torch_plugin.msg_type_register")
def test_class_with_npu_model_2_static(self, mock_msg_type_register):
import torch
import torch_npu
import torchair
import cloudpickle
TestTorchPlugin.MockTorchType(torch)
cloudpickle.dumps = MagicMock(return_value=bytes(1024))
@torch_plugin.npu_model(
optimize_level=2,
input_descs=[
df.TensorDesc(dt.DT_FLOAT, [2, 1024]),
df.TensorDesc(dt.DT_FLOAT, [2, 1024]),
],
)
class TestNpuModelLevel2Static:
def __init__(self):
pass
@df.method()
def forward(self, in0, in1):
return in0
workspace_dir = "./TestNpuModelLevel2Static_fnode_ws"
data0 = df.FlowData(name="d0")
data1 = df.FlowData(name="d1")
fnode = TestNpuModelLevel2Static.fnode()
out = fnode(data0, data1)
self.assertTrue(isinstance(out, df.dataflow.FlowOutput))
torchair.dynamo_export.assert_called_once()
_, kwargs = torchair.dynamo_export.call_args
self.assertEqual(kwargs["export_path"], workspace_dir)
self.assertEqual(kwargs["dynamic"], False)
@patch.dict("sys.modules", {"torch": MagicMock()})
@patch("dataflow.plugin.torch.torch_plugin.msg_type_register")
def test_register_torch_tensor(self, mock_msg_type_register):
torch_plugin._register_torch_tensor()
mock_msg_type_register.register.assert_called_once()
@patch.dict(
"sys.modules",
{"torch": MagicMock(), "torch_npu": MagicMock(), "torchair": MagicMock()},
)
@patch("dataflow.plugin.torch.torch_plugin.NpuFunctionProcessPoint")
@patch("dataflow.plugin.torch.torch_plugin.msg_type_register")
def test_make_npu_model_func(
self, mock_msg_type_register, mock_NpuFunctionProcessPoint
):
def func_with_invalid_output_num(in0, in1):
return in0, in1
options = {}
result = torch_plugin._make_npu_model(func_with_invalid_output_num, options)
self.assertEqual(result, mock_NpuFunctionProcessPoint.return_value)
@patch.dict(
"sys.modules",
{"torch": MagicMock(), "torch_npu": MagicMock(), "torchair": MagicMock()},
)
@patch("dataflow.plugin.torch.torch_plugin.NpuActorProcessPoint")
@patch("dataflow.plugin.torch.torch_plugin.msg_type_register")
def test_make_npu_model_class(
self, mock_msg_type_register, mock_NpuActorProcessPoint
):
class TestNpuModel:
def __init__(self):
pass
options = {}
result = torch_plugin._make_npu_model(TestNpuModel, options)
self.assertEqual(result, mock_NpuActorProcessPoint._df_from_class.return_value)
@patch("dataflow.plugin.torch.torch_plugin._make_npu_model")
@patch("dataflow.plugin.torch.torch_plugin.msg_type_register")
def test_npu_model(self, mock_msg_type_register, mock_make_npu_model):
@torch_plugin.npu_model
def mock_function():
pass
mock_make_npu_model.assert_called_once()