import torch
import torch_npu
from torch_npu.testing.testcase import TestCase, run_tests
from torch_npu.testing.common_utils import create_common_tensor
class TesBatchNms(TestCase):
def test_batch_nms_shape_format(self):
boxes = torch.randn(8, 4, 1, 4).npu()
scores = torch.randn(8, 4, 1).npu()
boxes_fp16 = boxes.half()
scores_fp16 = scores.half()
nmsed_boxes, nmsed_scores, nmsed_classes, nmsed_num = torch_npu.npu_batch_nms(
boxes, scores, 0.3, 0.5, 4, 4
)
boxes1, scores1, classes1, num1 = torch_npu.npu_batch_nms(
boxes_fp16, scores_fp16, 0.3, 0.5, 4, 4
)
expedt_nmsed_classes = torch.tensor(
[
[0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000],
],
dtype=torch.float32,
)
self.assertRtolEqual(expedt_nmsed_classes, nmsed_classes.cpu())
self.assertRtolEqual(expedt_nmsed_classes.half(), classes1.cpu())
if __name__ == "__main__":
run_tests()