multi_scale_deformable_attn(MultiScaleDeformableAttnFunction.Apply)

接口原型

mx_driving.multi_scale_deformable_attn(Tensor value, Tensor value_spatial_shapes, Tensor value_level_start_index, Tensor sampling_locations, Tensor attention_weights) -> Tensor

功能描述

多尺度可变形注意力机制, 将多个视角的特征图进行融合。

参数说明

  • value(Tensor):特征张量,数据类型为float32, float16。shape为[bs, num_keys, num_heads, embed_dims]。其中bs为batch size,num_keys为特征图的大小,num_heads为头的数量,embed_dims为特征图的维度,其中embed_dims需要为8的倍数。
  • value_spatial_shapes(Tensor):特征图的形状,数据类型为int32, int64。shape为[num_levels, 2]。其中num_levels为特征图的数量,2分别代表H, W
  • value_level_start_index(Tensor):偏移量张量,数据类型为int32, int64。shape为[num_levels]
  • sampling_locations(Tensor):位置张量,数据类型为float32, float16。shape为[bs, num_queries, num_heads, num_levels, num_points, 2]。其中bs为batch size,num_queries为查询的数量,num_heads为头的数量,num_levels为特征图的数量,num_points为采样点的数量,2分别代表x, y
  • attention_weights(Tensor):权重张量,数据类型为float32, float16。shape为[bs, num_queries, num_heads, num_levels, num_points]。其中bs为batch size,num_queries为查询的数量,num_heads为头的数量,num_levels为特征图的数量,num_points为采样点的数量。

返回值

  • output(Tensor):融合后的特征张量,数据类型为float32, float16。shape为[bs, num_queries, num_heads*embed_dims]

支持的型号

  • Atlas A2 训练系列产品

约束说明

  • 当前版本只支持num_points * num_levels ≤ 64,num_heads ≤ 8,embed_dims ≤ 256。

调用示例

import torch, torch_npu
from mx_driving import multi_scale_deformable_attn
bs, num_levels, num_heads, num_points, num_queries, embed_dims = 1, 1, 4, 8, 16, 32

shapes = torch.as_tensor([(100, 100)], dtype=torch.long)
num_keys = sum((H * W).item() for H, W in shapes)

value = torch.rand(bs, num_keys, num_heads, embed_dims) * 0.01
sampling_locations = torch.ones(bs, num_queries, num_heads, num_levels, num_points, 2) * 0.005
attention_weights = torch.rand(bs, num_queries, num_heads, num_levels, num_points) + 1e-5
level_start_index = torch.cat((shapes.new_zeros((1, )), shapes.prod(1).cumsum(0)[:-1]))

value.requires_grad_()
sampling_locations.requires_grad_()
attention_weights.requires_grad_()

out = multi_scale_deformable_attn(value.npu(), shapes.npu(), level_start_index.npu(), sampling_locations.npu(), attention_weights.npu())
out.backward(torch.ones_like(out))