diff --git a/mmdet/models/dense_heads/grounding_dino_head.py b/mmdet/models/dense_heads/grounding_dino_head.py
index 80883225..9e9af938 100644
--- a/mmdet/models/dense_heads/grounding_dino_head.py
+++ b/mmdet/models/dense_heads/grounding_dino_head.py
@@ -103,6 +103,7 @@ class GroundingDINOHead(DINOHead):
         self.contrastive_cfg = contrastive_cfg
         self.max_text_len = contrastive_cfg.get('max_text_len', 256)
         super().__init__(**kwargs)
+        self.max_per_img = self.test_cfg.get('max_per_img')
 
     def _init_layers(self) -> None:
         """Initialize classification branch and regression branch of head."""
@@ -413,15 +414,13 @@ class GroundingDINOHead(DINOHead):
                 - bboxes (Tensor): Has a shape (num_instances, 4),
                   the last dimension 4 arrange as (x1, y1, x2, y2).
         """
-        assert len(cls_score) == len(bbox_pred)  # num_queries
-        max_per_img = self.test_cfg.get('max_per_img', len(cls_score))
         img_shape = img_meta['img_shape']
 
         if token_positive_maps is not None:
             cls_score = convert_grounding_to_cls_scores(
                 logits=cls_score.sigmoid()[None],
                 positive_maps=[token_positive_maps])[0]
-            scores, indexes = cls_score.view(-1).topk(max_per_img)
+            scores, indexes = cls_score.view(-1).topk(self.max_per_img)
             num_classes = cls_score.shape[-1]
             det_labels = indexes % num_classes
             bbox_index = indexes // num_classes
@@ -429,7 +428,7 @@ class GroundingDINOHead(DINOHead):
         else:
             cls_score = cls_score.sigmoid()
             scores, _ = cls_score.max(-1)
-            scores, indexes = scores.topk(max_per_img)
+            scores, indexes = scores.topk(self.max_per_img)
             bbox_pred = bbox_pred[indexes]
             det_labels = scores.new_zeros(scores.shape, dtype=torch.long)
 
diff --git a/mmdet/models/detectors/deformable_detr.py b/mmdet/models/detectors/deformable_detr.py
index 0eb5cd2f..a9cf3ffc 100644
--- a/mmdet/models/detectors/deformable_detr.py
+++ b/mmdet/models/detectors/deformable_detr.py
@@ -225,11 +225,15 @@ class DeformableDETR(DetectionTransformer):
             valid_ratios = mlvl_feats[0].new_ones(batch_size, len(mlvl_feats),
                                                   2)
 
+        spatial_shapes_tensor = spatial_shapes
+        spatial_shapes = spatial_shapes.tolist()
+        lvl_pos_embed_flatten = lvl_pos_embed_flatten.half()
         encoder_inputs_dict = dict(
             feat=feat_flatten,
             feat_mask=mask_flatten,
             feat_pos=lvl_pos_embed_flatten,
             spatial_shapes=spatial_shapes,
+            spatial_shapes_tensor=spatial_shapes_tensor,
             level_start_index=level_start_index,
             valid_ratios=valid_ratios)
         decoder_inputs_dict = dict(
diff --git a/mmdet/models/detectors/glip.py b/mmdet/models/detectors/glip.py
index 45cfe7d3..f14e6003 100644
--- a/mmdet/models/detectors/glip.py
+++ b/mmdet/models/detectors/glip.py
@@ -25,13 +25,25 @@ def find_noun_phrases(caption: str) -> list:
         >>> caption = 'There is two cat and a remote in the picture'
         >>> find_noun_phrases(caption) # ['cat', 'a remote', 'the picture']
     """
+    # try:
+    #     import nltk
+    #     nltk.download('punkt', download_dir='~/nltk_data')
+    #     nltk.download('averaged_perceptron_tagger', download_dir='~/nltk_data')
+    # except ImportError:
+    #     raise RuntimeError('nltk is not installed, please install it by: '
+    #                        'pip install nltk.')
+        
+    import nltk 
     try:
-        import nltk
+        nltk.data.find("tokenizers/punkt")
+    except LookupError:
         nltk.download('punkt', download_dir='~/nltk_data')
+
+    try:
+        nltk.data.find("taggers/averaged_perceptron_tagger")
+    except LookupError:
         nltk.download('averaged_perceptron_tagger', download_dir='~/nltk_data')
-    except ImportError:
-        raise RuntimeError('nltk is not installed, please install it by: '
-                           'pip install nltk.')
+
 
     caption = caption.lower()
     tokens = nltk.word_tokenize(caption)
diff --git a/mmdet/models/detectors/grounding_dino.py b/mmdet/models/detectors/grounding_dino.py
index b1ab7c2d..d4a15e98 100644
--- a/mmdet/models/detectors/grounding_dino.py
+++ b/mmdet/models/detectors/grounding_dino.py
@@ -7,17 +7,20 @@ from typing import Dict, Optional, Tuple, Union
 import torch
 import torch.nn as nn
 from mmengine.runner.amp import autocast
+from mmengine.structures import InstanceData
 from torch import Tensor
 
 from mmdet.registry import MODELS
 from mmdet.structures import OptSampleList, SampleList
 from mmdet.utils import ConfigType
+from mmdet.utils import InstanceList
 from ..layers import SinePositionalEncoding
 from ..layers.transformer.grounding_dino_layers import (
     GroundingDinoTransformerDecoder, GroundingDinoTransformerEncoder)
 from .dino import DINO
 from .glip import (create_positive_map, create_positive_map_label_to_token,
                    run_ner)
+from ..language_models.bert import generate_masks_with_special_tokens_and_transfer_map
 
 
 def clean_label_name(name: str) -> str:
@@ -321,22 +324,29 @@ class GroundingDINO(DINO):
         return head_inputs_dict
 
     def forward_encoder(self, feat: Tensor, feat_mask: Tensor,
-                        feat_pos: Tensor, spatial_shapes: Tensor,
+                        feat_pos: Tensor, spatial_shapes: list,
+                        spatial_shapes_tensor: Tensor,
                         level_start_index: Tensor, valid_ratios: Tensor,
                         text_dict: Dict) -> Dict:
         text_token_mask = text_dict['text_token_mask']
+        reference_points = self.encoder.get_encoder_reference_points(
+            spatial_shapes_tensor, valid_ratios, device=feat.device
+        )
+        reference_points = reference_points.half()
         memory, memory_text = self.encoder(
             query=feat,
             query_pos=feat_pos,
             key_padding_mask=feat_mask,  # for self_attn
             spatial_shapes=spatial_shapes,
+            spatial_shapes_tensor=spatial_shapes_tensor,
             level_start_index=level_start_index,
             valid_ratios=valid_ratios,
             # for text encoder
             memory_text=text_dict['embedded'],
             text_attention_mask=~text_token_mask,
             position_ids=text_dict['position_ids'],
-            text_self_attention_masks=text_dict['masks'])
+            text_self_attention_masks=text_dict['masks'],
+            reference_points=reference_points)
         encoder_outputs_dict = dict(
             memory=memory,
             memory_mask=feat_mask,
@@ -358,6 +368,7 @@ class GroundingDINO(DINO):
 
         output_memory, output_proposals = self.gen_encoder_output_proposals(
             memory, memory_mask, spatial_shapes)
+        output_proposals = output_proposals.half()
 
         enc_outputs_class = self.bbox_head.cls_branches[
             self.decoder.num_layers](output_memory, memory_text,
@@ -416,6 +427,20 @@ class GroundingDINO(DINO):
         head_inputs_dict['text_token_mask'] = text_token_mask
         return decoder_inputs_dict, head_inputs_dict
 
+    def tokenizes(self, texts):
+        device = next(self.language_model.language_backbone.parameters()).device
+        tokenized = self.language_model.tokenizer.batch_encode_plus(
+            texts,
+            max_length=self.language_model.max_tokens,
+            padding='max_length' if self.language_model.pad_to_max else 'longest',
+            return_special_tokens_mask=True,
+            return_tensors='pt',
+            truncation=True).to(device)
+        attention_mask, position_ids = \
+                generate_masks_with_special_tokens_and_transfer_map(
+                    tokenized, self.language_model.special_tokens)
+        return tokenized, attention_mask, position_ids
+
     def loss(self, batch_inputs: Tensor,
              batch_data_samples: SampleList) -> Union[dict, list]:
         text_prompts = [
@@ -477,7 +502,8 @@ class GroundingDINO(DINO):
                     positive_maps.append(positive_map)
                     new_text_prompts.append(caption_string)
 
-        text_dict = self.language_model(new_text_prompts)
+        tokenized, attention_mask, position_ids = self.tokenizes(new_text_prompts)
+        text_dict = self.language_model(tokenized, attention_mask, position_ids)
         if self.text_feat_map is not None:
             text_dict['embedded'] = self.text_feat_map(text_dict['embedded'])
 
@@ -501,7 +527,8 @@ class GroundingDINO(DINO):
             **head_inputs_dict, batch_data_samples=batch_data_samples)
         return losses
 
-    def predict(self, batch_inputs, batch_data_samples, rescale: bool = True):
+    def predict(self, batch_inputs, batch_data_samples, tokenized=None, attention_mask=None, position_ids=None, isvisualize=True, rescale: bool = True):
+        batch_inputs = batch_inputs.half()
         text_prompts = []
         enhanced_text_prompts = []
         tokens_positives = []
@@ -553,7 +580,8 @@ class GroundingDINO(DINO):
             for b in range(len(text_prompts[0])):
                 text_prompts_once = [text_prompts[0][b]]
                 token_positive_maps_once = token_positive_maps[0][b]
-                text_dict = self.language_model(text_prompts_once)
+                tokenized, attention_mask, position_ids = self.tokenizes(text_prompts_once)
+                text_dict = self.language_model(tokenized, attention_mask, position_ids)
                 # text feature map layer
                 if self.text_feat_map is not None:
                     text_dict['embedded'] = self.text_feat_map(
@@ -577,7 +605,9 @@ class GroundingDINO(DINO):
             is_rec_tasks = [False] * len(results_list)
         else:
             # extract text feats
-            text_dict = self.language_model(list(text_prompts))
+            if tokenized is None:
+                tokenized, attention_mask, position_ids = self.tokenizes(list(text_prompts))
+            text_dict = self.language_model(tokenized, attention_mask, position_ids)
             # text feature map layer
             if self.text_feat_map is not None:
                 text_dict['embedded'] = self.text_feat_map(
@@ -598,24 +628,27 @@ class GroundingDINO(DINO):
                 rescale=rescale,
                 batch_data_samples=batch_data_samples)
 
-        for data_sample, pred_instances, entity, is_rec_task in zip(
-                batch_data_samples, results_list, entities, is_rec_tasks):
-            if len(pred_instances) > 0:
-                label_names = []
-                for labels in pred_instances.labels:
-                    if is_rec_task:
-                        label_names.append(entity)
-                        continue
-                    if labels >= len(entity):
-                        warnings.warn(
-                            'The unexpected output indicates an issue with '
-                            'named entity recognition. You can try '
-                            'setting custom_entities=True and running '
-                            'again to see if it helps.')
-                        label_names.append('unobject')
-                    else:
-                        label_names.append(entity[labels])
-                # for visualization
-                pred_instances.label_names = label_names
-            data_sample.pred_instances = pred_instances
-        return batch_data_samples
+        if isvisualize:
+            for data_sample, pred_instances, entity, is_rec_task in zip(
+                    batch_data_samples, results_list, entities, is_rec_tasks):
+                if len(pred_instances) > 0:
+                    label_names = []
+                    for labels in pred_instances.labels:
+                        if is_rec_task:
+                            label_names.append(entity)
+                            continue
+                        if labels >= len(entity):
+                            warnings.warn(
+                                'The unexpected output indicates an issue with '
+                                'named entity recognition. You can try '
+                                'setting custom_entities=True and running '
+                                'again to see if it helps.')
+                            label_names.append('unobject')
+                        else:
+                            label_names.append(entity[labels])
+                    # for visualization
+                    pred_instances.label_names = label_names
+                data_sample.pred_instances = pred_instances
+            return batch_data_samples
+        else:
+            return batch_data_samples, results_list
diff --git a/mmdet/models/language_models/bert.py b/mmdet/models/language_models/bert.py
index efb0f46b..a99e6fcd 100644
--- a/mmdet/models/language_models/bert.py
+++ b/mmdet/models/language_models/bert.py
@@ -134,21 +134,10 @@ class BertModel(BaseModel):
             self.special_tokens = self.tokenizer.convert_tokens_to_ids(
                 special_tokens_list)
 
-    def forward(self, captions: Sequence[str], **kwargs) -> dict:
+    def forward(self, tokenized, attention_mask, position_ids, **kwargs) -> dict:
         """Forward function."""
-        device = next(self.language_backbone.parameters()).device
-        tokenized = self.tokenizer.batch_encode_plus(
-            captions,
-            max_length=self.max_tokens,
-            padding='max_length' if self.pad_to_max else 'longest',
-            return_special_tokens_mask=True,
-            return_tensors='pt',
-            truncation=True).to(device)
         input_ids = tokenized.input_ids
         if self.use_sub_sentence_represent:
-            attention_mask, position_ids = \
-                generate_masks_with_special_tokens_and_transfer_map(
-                    tokenized, self.special_tokens)
             token_type_ids = tokenized['token_type_ids']
 
         else:
diff --git a/mmdet/models/layers/transformer/dino_layers.py b/mmdet/models/layers/transformer/dino_layers.py
index 64610d0a..e97733f7 100644
--- a/mmdet/models/layers/transformer/dino_layers.py
+++ b/mmdet/models/layers/transformer/dino_layers.py
@@ -82,12 +82,13 @@ class DinoTransformerDecoder(DeformableDetrTransformerDecoder):
 
             query_sine_embed = coordinate_to_encoding(
                 reference_points_input[:, :, 0, :])
+            query_sine_embed = query_sine_embed.half()
             query_pos = self.ref_point_head(query_sine_embed)
 
             query = layer(
-                query,
+                query.contiguous(),
                 query_pos=query_pos,
-                value=value,
+                value=value.contiguous(),
                 key_padding_mask=key_padding_mask,
                 self_attn_mask=self_attn_mask,
                 spatial_shapes=spatial_shapes,
@@ -98,7 +99,6 @@ class DinoTransformerDecoder(DeformableDetrTransformerDecoder):
 
             if reg_branches is not None:
                 tmp = reg_branches[lid](query)
-                assert reference_points.shape[-1] == 4
                 new_reference_points = tmp + inverse_sigmoid(
                     reference_points, eps=1e-3)
                 new_reference_points = new_reference_points.sigmoid()
diff --git a/mmdet/models/layers/transformer/grounding_dino_layers.py b/mmdet/models/layers/transformer/grounding_dino_layers.py
index 3c285768..e391235c 100644
--- a/mmdet/models/layers/transformer/grounding_dino_layers.py
+++ b/mmdet/models/layers/transformer/grounding_dino_layers.py
@@ -155,29 +155,21 @@ class GroundingDinoTransformerEncoder(DeformableDetrTransformerEncoder):
             for _ in range(self.num_layers)
         ])
         self.embed_dims = self.layers[0].embed_dims
-        if self.num_cp > 0:
-            if checkpoint_wrapper is None:
-                raise NotImplementedError(
-                    'If you want to reduce GPU memory usage, \
-                    please install fairscale by executing the \
-                    following command: pip install fairscale.')
-            for i in range(self.num_cp):
-                self.layers[i] = checkpoint_wrapper(self.layers[i])
-                self.fusion_layers[i] = checkpoint_wrapper(
-                    self.fusion_layers[i])
 
     def forward(self,
                 query: Tensor,
                 query_pos: Tensor,
                 key_padding_mask: Tensor,
-                spatial_shapes: Tensor,
+                spatial_shapes: list,
+                spatial_shapes_tensor: Tensor,
                 level_start_index: Tensor,
                 valid_ratios: Tensor,
                 memory_text: Tensor = None,
                 text_attention_mask: Tensor = None,
                 pos_text: Tensor = None,
                 text_self_attention_masks: Tensor = None,
-                position_ids: Tensor = None):
+                position_ids: Tensor = None,
+                reference_points: Tensor = None):
         """Forward function of Transformer encoder.
 
         Args:
@@ -206,8 +198,6 @@ class GroundingDinoTransformerEncoder(DeformableDetrTransformerEncoder):
                 Defaults to None.
         """
         output = query
-        reference_points = self.get_encoder_reference_points(
-            spatial_shapes, valid_ratios, device=query.device)
         if self.text_layers:
             # generate pos_text
             bs, n_text, _ = memory_text.shape
@@ -223,6 +213,7 @@ class GroundingDinoTransformerEncoder(DeformableDetrTransformerEncoder):
                     position_ids[..., None],
                     num_pos_feats=256,
                     exchange_xy=False)
+        pos_text = pos_text.half()
 
         # main process
         for layer_id, layer in enumerate(self.layers):
@@ -248,6 +239,7 @@ class GroundingDinoTransformerEncoder(DeformableDetrTransformerEncoder):
                 query_pos=query_pos,
                 reference_points=reference_points,
                 spatial_shapes=spatial_shapes,
+                spatial_shapes_tensor=spatial_shapes_tensor,
                 level_start_index=level_start_index,
                 key_padding_mask=key_padding_mask)
         return output, memory_text
diff --git a/mmdet/visualization/local_visualizer.py b/mmdet/visualization/local_visualizer.py
index cc6521c5..588bec81 100644
--- a/mmdet/visualization/local_visualizer.py
+++ b/mmdet/visualization/local_visualizer.py
@@ -124,7 +124,7 @@ class DetLocalVisualizer(Visualizer):
         self.set_image(image)
 
         if 'bboxes' in instances and instances.bboxes.sum() > 0:
-            bboxes = instances.bboxes
+            bboxes = instances.bboxes.to(torch.float32)
             labels = instances.labels
 
             max_label = int(max(labels) if len(labels) > 0 else 0)