npu_batch_matmul

接口原型

mx_driving.npu_batch_matmul(Tensor projection_mat, Tensor pts_extend) -> Tensor

功能描述

实现批量矩阵乘法,与torch.batch_matmul功能相同。

参数说明

  • projection_mat(Tensor):投影矩阵,数据类型为float32。Shape为4-6维,最后两维需要是4, 43, 3,且和pts_extend互相可广播。
  • pts_extend(Tensor):所有点的特征,数据类型为float32。Shape为4-6维,最后两维需要是4, 13, 1,且和projection_mat互相可广播。

返回值

  • output(Tensor):矩阵乘结果,数据类型为float32

支持的型号

  • Atlas A2 训练系列产品

调用示例

输入是6维

import numpy as np
import torch, torch_npu
import mx_driving
projection_mat =torch.randn((6, 6, 4, 4)).npu()
pts_extend =torch.randn(6, 1220, 13, 4).npu()
projection_mat_fused = projection_mat[:, :, None, None].contiguous()
pts_extend2_fused = pts_extend[:, None, ..., None].contiguous()
projection_mat_fused.requires_grad=True
pts_extend2_fused.requires_grad=True
result = mx_driving.npu_batch_matmul(projection_mat_fused, pts_extend2_fused)
grad = torch.ones_like(result)
result.backward(grad)

输入是4维

import numpy as np
import torch, torch_npu
import mx_driving
projection_mat =torch.randn((6, 1220, 4, 4)).npu()
pts_extend = torch.randn(6, 1220, 4, 1).npu()
projection_mat.requires_grad=True
pts_extend.requires_grad=True
result = mx_driving.npu_batch_matmul(projection_mat, pts_extend)
grad = torch.ones_like(result)
result.backward(grad)