import os
import time
import copy
import torch
import torch.nn as nn
import torch.optim as optim
import torch_npu
from torch_npu.testing.testcase import TestCase, run_tests
from torch_npu.utils._path_manager import PathManager
class TestAsyncSave(TestCase):
test_save_path = os.path.join(
os.path.realpath(os.path.dirname(__file__)), "test_save_async")
@classmethod
def setUpClass(cls):
PathManager.make_dir_safety(TestAsyncSave.test_save_path)
@classmethod
def tearDownClass(cls):
PathManager.remove_path_safety(TestAsyncSave.test_save_path)
def wait_for_save_completion(self, file_path, timeout_sec=60, poll_interval_sec=0.5):
start_time = time.time()
while time.time() - start_time < timeout_sec:
if os.path.exists(file_path):
current_size = os.path.getsize(file_path)
time.sleep(poll_interval_sec)
new_size = os.path.getsize(file_path)
if current_size == new_size:
return True
else:
time.sleep(poll_interval_sec)
return False
def test_save_async_tensor(self):
save_tensor = torch.rand(1024, dtype=torch.float32).npu()
async_save_path = os.path.join(TestAsyncSave.test_save_path, "async_save_tensor.pt")
torch_npu.utils.save_async(save_tensor, async_save_path)
if self.wait_for_save_completion(async_save_path):
tensor_async = torch.load(async_save_path, weights_only=False)
self.assertEqual(tensor_async, save_tensor)
else:
self.assertTrue(False, f"{async_save_path} is not exist!")
def test_save_async(self):
loss1 = [1.6099495, 1.6099086, 1.6098710]
loss2 = []
model_list = []
checkpoint_list = []
model_origin = nn.Sequential(
nn.Linear(100, 50),
nn.ReLU(),
nn.Linear(50, 20),
nn.ReLU(),
nn.Linear(20, 5),
nn.ReLU()
)
input_data = torch.ones(6400, 100).npu()
labels = torch.arange(5).repeat(1280).npu()
criterion = nn.CrossEntropyLoss()
model = model_origin.npu()
optimerizer = optim.SGD(model.parameters(), lr=0.1)
for step in range(3):
outputs = model(input_data)
loss = criterion(outputs, labels)
optimerizer.zero_grad()
loss.backward()
optimerizer.step()
loss2.append(loss)
checkpoint = {
"model": model.state_dict(),
"optimizer": optimerizer.state_dict()
}
checkpoint_list.append(copy.deepcopy(checkpoint))
model_list.append(copy.deepcopy(model))
checkpoint_async_path = os.path.join(TestAsyncSave.test_save_path, f"checkpoint_async_{step}.path")
model_async_path = os.path.join(TestAsyncSave.test_save_path, f"model_async_{step}.path")
torch_npu.utils.save_async(checkpoint, checkpoint_async_path, model=model)
torch_npu.utils.save_async(model, model_async_path, model=model)
for i in range(3):
self.assertEqual(loss1[i], loss2[i].item())
checkpoint_async_path = os.path.join(TestAsyncSave.test_save_path, f"checkpoint_async_{i}.path")
if self.wait_for_save_completion(checkpoint_async_path):
checkpoint_async = torch.load(checkpoint_async_path, weights_only=False)
self.assertEqual(checkpoint_list[i], checkpoint_async, prec=2e-3)
else:
self.assertTrue(False, f"{checkpoint_async_path} is not exist!")
model_async_path = os.path.join(TestAsyncSave.test_save_path, f"model_async_{i}.path")
if self.wait_for_save_completion(model_async_path):
model_async = torch.load(model_async_path, weights_only=False)
else:
self.assertTrue(False, f"{model_async_path} is not exist!")
state_dict_sync = model_list[i].state_dict()
state_dict_async = model_async.state_dict()
key_sync = sorted(state_dict_sync.keys())
key_async = sorted(state_dict_async.keys())
self.assertEqual(key_sync, key_async)
for key in key_async:
self.assertEqual(state_dict_async[key], state_dict_sync[key], prec=2e-3)
if __name__ == '__main__':
torch.npu.set_device(0)
run_tests()