import threading
import torch
from torch.testing._internal.common_utils import TestCase, run_tests
import torch_npu
from torch_npu.testing.common_distributed import skipIfUnsupportMultiNPU
class TestPinMemory(TestCase):
@skipIfUnsupportMultiNPU(2)
def test_pin_memory(self):
torch.npu.set_device(1)
def worker_function():
torch.npu.set_device(0)
t = threading.Thread(target=worker_function)
t.start()
t.join()
device = torch.npu.current_device()
self.assertEqual(device, 1)
pinmemory_tensor = torch.empty(32, pin_memory=True)
device = torch.npu.current_device()
self.assertEqual(device, 1)
if __name__ == "__main__":
run_tests()