three_nn
接口原型
mx_driving.three_nn(Tensor target, Tensor source) -> (Tensor dist, Tensor idx)
功能描述
对target中的每个点找到source中对应batch中的距离最近的3个点,并且返回此3个点的距离和索引值。
参数说明
target(Tensor):点数据,表示(x, y, z)三维坐标,数据类型为float32/float16。shape为[B, npoint, 3]。其中B为batch size,npoint为点的数量。source(Tensor):点数据,表示(x, y, z)三维坐标,数据类型为float32/float16。shape为[B, N, 3]。其中B为batch size,N为点的数量。
返回值
dist(Tensor):采样后的距离数据,数据类型为float32/float16。shape为[B, npoint, 3]。idx(Tensor):采样后的索引数据,数据类型为int32/int64。shape为[B, npoint, 3]。
算子约束
- source和target的shape必须是3维,且source和target的shape的dim的第2维必须是3。
- 距离相同时,排序为不稳定排序;此时存在距离精度符合要求但索引精度错误问题,导致与竞品无法完全对齐。
- 性能在N值较大的场景下较优。
支持的型号
- Atlas A2 训练系列产品
调用示例
import torch, torch_npu
from mx_driving import three_nn
source = torch.tensor([[[1, 1, 1], [1, 1, 1], [1, 1, 1]], [[1, 1, 1], [1, 1, 1], [1, 1, 1]]], dtype=torch.float32).npu()
target = torch.tensor([[[1, 2, 3]], [[1, 2, 3]]], dtype=torch.float32).npu()
dist, idx = three_nn(target, source)