@@ -520,9 +520,6 @@ class MultiheadAttention(BaseModule):
# use query_pos if key_pos is not available
if query_pos.shape == key.shape:
key_pos = query_pos
- else:
- warnings.warn(f'position encoding of key is'
- f'missing in {self.__class__.__name__}.')
if query_pos is not None:
query = query + query_pos
if key_pos is not None:
@@ -544,7 +541,8 @@ class MultiheadAttention(BaseModule):
key=key,
value=value,
attn_mask=attn_mask,
- key_padding_mask=key_padding_mask)[0]
+ key_padding_mask=key_padding_mask,
+ need_weights=False)[0]
if self.batch_first:
out = out.transpose(0, 1)
@@ -107,7 +107,7 @@ class MultiScaleDeformableAttnFunction(Function):
def multi_scale_deformable_attn_pytorch(
- value: torch.Tensor, value_spatial_shapes: torch.Tensor,
+ value: torch.Tensor, value_spatial_shapes: list,
sampling_locations: torch.Tensor,
attention_weights: torch.Tensor) -> torch.Tensor:
"""CPU version of multi-scale deformable attention.
@@ -280,7 +280,8 @@ class MultiScaleDeformableAttention(BaseModule):
query_pos: Optional[torch.Tensor] = None,
key_padding_mask: Optional[torch.Tensor] = None,
reference_points: Optional[torch.Tensor] = None,
- spatial_shapes: Optional[torch.Tensor] = None,
+ spatial_shapes: Optional[list] = None,
+ spatial_shapes_tensor: Optional[torch.Tensor] = None,
level_start_index: Optional[torch.Tensor] = None,
**kwargs) -> torch.Tensor:
"""Forward Function of MultiScaleDeformAttention.
@@ -332,7 +333,6 @@ class MultiScaleDeformableAttention(BaseModule):
bs, num_query, _ = query.shape
bs, num_value, _ = value.shape
- assert (spatial_shapes[:, 0] * spatial_shapes[:, 1]).sum() == num_value
value = self.value_proj(value)
if key_padding_mask is not None:
@@ -350,7 +350,7 @@ class MultiScaleDeformableAttention(BaseModule):
self.num_points)
if reference_points.shape[-1] == 2:
offset_normalizer = torch.stack(
- [spatial_shapes[..., 1], spatial_shapes[..., 0]], -1)
+ [spatial_shapes_tensor[..., 1], spatial_shapes_tensor[..., 0]], -1)
sampling_locations = reference_points[:, :, None, :, None, :] \
+ sampling_offsets \
/ offset_normalizer[None, None, None, :, None, :]