diff --git a/mmcv/cnn/bricks/transformer.py b/mmcv/cnn/bricks/transformer.py
index f83b9a69..4952e07a 100644
--- a/mmcv/cnn/bricks/transformer.py
+++ b/mmcv/cnn/bricks/transformer.py
@@ -520,9 +520,6 @@ class MultiheadAttention(BaseModule):
                 # use query_pos if key_pos is not available
                 if query_pos.shape == key.shape:
                     key_pos = query_pos
-                else:
-                    warnings.warn(f'position encoding of key is'
-                                  f'missing in {self.__class__.__name__}.')
         if query_pos is not None:
             query = query + query_pos
         if key_pos is not None:
@@ -544,7 +541,8 @@ class MultiheadAttention(BaseModule):
             key=key,
             value=value,
             attn_mask=attn_mask,
-            key_padding_mask=key_padding_mask)[0]
+            key_padding_mask=key_padding_mask,
+            need_weights=False)[0]
 
         if self.batch_first:
             out = out.transpose(0, 1)
diff --git a/mmcv/ops/multi_scale_deform_attn.py b/mmcv/ops/multi_scale_deform_attn.py
index 7459263c..e7cfe9ab 100644
--- a/mmcv/ops/multi_scale_deform_attn.py
+++ b/mmcv/ops/multi_scale_deform_attn.py
@@ -107,7 +107,7 @@ class MultiScaleDeformableAttnFunction(Function):
 
 
 def multi_scale_deformable_attn_pytorch(
-        value: torch.Tensor, value_spatial_shapes: torch.Tensor,
+        value: torch.Tensor, value_spatial_shapes: list,
         sampling_locations: torch.Tensor,
         attention_weights: torch.Tensor) -> torch.Tensor:
     """CPU version of multi-scale deformable attention.
@@ -280,7 +280,8 @@ class MultiScaleDeformableAttention(BaseModule):
                 query_pos: Optional[torch.Tensor] = None,
                 key_padding_mask: Optional[torch.Tensor] = None,
                 reference_points: Optional[torch.Tensor] = None,
-                spatial_shapes: Optional[torch.Tensor] = None,
+                spatial_shapes: Optional[list] = None,
+                spatial_shapes_tensor: Optional[torch.Tensor] = None,
                 level_start_index: Optional[torch.Tensor] = None,
                 **kwargs) -> torch.Tensor:
         """Forward Function of MultiScaleDeformAttention.
@@ -332,7 +333,6 @@ class MultiScaleDeformableAttention(BaseModule):
 
         bs, num_query, _ = query.shape
         bs, num_value, _ = value.shape
-        assert (spatial_shapes[:, 0] * spatial_shapes[:, 1]).sum() == num_value
 
         value = self.value_proj(value)
         if key_padding_mask is not None:
@@ -350,7 +350,7 @@ class MultiScaleDeformableAttention(BaseModule):
                                                    self.num_points)
         if reference_points.shape[-1] == 2:
             offset_normalizer = torch.stack(
-                [spatial_shapes[..., 1], spatial_shapes[..., 0]], -1)
+                [spatial_shapes_tensor[..., 1], spatial_shapes_tensor[..., 0]], -1)
             sampling_locations = reference_points[:, :, None, :, None, :] \
                 + sampling_offsets \
                 / offset_normalizer[None, None, None, :, None, :]