fb4bb61c创建于 2025年5月7日历史提交
import unittest

import numpy as np
import torch
import torch.nn.functional as F
import torch_npu
from data_cache import golden_data_cache
from torch_npu.testing.testcase import TestCase, run_tests
import mx_driving

@golden_data_cache(__file__)
def gen_inputs(shape1, shape2, dtype):
    projection_mat =torch.randn(shape1).npu()
    pts_extend =torch.randn(shape2).npu()
    return projection_mat, pts_extend


def gen_former_npu_outputs(projection_mat, pts_extend):
    points_2d_mm = torch.matmul(projection_mat[:, :, None, None], pts_extend[:, None, ..., None])
    return points_2d_mm


class TestBatchMatmul(TestCase):  
    def test_npu_batch_matmul_sixdim(self, device="npu"):
        projection_mat, pts_extend = gen_inputs([6, 6, 4, 4], [6, 1220, 13, 4], np.float32)
        projection_mat_fused = projection_mat.detach()
        pts_extend2_fused = pts_extend.detach()
        projection_mat.requires_grad = True
        pts_extend.requires_grad = True      
        former_npu_result = gen_former_npu_outputs(projection_mat, pts_extend)
        grad = torch.ones_like(former_npu_result)
        former_npu_result.backward(grad)
        x_grad_former_npu = projection_mat.grad
        w_grad_former_npu = pts_extend.grad
        
        projection_mat_fused = projection_mat_fused[:, :, None, None].contiguous()
        pts_extend2_fused = pts_extend2_fused[:, None, ..., None].contiguous()
        projection_mat_fused.requires_grad = True
        pts_extend2_fused.requires_grad = True        
        result = mx_driving.npu_batch_matmul(projection_mat_fused, pts_extend2_fused)
        grad = torch.ones_like(result)
        result.backward(grad)
        x_grad_npu = projection_mat_fused.grad
        w_grad_npu = pts_extend2_fused.grad

        self.assertRtolEqual(result.detach().cpu().numpy(), former_npu_result.detach().cpu().numpy())
        self.assertRtolEqual(x_grad_former_npu.cpu().numpy(), x_grad_npu.squeeze().cpu().numpy())
        self.assertRtolEqual(w_grad_former_npu.cpu().numpy(), w_grad_npu.squeeze().cpu().numpy())
    
    def test_npu_batch_matmul_four_dim(self, device="npu"):
        projection_mat, pts_extend = gen_inputs([6, 1, 4, 4], [6, 1220, 4, 1], np.float32)
        projection_mat_fused = projection_mat.detach()
        pts_extend2_fused = pts_extend.detach()
        projection_mat.requires_grad = True
        pts_extend.requires_grad = True
        former_npu_result = torch.matmul(projection_mat, pts_extend)
        grad = torch.ones_like(former_npu_result)
        former_npu_result.backward(grad)
        x_grad_former_npu = projection_mat.grad
        w_grad_former_npu = pts_extend.grad

        projection_mat_fused.requires_grad = True
        pts_extend2_fused.requires_grad = True
        pts_extend2_fused_ = pts_extend2_fused
        result = mx_driving.npu_batch_matmul(projection_mat_fused, pts_extend2_fused_)
        grad = torch.ones_like(result)
        result.backward(grad)
        x_grad_npu = projection_mat_fused.grad
        w_grad_npu = pts_extend2_fused.grad

        self.assertRtolEqual(result.detach().cpu().numpy(), former_npu_result.detach().cpu().numpy())
        self.assertRtolEqual(x_grad_former_npu.cpu().numpy(), x_grad_npu.cpu().numpy())
        self.assertRtolEqual(w_grad_former_npu.cpu().numpy(), w_grad_npu.cpu().numpy())
    
    def test_npu_batch_matmul_none_brodcast(self, device="npu"):
        projection_mat, pts_extend = gen_inputs([6, 1220, 4, 4], [6, 1220, 4, 1], np.float32)
        projection_mat_fused = projection_mat.detach()
        pts_extend2_fused = pts_extend.detach()
        projection_mat.requires_grad = True
        pts_extend.requires_grad = True      
        former_npu_result = torch.matmul(projection_mat, pts_extend)
        grad = torch.ones_like(former_npu_result)
        former_npu_result.backward(grad)
        x_grad_former_npu = projection_mat.grad
        w_grad_former_npu = pts_extend.grad

        projection_mat_fused.requires_grad = True
        pts_extend2_fused.requires_grad = True 
        pts_extend2_fused_ = pts_extend2_fused
        result = mx_driving.npu_batch_matmul(projection_mat_fused, pts_extend2_fused_)
        grad = torch.ones_like(result)
        result.backward(grad)
        x_grad_npu = projection_mat_fused.grad
        w_grad_npu = pts_extend2_fused.grad

        self.assertRtolEqual(result.detach().cpu().numpy(), former_npu_result.detach().cpu().numpy())
        self.assertRtolEqual(x_grad_former_npu.cpu().numpy(), x_grad_npu.cpu().numpy())
        self.assertRtolEqual(w_grad_former_npu.cpu().numpy(), w_grad_npu.cpu().numpy())
    
    def test_npu_batch_matmul_need_brodcast(self, device="npu"):
        projection_mat, pts_extend = gen_inputs([1, 1, 4, 4], [6, 1220, 4, 1], np.float32)
        projection_mat_fused = projection_mat.detach()
        pts_extend2_fused = pts_extend.detach()
        projection_mat.requires_grad = True
        pts_extend.requires_grad = True      
        former_npu_result = torch.matmul(projection_mat, pts_extend)
        grad = torch.ones_like(former_npu_result)
        former_npu_result.backward(grad)
        x_grad_former_npu = projection_mat.grad
        w_grad_former_npu = pts_extend.grad

        projection_mat_fused.requires_grad = True
        pts_extend2_fused.requires_grad = True  
        pts_extend2_fused_ = pts_extend2_fused
        result = mx_driving.npu_batch_matmul(projection_mat_fused, pts_extend2_fused_)
        grad = torch.ones_like(result)
        result.backward(grad)
        x_grad_npu = projection_mat_fused.grad
        w_grad_npu = pts_extend2_fused.grad

        self.assertRtolEqual(result.detach().cpu().numpy(), former_npu_result.detach().cpu().numpy())
        self.assertRtolEqual(x_grad_former_npu.cpu().numpy(), x_grad_npu.cpu().numpy())
        self.assertRtolEqual(w_grad_former_npu.cpu().numpy(), w_grad_npu.cpu().numpy())
    
    def test_npu_batch_matmul_kernel_3(self, device="npu"):
        projection_mat, pts_extend = gen_inputs([6, 6, 3, 3], [6, 1220, 13, 3], np.float32)
        projection_mat_fused = projection_mat.detach()
        pts_extend2_fused = pts_extend.detach()
        projection_mat.requires_grad = True
        pts_extend.requires_grad = True      
        former_npu_result = gen_former_npu_outputs(projection_mat, pts_extend)
        grad = torch.ones_like(former_npu_result)
        former_npu_result.backward(grad)
        x_grad_former_npu = projection_mat.grad
        w_grad_former_npu = pts_extend.grad

        projection_mat_fused = projection_mat_fused[:, :, None, None].contiguous()
        pts_extend2_fused = pts_extend2_fused[:, None, ..., None].contiguous()
        projection_mat_fused.requires_grad = True
        pts_extend2_fused.requires_grad = True        
        result = mx_driving.npu_batch_matmul(projection_mat_fused, pts_extend2_fused)
        grad = torch.ones_like(result)
        result.backward(grad)
        x_grad_npu = projection_mat_fused.grad
        w_grad_npu = pts_extend2_fused.grad

        self.assertRtolEqual(result.detach().cpu().numpy(), former_npu_result.detach().cpu().numpy())
        self.assertRtolEqual(x_grad_former_npu.cpu().numpy(), x_grad_npu.squeeze().cpu().numpy())
        self.assertRtolEqual(w_grad_former_npu.cpu().numpy(), w_grad_npu.squeeze().cpu().numpy())
    
    def test_npu_batch_matmul_kernel_3_dim_4(self, device="npu"):
        projection_mat, pts_extend = gen_inputs([6, 1, 3, 3], [6, 1220, 3, 1], np.float32)
        projection_mat_fused = projection_mat.detach()
        pts_extend2_fused = pts_extend.detach()
        projection_mat.requires_grad = True
        pts_extend.requires_grad = True      
        former_npu_result = torch.matmul(projection_mat, pts_extend)
        grad = torch.ones_like(former_npu_result)
        former_npu_result.backward(grad)
        x_grad_former_npu = projection_mat.grad
        w_grad_former_npu = pts_extend.grad

        projection_mat_fused.requires_grad = True
        pts_extend2_fused.requires_grad = True
        pts_extend2_fused_ = pts_extend2_fused
        result = mx_driving.npu_batch_matmul(projection_mat_fused, pts_extend2_fused_)
        grad = torch.ones_like(result)
        result.backward(grad)
        x_grad_npu = projection_mat_fused.grad
        w_grad_npu = pts_extend2_fused.grad

        self.assertRtolEqual(result.detach().cpu().numpy(), former_npu_result.detach().cpu().numpy())
        self.assertRtolEqual(x_grad_former_npu.cpu().numpy(), x_grad_npu.cpu().numpy())
        self.assertRtolEqual(w_grad_former_npu.cpu().numpy(), w_grad_npu.cpu().numpy())

if __name__ == "__main__":
    run_tests()