import numpy as np
import torch
from data_cache import golden_data_cache
from torch_npu.testing.testcase import TestCase, run_tests
import mx_driving
class TestCalAnchorsHeading(TestCase):
@golden_data_cache(__file__)
def cal_anchors_heading_cpu(self, anchors, origin_pos=None):
if origin_pos is None:
input_add_start = torch.cat((torch.zeros_like(anchors[:, :, 0:1, :]), anchors), dim=-2)
elif len(origin_pos.shape) == 2:
input_add_start = torch.cat((origin_pos.unsqueeze(1).unsqueeze(1).repeat(1, anchors.shape[1], 1, 1), anchors), dim=-2)
xy_diff = input_add_start[:, :, 1:, :] - input_add_start[:, :, :-1, :]
heading_valid = torch.logical_or(xy_diff[..., 0] > 0.1, xy_diff[..., 1] > 0.1)
heading = torch.atan2(xy_diff[..., 1], xy_diff[..., 0])
for t in range(heading.shape[2]):
heading_t = heading[:, :, t]
heading_valid_t = heading_valid[:, :, t]
if t == 0:
heading_t[heading_valid_t == False] = 0
else:
heading_t[heading_valid_t == False] = heading[:, :, t - 1][heading_valid_t == False]
return heading.numpy()
def cal_anchors_heading_npu(self, anchors, origin_pos=None):
anchors = anchors.npu()
origin_pos = None if origin_pos is None else origin_pos.npu()
heading = mx_driving.cal_anchors_heading(anchors, origin_pos)
return heading.cpu().numpy()
@golden_data_cache(__file__)
def gen_data(self, batch_size, anchors_num, seq_length):
anchors = np.random.uniform(-5, 5, (batch_size, anchors_num, seq_length, 2))
origin_pos = np.random.uniform(-5, 5, (batch_size, 2))
return torch.from_numpy(anchors).float(), torch.from_numpy(origin_pos).float()
def one_case(self, batch_size, anchors_num, seq_length, none_origin_pos=False):
anchors, origin_pos = self.gen_data(batch_size, anchors_num, seq_length)
origin_pos = origin_pos if none_origin_pos is False else None
heading_cpu = self.cal_anchors_heading_cpu(anchors, origin_pos)
heading_npu = self.cal_anchors_heading_npu(anchors, origin_pos)
self.assertRtolEqual(heading_cpu, heading_npu)
def test_cal_anchors_heading(self):
self.one_case(1, 1, 1, True)
self.one_case(1, 1, 1, False)
self.one_case(2048, 256, 32, True)
self.one_case(2048, 256, 32, False)
self.one_case(2, 10240, 32, True)
self.one_case(2, 10240, 32, False)
self.one_case(2, 256, 256, True)
self.one_case(2, 256, 256, False)
if __name__ == "__main__":
run_tests()