"""
Add validation cases for torch.nn.models APIs on NPU:
1. test/test_modules.py from PyTorch community lacks sufficient API validations, so this file is added.
2. This file validates torch.nn.Module, torch.nn.Module.state_dict, torch.nn.ModuleDict, torch.nn.ModuleList (extendable).
"""
import torch
import torch.nn as nn
from torch.testing._internal.common_utils import run_tests, TestCase
import torch_npu
device = torch.device("npu:0")
class TestNPUModuleAPIs(TestCase):
def test_module_device_consistency(self):
"""验证Module设备迁移后参数/缓冲区设备一致"""
m = nn.Sequential(nn.Linear(10, 20), nn.ReLU(), nn.BatchNorm1d(20)).to(device)
for p in m.parameters():
self.assertEqual(p.device, device)
for b in m.buffers():
self.assertEqual(b.device, device)
def test_module_state_dict(self):
"""验证state_dict保存/加载后状态一致"""
base = nn.Sequential(nn.Linear(10, 20), nn.ReLU(), nn.BatchNorm1d(20))
m = nn.Module()
m.dict = nn.ModuleDict({"linear": nn.Linear(10, 20), "base": base})
m.list = nn.ModuleList([nn.Linear(20, 30), base])
m.to(device)
sd = m.state_dict()
torch.npu.synchronize()
for v in sd.values():
self.assertEqual(v.device, device)
base2 = nn.Sequential(nn.Linear(10, 20), nn.ReLU(), nn.BatchNorm1d(20))
m2 = nn.Module()
m2.dict = nn.ModuleDict({"linear": nn.Linear(10, 20), "base": base2})
m2.list = nn.ModuleList([nn.Linear(20, 30), base2])
m2.load_state_dict(sd)
m2.to(device)
torch.npu.synchronize()
for (n1, p1), (n2, p2) in zip(m.named_parameters(), m2.named_parameters()):
self.assertTrue(torch.allclose(p1, p2))
def test_moduledict_operations(self):
"""验证ModuleDict增删/索引/遍历"""
m = nn.ModuleDict({"a": nn.Linear(10, 20).to(device)})
self.assertIn("a", m)
m["b"] = nn.Linear(20, 30).to(device)
self.assertEqual(m["b"].weight.device, device)
del m["b"]
self.assertNotIn("b", m)
for sub in m.values():
for p in sub.parameters():
self.assertEqual(p.device, device)
def test_modulelist_operations(self):
"""验证ModuleList索引/新增/删除/遍历"""
m = nn.ModuleList([nn.Linear(20, 30).to(device), nn.BatchNorm1d(30).to(device)])
self.assertEqual(m[0].weight.device, device)
self.assertEqual(m[1].running_mean.device, device)
m.append(nn.Linear(30, 40).to(device))
m.insert(0, nn.Linear(10, 20).to(device))
m.pop(-1)
m.pop(0)
self.assertEqual(len(m), 2)
for sub in m:
for p in sub.parameters():
self.assertEqual(p.device, device)
if __name__ == "__main__":
run_tests()