b565ec28创建于 2024年11月22日历史提交
import unittest
import torch
import torch.nn as nn
from torch._subclasses.fake_tensor import FakeTensorMode

import torch_npu
from torch_npu.testing.testcase import run_tests
import torchvision_npu
from torchvision_npu.testing.test_deviation_case import TestCase


class TestROIAlign(TestCase):

    def test_roi_align_amp(self):
        torch.manual_seed(1)
        torch.npu.manual_seed(1)
        x = torch.randn((1, 256, 200, 304)).npu().to(torch.float16)
        rois = torch.randn((648, 5)).npu()
        y = torch.ops.torchvision.roi_align(x.to(torch.float32), rois, 0.25, 7, 7, 2, False).to(torch.float16)
        with torch_npu.npu.amp.autocast():
            y_autocast = torch.ops.torchvision.roi_align(x, rois, 0.25, 7, 7, 2, False)
        self.assertEqual(y_autocast.dtype, torch.float16)
        self.assertRtolEqual(y, y_autocast)

    def test_faketensor_amp(self):
        with FakeTensorMode():
            with torch_npu.npu.amp.autocast():
                x = torch.randn((1, 256, 200, 304)).npu().to(torch.float16)
                rois = torch.randn((648, 5)).npu()
                y = torch.ops.torchvision.roi_align(x, rois, 0.25, 7, 7, 2, False)
        self.assertEqual(y.dtype, x.dtype)


if __name__ == '__main__':
    run_tests()