"""
Add validation cases for torch.multiprocessing APIs on NPU:
1. test/test_multiprocessing.py from PyTorch community lacks sufficient API validations, so this file is added.
2. This file validates:
torch.multiprocessing.Array,
torch.multiprocessing.Value,
torch.multiprocessing.Manager,
torch.multiprocessing.Pipe,
torch.multiprocessing.get_start_method,
torch.multiprocessing.set_start_method,
torch.multiprocessing.reductions.init_reductions,
torch.multiprocessing.reductions.reduce_tensor
(extendable).
"""
import os
import sys
import time
import unittest
import torch
import torch.multiprocessing as mp
from torch.testing._internal.common_utils import TestCase, run_tests
import torch_npu
def _worker(queue):
queue.put(mp.get_start_method())
class TestMultiprocessingAPIs(TestCase):
def setUp(self):
self.original_start_method = mp.get_start_method()
def tearDown(self):
if self.original_start_method is not None:
mp.set_start_method(self.original_start_method, force=True)
def test_get_set_start_method(self):
"""Test get_start_method and set_start_method APIs"""
default_method = mp.get_start_method(allow_none=True)
self.assertIn(default_method, [None, 'spawn', 'fork', 'forkserver'])
for method in ['spawn', 'fork', 'forkserver']:
mp.set_start_method(method, force=True)
current_method = mp.get_start_method()
self.assertEqual(current_method, method)
queue = mp.SimpleQueue()
process = mp.Process(target=_worker, args=(queue,))
process.start()
process.join()
self.assertFalse(queue.empty(), "Queue should contain the start method")
child_method = queue.get()
self.assertEqual(child_method, method,
f"Child process should use {method} start method")
self.assertEqual(mp.get_start_method(), method)
def test_value(self):
"""Test Value API"""
def func(v):
v.value += 1
v = mp.Value('i', 0)
self.assertEqual(v.value, 0)
p = mp.Process(target=func, args=(v,))
p.start()
p.join()
self.assertEqual(v.value, 1)
v_float = mp.Value('f', 0.5)
self.assertAlmostEqual(v_float.value, 0.5)
v_double = mp.Value('d', 0.123456789)
self.assertAlmostEqual(v_double.value, 0.123456789)
def test_array(self):
"""Test Array API"""
def func(arr):
for i, _ in enumerate(arr):
arr[i] += 1
arr = mp.Array('i', [0, 1, 2, 3, 4])
self.assertEqual(list(arr), [0, 1, 2, 3, 4])
p = mp.Process(target=func, args=(arr,))
p.start()
p.join()
self.assertEqual(list(arr), [1, 2, 3, 4, 5])
arr_float = mp.Array('f', [0.0, 1.0, 2.0])
self.assertEqual(list(arr_float), [0.0, 1.0, 2.0])
def test_pipe(self):
"""Test Pipe API"""
parent_conn, child_conn = mp.Pipe()
def send_data(conn):
conn.send('Hello from child')
conn.send(42)
conn.send([1, 2, 3])
conn.close()
p = mp.Process(target=send_data, args=(child_conn,))
p.start()
self.assertEqual(parent_conn.recv(), 'Hello from child')
self.assertEqual(parent_conn.recv(), 42)
self.assertEqual(parent_conn.recv(), [1, 2, 3])
p.join()
def send_receive(conn):
msg = conn.recv()
conn.send(msg + ' received')
parent_conn, child_conn = mp.Pipe()
p = mp.Process(target=send_receive, args=(child_conn,))
p.start()
parent_conn.send('Test message')
self.assertEqual(parent_conn.recv(), 'Test message received')
p.join()
def test_manager(self):
"""Test Manager API"""
def modify_list(lst):
lst.append(4)
lst[0] = 100
with mp.Manager() as manager:
shared_list = manager.list([1, 2, 3])
self.assertEqual(list(shared_list), [1, 2, 3])
p = mp.Process(target=modify_list, args=(shared_list,))
p.start()
p.join()
self.assertEqual(list(shared_list), [100, 2, 3, 4])
shared_dict = manager.dict({'a': 1, 'b': 2})
self.assertEqual(dict(shared_dict), {'a': 1, 'b': 2})
shared_value = manager.Value('i', 0)
self.assertEqual(shared_value.value, 0)
shared_namespace = manager.Namespace()
shared_namespace.x = 1
shared_namespace.y = 'test'
self.assertEqual(shared_namespace.x, 1)
self.assertEqual(shared_namespace.y, 'test')
@unittest.skip(
"Skip: pre-existing issue, npu_tensor.cpu() != reconstructed_npu.cpu() "
"after reduce_tensor on ARM CI"
)
def test_reductions(self):
"""Test torch.multiprocessing.reductions.init_reductions and reduce_tensor APIs"""
mp.reductions.init_reductions()
tensor = torch.tensor([1, 2, 3, 4])
reduced = mp.reductions.reduce_tensor(tensor)
self.assertIsInstance(reduced, tuple)
self.assertEqual(len(reduced), 2)
constructor, args = reduced
reconstructed = constructor(*args)
self.assertTrue(torch.equal(tensor, reconstructed))
self.assertEqual(tensor.device, reconstructed.device)
self.assertEqual(tensor.dtype, reconstructed.dtype)
if torch.npu.is_available():
npu_tensor = torch.tensor([1, 2, 3, 4], device='npu:0')
reduced_npu = mp.reductions.reduce_tensor(npu_tensor)
self.assertIsInstance(reduced_npu, tuple)
self.assertEqual(len(reduced_npu), 2)
constructor_npu, args_npu = reduced_npu
reconstructed_npu = constructor_npu(*args_npu)
self.assertTrue(torch.equal(npu_tensor.cpu(), reconstructed_npu.cpu()))
def test_reductions_invalid_input(self):
"""Test reduction APIs with invalid inputs"""
with self.assertRaises(Exception):
mp.reductions.reduce_tensor(None)
with self.assertRaises(Exception):
mp.reductions.reduce_tensor("not a tensor")
mp.reductions.init_reductions()
mp.reductions.init_reductions()
if __name__ == '__main__':
run_tests()