from abc import ABC, abstractmethod
import torch
from data_cache import golden_data_cache
from torch_npu.testing.testcase import TestCase, run_tests
import mx_driving
import mx_driving.point
class CreateBenchMarkTest(ABC):
def __init__(self, datatype):
self.batch = None
self.N = None
self.numPoints = None
self.point = None
self.nearestDist = None
self.datatype = datatype
@abstractmethod
def createData(self, datatype):
pass
def compare_min(self, a):
if a[0] > a[1]:
return a[1]
else:
return a[0]
@golden_data_cache(__file__)
def getCpuRes(self):
cpuRes = torch.zeros([self.batch, self.numPoints], dtype=torch.int32)
nearestDistCopy = self.nearestDist.clone()
point_xyz = self.point.clone()
if self.datatype == torch.bfloat16:
point_xyz = self.point.clone().to(dtype=torch.float32).contiguous()
for i in range(self.batch):
sampled = 1
index = 0
while sampled < self.numPoints:
deltaX = point_xyz[i][0] - point_xyz[i][0][index]
deltaY = point_xyz[i][1] - point_xyz[i][1][index]
deltaZ = point_xyz[i][2] - point_xyz[i][2][index]
deltaX2 = deltaX * deltaX
deltaY2 = deltaY * deltaY
deltaZ2 = deltaZ * deltaZ
currentDist = deltaX2 + deltaY2 + deltaZ2
nearestDistCopy[i] = currentDist.min(nearestDistCopy[i])
index = torch.argmax(nearestDistCopy[i])
cpuRes[i][sampled] = index
sampled = sampled + 1
return cpuRes
class Test1(CreateBenchMarkTest):
def createData(self, datatype):
self.batch = 47
self.N = 717
self.numPoints = 580
self.point = torch.zeros([self.batch, 3, self.N], dtype=datatype)
for i in range(self.batch):
for j in range(self.N):
self.point[i, 0, j] = j
if datatype == torch.bfloat16:
self.nearestDist = torch.ones([self.batch, self.N], dtype=torch.float32) * 1e10
else:
self.nearestDist = torch.ones([self.batch, self.N], dtype=datatype) * 1e10
class Test2(CreateBenchMarkTest):
def createData(self, datatype):
self.batch = 193
self.N = 579
self.numPoints = 123
self.point = torch.zeros([self.batch, 3, self.N], dtype=datatype)
for i in range(self.batch):
for j in range(self.N):
self.point[i, 0, j] = j
self.point[i, 1, j] = j + 1
self.point[i, 2, j] = j + 3
if datatype == torch.bfloat16:
self.nearestDist = torch.ones([self.batch, self.N], dtype=torch.float32) * 1e10
else:
self.nearestDist = torch.ones([self.batch, self.N], dtype=datatype) * 1e10
class Test3(CreateBenchMarkTest):
def createData(self, datatype):
self.batch = 21
self.N = 3901
self.numPoints = 671
self.point = torch.zeros([self.batch, 3, self.N], dtype=datatype)
for i in range(self.batch):
for j in range(self.N):
self.point[i, 0, j] = j
if datatype == torch.bfloat16:
self.nearestDist = torch.ones([self.batch, self.N], dtype=torch.float32) * 1e10
else:
self.nearestDist = torch.ones([self.batch, self.N], dtype=datatype) * 1e10
class Test4(CreateBenchMarkTest):
def createData(self, datatype):
self.batch = 151
self.N = 3901
self.numPoints = 671
self.point = torch.zeros([self.batch, 3, self.N], dtype=datatype)
for i in range(self.batch):
for j in range(self.N):
self.point[i, 0, j] = j
self.point[i, 1, j] = j + 1
self.point[i, 2, j] = j + 3
if datatype == torch.bfloat16:
self.nearestDist = torch.ones([self.batch, self.N], dtype=torch.float32) * 1e10
else:
self.nearestDist = torch.ones([self.batch, self.N], dtype=datatype) * 1e10
class TestFurthestPointSample(TestCase):
def cpu_op_exec(self, myTest):
return myTest.getCpuRes()
def npu_op_exec(self, myTest):
res1 = mx_driving.point.furthest_point_sampling(myTest.point.clone().permute(0, 2, 1).npu(), myTest.numPoints)
res2 = mx_driving.point.npu_furthest_point_sampling(
myTest.point.clone().permute(0, 2, 1).npu(), myTest.numPoints
)
res3 = mx_driving.furthest_point_sampling(myTest.point.clone().permute(0, 2, 1).npu(), myTest.numPoints)
return res1, res2, res3
def compare_res(self, myTest):
myTest.createData(myTest.datatype)
cpuOutput = self.cpu_op_exec(myTest)
npuOutput1, npuOutput2, npuOutput3 = self.npu_op_exec(myTest)
self.assertRtolEqual(cpuOutput, npuOutput1)
self.assertRtolEqual(cpuOutput, npuOutput2)
self.assertRtolEqual(cpuOutput, npuOutput3)
def test_FurthestPointSampleF32(self):
print("start float32")
self.compare_res(Test1(torch.float32))
self.compare_res(Test2(torch.float32))
self.compare_res(Test3(torch.float32))
def test_FurthestPointSampleF16(self):
print("start float16")
self.compare_res(Test1(torch.float16))
self.compare_res(Test2(torch.float16))
self.compare_res(Test3(torch.float16))
def test_FurthestPointSampleBF16(self):
print("start bfloat16")
self.compare_res(Test1(torch.bfloat16))
self.compare_res(Test2(torch.bfloat16))
self.compare_res(Test3(torch.bfloat16))
if __name__ == "__main__":
run_tests()