three_interpolate

接口原型

mx_driving.three_interpolate(features: torch.Tensor, indices: torch.Tensor, weight: torch.Tensor) -> torch.Tensor

兼容:

mx_driving.common.three_interpolate(features: torch.Tensor, indices: torch.Tensor, weight: torch.Tensor) -> torch.Tensor

功能描述

对三维数据进行加权最近邻线性插值处理

参数说明

  • features(Tensor):需要被插值的特征,数据类型为float32|float16,维度为(B, C, M)。
  • indices(Tensor):获取目标特征计算的索引,数据类型为int32,维度为(B, N, 3),
    • indices的元素值需小于features的第三维度,即值在[0, M)。
  • weight(Tensor):获取目标特征计算的权重,数据类型为float32|float16,维度为(B, N, 3)。
    • weight数据类型与features须一致。
  • featuresindicesweight三个参数的每个维度须小于10000。
  • featuresindicesweight三个参数的大小请勿超过2^24。

返回值

  • output(Tensor):目标特征张量,数据类型为float32|float16,维度为(B, C, N)。

支持的型号

  • Atlas A2 训练系列产品

调用示例

import torch, torch_npu
from mx_driving import three_interpolate


features = torch.tensor(
            [[[2.4350, 4.7516, 4.4995, 2.4350, 2.4350, 2.4350],
            [3.1236, 2.6278, 3.0447, 3.1236, 3.1236, 3.1236],
            [2.6732, 2.8677, 2.6436, 2.6732, 2.6732, 2.6732],
            [0.0124, 7.0150, 7.0199, 0.0124, 0.0124, 0.0124],
            [0.3207, 0.0000, 0.3411, 0.3207, 0.3207, 0.3207]],
            [[0.0000, 0.9544, 2.4532, 0.0000, 0.0000, 0.0000],
            [0.5346, 1.9176, 1.4715, 0.5346, 0.5346, 0.5346],
            [0.0000, 0.2744, 2.0842, 0.0000, 0.0000, 0.0000],
            [0.3414, 1.5063, 1.6209, 0.3414, 0.3414, 0.3414],
            [0.5814, 0.0103, 0.0000, 0.5814, 0.5814, 0.5814]]],
            ).npu()
idx = torch.tensor(
            [[[0, 1, 2], [2, 3, 4], [2, 3, 4], [0, 1, 2], [0, 1, 2], [0, 1, 3]],
            [[0, 2, 3], [1, 3, 4], [2, 1, 4], [0, 2, 4], [0, 2, 4], [0, 1, 2]]],
            ).int().npu()
weight = torch.tensor(
            [[[3.3333e-01, 3.3333e-01, 3.3333e-01],
              [1.0000e+00, 5.8155e-08, 2.2373e-08],
              [1.0000e+00, 1.7737e-08, 1.7356e-08],
              [3.3333e-01, 3.3333e-01, 3.3333e-01],
              [3.3333e-01, 3.3333e-01, 3.3333e-01],
              [3.3333e-01, 3.3333e-01, 3.3333e-01]],
             [[3.3333e-01, 3.3333e-01, 3.3333e-01],
              [1.0000e+00, 1.3651e-08, 7.7312e-09],
              [1.0000e+00, 1.7148e-08, 1.4070e-08],
              [3.3333e-01, 3.3333e-01, 3.3333e-01],
              [3.3333e-01, 3.3333e-01, 3.3333e-01],
              [3.3333e-01, 3.3333e-01, 3.3333e-01]]],
            ).npu()

features.requires_grad = True
output = three_interpolate(features, idx, weight)
grad_out_tensor = torch.ones_like(output)
output.backward(grad_out_tensor)