import torch
import torch_npu
from torch_npu.testing.testcase import run_tests, TestCase
from torch_npu.testing.common_distributed import skipIfUnsupportMultiNPU
import torchvision
from torchvision import ops
import torchvision_npu
class TestNms(TestCase):
@skipIfUnsupportMultiNPU(2)
def test_nms_multidevice(self):
boxes = torch.tensor([[285.3538, 185.5758, 1193.5110, 851.4551],
[285.1472, 188.7374, 1192.4984, 851.0669],
[279.2440, 197.9812, 1189.4746, 849.2019]]).to('npu:1')
scores = torch.tensor([0.6370, 0.7569, 0.3966]).to('npu:1')
iou_thres = 0.2
cpu_res = ops.nms(boxes.cpu(), scores.cpu(), iou_thres)
npu_res = ops.nms(boxes, scores, iou_thres)
self.assertRtolEqual(cpu_res, npu_res)
if __name__ == '__main__':
run_tests()