deformable_aggregation

接口原型

mx_driving.deformable_aggregation(Tensor feature_maps, Tensor spatial_shape, Tensor scale_start_index, Tensor sample_locations, Tensor weight) -> Tensor

兼容:

mx_driving.fused.npu_deformable_aggregation(Tensor feature_maps, Tensor spatial_shape, Tensor scale_start_index, Tensor sample_locations, Tensor weight) -> Tensor
mx_driving.npu_deformable_aggregation(Tensor feature_maps, Tensor spatial_shape, Tensor scale_start_index, Tensor sample_locations, Tensor weight) -> Tensor

功能描述

可变形聚合,对于每个锚点实例,对多个关键点的多时间戳、视图、缩放特征进行稀疏采样后分层融合为实例特征,实现精确的锚点细化。

参数说明

  • feature_maps(Tensor):特征张量,数据类型为float32。shape为[bs, num_feat, c]。其中bs为batch size,num_feat为特征图的大小,c为特征图的维度。
  • spatial_shape(Tensor):特征图的形状,数据类型为int32。shape为[cam, scale, 2]。其中cam为相机数量,其中scale为每个相机的特征图数量,2分别代表H, W。
  • scale_start_index(Tensor):每个特征图的偏移位置张量,数据类型为int32。shape为[cam, scale],其中cam为相机数量,其中scale为每个相机的特征图数量。
  • sample_locations(Tensor):位置张量,数据类型为float32。shape为[bs, anchor, pts, cam, 2]。其中bs为batch size,anchor为锚点数量,pts为采样点的数量,cam为相机的数量,2分别代表y, x。
  • weight(Tensor):权重张量,数据类型为float32。shape为[bs, anchor, pts, cam, scale, group]。其中bs为batch size,anchor为锚点数量,pts为采样点的数量,cam为相机的数量,scale每个相机的特征图数量,group为分组数。

返回值

  • output(Tensor):输出结果张量,数据类型为float32。shape为[bs, anchor, c]

支持的型号

  • Atlas A2 训练系列产品

约束说明

  • bs <= 128
  • num_feat的值为spatial_shape中每幅图的特征数量之和
  • c <= 256,且c / group为8的整数倍
  • cam <= 6
  • scale <= 4
  • anchor <= 2048
  • pts <= 2048
  • group = 8
  • sample_locations的值在[0, 1]之间。
  • 每个输入tensor的数据量不超过1.5亿。
  • 反向具有相同约束。

调用示例

import torch, torch_npu
from mx_driving import deformable_aggregation

bs, num_feat, c, cam, anchor, pts, scale, group = 1, 2816, 256, 1, 10, 2000, 1, 8

feature_maps = torch.ones_like(torch.randn(bs,num_feat ,c))
spatial_shape = torch.tensor([[[32, 88]]], dtype=torch.int32)
scale_start_index = torch.tensor([[0]], dtype=torch.int32)
sampling_location = torch.rand(bs, anchor, pts, cam, 2)
weights = torch.randn(bs, anchor, pts, cam, scale, group)
feature_maps.requires_grad = True
out = deformable_aggregation(feature_maps.npu(), spatial_shape.npu(), scale_start_index.npu(), sampling_location.npu(), weights.npu())
grad_out_tensor = torch.ones_like(out)
out.backward(grad_out_tensor)