import numpy as np
import torch
import torch_npu
import torch_scatter
from torch_npu.testing.common_utils import create_common_tensor
from torch_npu.testing.testcase import TestCase, run_tests
import mx_driving.common
class TestScatterMaxWithArgmax(TestCase):
def cpu_op_exec(self, updates, indices):
updates.requires_grad = True
output, output_argmax = torch_scatter.scatter_max(updates, indices.long(), dim=0)
output.backward(torch.ones_like(output))
output_grad = updates.grad
output_grad = output_grad.detach().numpy()
output = output.detach().numpy()
output_argmax = output_argmax.to(torch.int32).numpy()
return output, output_argmax, output_grad
def npu_op_exec(self, updates, indices):
updates.requires_grad = True
output, output_argmax = mx_driving.common.scatter_max(updates, indices)
output, output_argmax = mx_driving.scatter_max(updates, indices)
output.backward(torch.ones_like(output))
output_grad = updates.grad.cpu()
output_grad = output_grad.detach().numpy()
output = output.cpu()
output = output.detach().numpy()
output_argmax = output_argmax.cpu().numpy()
return output, output_argmax, output_grad
def test_scatter_max_dim3_1(self):
shape_updates = (100, 3, 16)
shape_indices = (100, 1, 1)
cpu_updates, npu_updates = create_common_tensor(["float32", 2, shape_updates], 0, 100)
cpu_indices, npu_indices = create_common_tensor(["int32", 2, shape_indices], 0, 100)
output_npu, output_argmax_npu = mx_driving.common.scatter_max(npu_updates, npu_indices)
output_cpu, output_argmax_cpu = torch_scatter.scatter_max(cpu_updates, cpu_indices.to(torch.int64), dim=0)
self.assertRtolEqual(output_cpu, output_npu)
self.assertRtolEqual(output_argmax_cpu.to(torch.int32), output_argmax_npu)
def test_scatter_max_dim5_2(self):
shape_updates = (100, 4, 3, 16, 3)
shape_indices = (100, 1, 1)
cpu_updates, npu_updates = create_common_tensor(["float32", 2, shape_updates], 0, 100)
cpu_indices, npu_indices = create_common_tensor(["int32", 2, shape_indices], 0, 100)
output_npu, output_argmax_npu = mx_driving.common.scatter_max(npu_updates, npu_indices)
output_cpu, output_argmax_cpu = torch_scatter.scatter_max(cpu_updates, cpu_indices.to(torch.int64), dim=0)
self.assertRtolEqual(output_cpu, output_npu)
self.assertRtolEqual(output_argmax_cpu.to(torch.int32), output_argmax_npu)
def test_scatter_max_bigtail_3(self):
shape_updates = (100, 8192)
shape_indices = (100, 1)
cpu_updates, npu_updates = create_common_tensor(["float32", 2, shape_updates], 0, 100)
cpu_indices, npu_indices = create_common_tensor(["int32", 2, shape_indices], 0, 100)
output_npu, output_argmax_npu = mx_driving.common.scatter_max(npu_updates, npu_indices)
output_cpu, output_argmax_cpu = torch_scatter.scatter_max(cpu_updates, cpu_indices.to(torch.int64), dim=0)
self.assertRtolEqual(output_cpu, output_npu)
self.assertRtolEqual(output_argmax_cpu.to(torch.int32), output_argmax_npu)
def test_scatter_max_dim3_and_unaligned_4(self):
shape_updates = (1024, 123, 5)
shape_indices = (1024, 1, 1)
cpu_updates, npu_updates = create_common_tensor(["float32", 2, shape_updates], 0, 100)
cpu_indices, npu_indices = create_common_tensor(["int32", 2, shape_indices], 0, 100)
output_npu, output_argmax_npu = mx_driving.common.scatter_max(npu_updates, npu_indices)
output_cpu, output_argmax_cpu = torch_scatter.scatter_max(cpu_updates, cpu_indices.to(torch.int64), dim=0)
self.assertRtolEqual(output_cpu, output_npu)
self.assertRtolEqual(output_argmax_cpu.to(torch.int32), output_argmax_npu)
def test_scatter_max_with_out_1(self):
shape_updates = (1024, 16)
shape_indices = (1024, 1)
shape_out = (1024, 16)
cpu_updates, npu_updates = create_common_tensor(["float32", 2, shape_updates], 0, 100)
cpu_indices, npu_indices = create_common_tensor(["int32", 2, shape_indices], 0, 1000)
cpu_out, npu_out = create_common_tensor(["float32", 2, shape_out], 0, 100)
output_npu, output_argmax_npu = mx_driving.common.scatter_max(npu_updates, npu_indices, npu_out)
output_npu, output_argmax_npu = mx_driving.scatter_max(npu_updates, npu_indices, npu_out)
output_cpu, output_argmax_cpu = torch_scatter.scatter_max(cpu_updates, cpu_indices.to(torch.int64), dim=0, out=cpu_out)
self.assertRtolEqual(output_cpu, output_npu)
self.assertRtolEqual(output_argmax_cpu.to(torch.int32), output_argmax_npu)
def test_scatter_max_with_out_2(self):
shape_updates = (100, 3, 16)
shape_indices = (100, 1, 1)
shape_out = (20, 3, 16)
cpu_updates, npu_updates = create_common_tensor(["float32", 2, shape_updates], 0, 100)
cpu_indices, npu_indices = create_common_tensor(["int32", 2, shape_indices], 0, 20)
cpu_out, npu_out = create_common_tensor(["float32", 2, shape_out], 0, 100)
output_npu, output_argmax_npu = mx_driving.common.scatter_max(npu_updates, npu_indices, npu_out)
output_npu, output_argmax_npu = mx_driving.scatter_max(npu_updates, npu_indices, npu_out)
output_cpu, output_argmax_cpu = torch_scatter.scatter_max(cpu_updates, cpu_indices.to(torch.int64), dim=0, out=cpu_out)
self.assertRtolEqual(output_cpu, output_npu)
self.assertRtolEqual(output_argmax_cpu.to(torch.int32), output_argmax_npu)
def test_scatter_max_with_grad_1(self):
shape_updates = (262144, 16)
shape_indices = (262144, 1)
cpu_updates_input, npu_updates_input = create_common_tensor(["float32", 2, shape_updates], 0, 262144)
cpu_indices_input, npu_indices_input = create_common_tensor(["int32", 2, shape_indices], 0, 262144)
cpu_output = self.cpu_op_exec(cpu_updates_input, cpu_indices_input)
npu_output = self.npu_op_exec(npu_updates_input, npu_indices_input)
self.assertRtolEqual(cpu_output[0], npu_output[0])
self.assertRtolEqual(cpu_output[1], npu_output[1])
self.assertRtolEqual(cpu_output[2], npu_output[2])
def test_scatter_max_with_grad_2(self):
shape_updates = (78848, 16)
shape_indices = (78848, 1)
cpu_updates_input, npu_updates_input = create_common_tensor(["float32", 2, shape_updates], 0, 78848)
cpu_indices_input, npu_indices_input = create_common_tensor(["int32", 2, shape_indices], 0, 78848)
cpu_output = self.cpu_op_exec(cpu_updates_input, cpu_indices_input)
npu_output = self.npu_op_exec(npu_updates_input, npu_indices_input)
self.assertRtolEqual(cpu_output[0], npu_output[0])
self.assertRtolEqual(cpu_output[1], npu_output[1])
self.assertRtolEqual(cpu_output[2], npu_output[2])
def test_scatter_max_with_grad_3(self):
shape_updates = (1024, 16)
shape_indices = (1024, 1)
cpu_updates_input, npu_updates_input = create_common_tensor(["float32", 2, shape_updates], 0, 100)
cpu_indices_input, npu_indices_input = create_common_tensor(["int32", 2, shape_indices], 0, 100)
cpu_output = self.cpu_op_exec(cpu_updates_input, cpu_indices_input)
npu_output = self.npu_op_exec(npu_updates_input, npu_indices_input)
self.assertRtolEqual(cpu_output[0], npu_output[0])
self.assertRtolEqual(cpu_output[1], npu_output[1])
self.assertRtolEqual(cpu_output[2], npu_output[2])
def test_scatter_max_with_grad_4(self):
shape_updates = (1024, 128)
shape_indices = (1024, 1)
cpu_updates_input, npu_updates_input = create_common_tensor(["float32", 2, shape_updates], 0, 100)
cpu_indices_input, npu_indices_input = create_common_tensor(["int32", 2, shape_indices], 0, 100)
cpu_output = self.cpu_op_exec(cpu_updates_input, cpu_indices_input)
npu_output = self.npu_op_exec(npu_updates_input, npu_indices_input)
self.assertRtolEqual(cpu_output[0], npu_output[0])
self.assertRtolEqual(cpu_output[1], npu_output[1])
self.assertRtolEqual(cpu_output[2], npu_output[2])
def test_scatter_max_with_grad_bigtail(self):
shape_updates = (1024, 4096)
shape_indices = (1024, 1)
cpu_updates_input, npu_updates_input = create_common_tensor(["float32", 2, shape_updates], 0, 100)
cpu_indices_input, npu_indices_input = create_common_tensor(["int32", 2, shape_indices], 0, 100)
cpu_output = self.cpu_op_exec(cpu_updates_input, cpu_indices_input)
npu_output = self.npu_op_exec(npu_updates_input, npu_indices_input)
self.assertRtolEqual(cpu_output[0], npu_output[0])
self.assertRtolEqual(cpu_output[1], npu_output[1])
self.assertRtolEqual(cpu_output[2], npu_output[2])
def test_scatter_max_with_grad_unaligned(self):
shape_updates = (1024, 17)
shape_indices = (1024, 1)
cpu_updates_input, npu_updates_input = create_common_tensor(["float32", 2, shape_updates], 0, 100)
cpu_indices_input, npu_indices_input = create_common_tensor(["int32", 2, shape_indices], 0, 100)
cpu_output = self.cpu_op_exec(cpu_updates_input, cpu_indices_input)
npu_output = self.npu_op_exec(npu_updates_input, npu_indices_input)
self.assertRtolEqual(cpu_output[0], npu_output[0])
self.assertRtolEqual(cpu_output[1], npu_output[1])
self.assertRtolEqual(cpu_output[2], npu_output[2])
if __name__ == "__main__":
run_tests()