roipoint_pool3d
接口原型
mx_driving.roipoint_pool3d(int num_sampled_points, Tensor points, Tensor point_features, Tensor boxes3d) -> (Tensor pooled_features, Tensor pooled_empty_flag)
功能描述
对每个3D方案的几何特定特征进行编码。
参数说明
num_sampled_points(int):特征点的数量,正整数。points(Tensor):点张量,数据类型为float32, float16。shape 为[B, N, 3]。3分别代表x, y, z。point_features(Tensor):点特征张量,数据类型为float32, float16。shape 为[B, N, C]。C分别代表x, y, z。boxes3d(Tensor):框张量,数据类型为float32, float16。shape 为[B, M, 7]。7分别代表x, y, z, x_size, y_size, z_size, rz。
返回值
pooled_features(Tensor):点在框内的特征张量,数据类型为float32, float16。shape 为[B, M, num, 3+C]。pooled_empty_flag(Tensor):所有点不在框内的空标记张量,数据类型为int32。shape 为[B, M]。
约束说明
points、point_features和boxes3d的数据类型必须相同,以及B也必须相同。num_sampled_points必须小于等于N。- 数据类型为
float32时,建议B小于96、N小于等于2048、M小于等于48、num_sampled_points小于等于48,C小于等于8,个别shape值略微超过建议值无影响,但所有shape值均大于建议值时,算子执行会发生错误。 - 数据类型为
float16时,建议B小于96、N小于等于3192、M小于等于60、num_sampled_points小于等于60,C小于等于8,个别shape值略微超过建议值无影响,但所有shape值均大于建议值时,算子执行会发生错误。 N/M的值越大,性能劣化越严重,建议N小于M的六百倍,否则性能可能会低于0.1x A100。
支持的型号
- Atlas A2 训练系列产品
调用示例
import torch, torch_npu
from mx_driving import roipoint_pool3d
num_sampled_points = 1
points = torch.tensor([[[1, 2, 3]]], dtype=torch.float).npu()
point_features = points.clone()
boxes3d = torch.tensor([[[1, 2, 3, 4, 5, 6, 1]]], dtype=torch.float).npu()
pooled_features, pooled_empty_flag = roipoint_pool3d(num_sampled_points, points, point_features, boxes3d)