@@ -103,6 +103,7 @@ class GroundingDINOHead(DINOHead):
self.contrastive_cfg = contrastive_cfg
self.max_text_len = contrastive_cfg.get('max_text_len', 256)
super().__init__(**kwargs)
+ self.max_per_img = self.test_cfg.get('max_per_img')
def _init_layers(self) -> None:
"""Initialize classification branch and regression branch of head."""
@@ -413,15 +414,13 @@ class GroundingDINOHead(DINOHead):
- bboxes (Tensor): Has a shape (num_instances, 4),
the last dimension 4 arrange as (x1, y1, x2, y2).
"""
- assert len(cls_score) == len(bbox_pred) # num_queries
- max_per_img = self.test_cfg.get('max_per_img', len(cls_score))
img_shape = img_meta['img_shape']
if token_positive_maps is not None:
cls_score = convert_grounding_to_cls_scores(
logits=cls_score.sigmoid()[None],
positive_maps=[token_positive_maps])[0]
- scores, indexes = cls_score.view(-1).topk(max_per_img)
+ scores, indexes = cls_score.view(-1).topk(self.max_per_img)
num_classes = cls_score.shape[-1]
det_labels = indexes % num_classes
bbox_index = indexes // num_classes
@@ -429,7 +428,7 @@ class GroundingDINOHead(DINOHead):
else:
cls_score = cls_score.sigmoid()
scores, _ = cls_score.max(-1)
- scores, indexes = scores.topk(max_per_img)
+ scores, indexes = scores.topk(self.max_per_img)
bbox_pred = bbox_pred[indexes]
det_labels = scores.new_zeros(scores.shape, dtype=torch.long)
@@ -225,11 +225,15 @@ class DeformableDETR(DetectionTransformer):
valid_ratios = mlvl_feats[0].new_ones(batch_size, len(mlvl_feats),
2)
+ spatial_shapes_tensor = spatial_shapes
+ spatial_shapes = spatial_shapes.tolist()
+ lvl_pos_embed_flatten = lvl_pos_embed_flatten.half()
encoder_inputs_dict = dict(
feat=feat_flatten,
feat_mask=mask_flatten,
feat_pos=lvl_pos_embed_flatten,
spatial_shapes=spatial_shapes,
+ spatial_shapes_tensor=spatial_shapes_tensor,
level_start_index=level_start_index,
valid_ratios=valid_ratios)
decoder_inputs_dict = dict(
@@ -25,13 +25,25 @@ def find_noun_phrases(caption: str) -> list:
>>> caption = 'There is two cat and a remote in the picture'
>>> find_noun_phrases(caption) # ['cat', 'a remote', 'the picture']
"""
+ # try:
+ # import nltk
+ # nltk.download('punkt', download_dir='~/nltk_data')
+ # nltk.download('averaged_perceptron_tagger', download_dir='~/nltk_data')
+ # except ImportError:
+ # raise RuntimeError('nltk is not installed, please install it by: '
+ # 'pip install nltk.')
+
+ import nltk
try:
- import nltk
+ nltk.data.find("tokenizers/punkt")
+ except LookupError:
nltk.download('punkt', download_dir='~/nltk_data')
+
+ try:
+ nltk.data.find("taggers/averaged_perceptron_tagger")
+ except LookupError:
nltk.download('averaged_perceptron_tagger', download_dir='~/nltk_data')
- except ImportError:
- raise RuntimeError('nltk is not installed, please install it by: '
- 'pip install nltk.')
+
caption = caption.lower()
tokens = nltk.word_tokenize(caption)
@@ -7,17 +7,20 @@ from typing import Dict, Optional, Tuple, Union
import torch
import torch.nn as nn
from mmengine.runner.amp import autocast
+from mmengine.structures import InstanceData
from torch import Tensor
from mmdet.registry import MODELS
from mmdet.structures import OptSampleList, SampleList
from mmdet.utils import ConfigType
+from mmdet.utils import InstanceList
from ..layers import SinePositionalEncoding
from ..layers.transformer.grounding_dino_layers import (
GroundingDinoTransformerDecoder, GroundingDinoTransformerEncoder)
from .dino import DINO
from .glip import (create_positive_map, create_positive_map_label_to_token,
run_ner)
+from ..language_models.bert import generate_masks_with_special_tokens_and_transfer_map
def clean_label_name(name: str) -> str:
@@ -321,22 +324,29 @@ class GroundingDINO(DINO):
return head_inputs_dict
def forward_encoder(self, feat: Tensor, feat_mask: Tensor,
- feat_pos: Tensor, spatial_shapes: Tensor,
+ feat_pos: Tensor, spatial_shapes: list,
+ spatial_shapes_tensor: Tensor,
level_start_index: Tensor, valid_ratios: Tensor,
text_dict: Dict) -> Dict:
text_token_mask = text_dict['text_token_mask']
+ reference_points = self.encoder.get_encoder_reference_points(
+ spatial_shapes_tensor, valid_ratios, device=feat.device
+ )
+ reference_points = reference_points.half()
memory, memory_text = self.encoder(
query=feat,
query_pos=feat_pos,
key_padding_mask=feat_mask, # for self_attn
spatial_shapes=spatial_shapes,
+ spatial_shapes_tensor=spatial_shapes_tensor,
level_start_index=level_start_index,
valid_ratios=valid_ratios,
# for text encoder
memory_text=text_dict['embedded'],
text_attention_mask=~text_token_mask,
position_ids=text_dict['position_ids'],
- text_self_attention_masks=text_dict['masks'])
+ text_self_attention_masks=text_dict['masks'],
+ reference_points=reference_points)
encoder_outputs_dict = dict(
memory=memory,
memory_mask=feat_mask,
@@ -358,6 +368,7 @@ class GroundingDINO(DINO):
output_memory, output_proposals = self.gen_encoder_output_proposals(
memory, memory_mask, spatial_shapes)
+ output_proposals = output_proposals.half()
enc_outputs_class = self.bbox_head.cls_branches[
self.decoder.num_layers](output_memory, memory_text,
@@ -416,6 +427,20 @@ class GroundingDINO(DINO):
head_inputs_dict['text_token_mask'] = text_token_mask
return decoder_inputs_dict, head_inputs_dict
+ def tokenizes(self, texts):
+ device = next(self.language_model.language_backbone.parameters()).device
+ tokenized = self.language_model.tokenizer.batch_encode_plus(
+ texts,
+ max_length=self.language_model.max_tokens,
+ padding='max_length' if self.language_model.pad_to_max else 'longest',
+ return_special_tokens_mask=True,
+ return_tensors='pt',
+ truncation=True).to(device)
+ attention_mask, position_ids = \
+ generate_masks_with_special_tokens_and_transfer_map(
+ tokenized, self.language_model.special_tokens)
+ return tokenized, attention_mask, position_ids
+
def loss(self, batch_inputs: Tensor,
batch_data_samples: SampleList) -> Union[dict, list]:
text_prompts = [
@@ -477,7 +502,8 @@ class GroundingDINO(DINO):
positive_maps.append(positive_map)
new_text_prompts.append(caption_string)
- text_dict = self.language_model(new_text_prompts)
+ tokenized, attention_mask, position_ids = self.tokenizes(new_text_prompts)
+ text_dict = self.language_model(tokenized, attention_mask, position_ids)
if self.text_feat_map is not None:
text_dict['embedded'] = self.text_feat_map(text_dict['embedded'])
@@ -501,7 +527,8 @@ class GroundingDINO(DINO):
**head_inputs_dict, batch_data_samples=batch_data_samples)
return losses
- def predict(self, batch_inputs, batch_data_samples, rescale: bool = True):
+ def predict(self, batch_inputs, batch_data_samples, tokenized=None, attention_mask=None, position_ids=None, isvisualize=True, rescale: bool = True):
+ batch_inputs = batch_inputs.half()
text_prompts = []
enhanced_text_prompts = []
tokens_positives = []
@@ -553,7 +580,8 @@ class GroundingDINO(DINO):
for b in range(len(text_prompts[0])):
text_prompts_once = [text_prompts[0][b]]
token_positive_maps_once = token_positive_maps[0][b]
- text_dict = self.language_model(text_prompts_once)
+ tokenized, attention_mask, position_ids = self.tokenizes(text_prompts_once)
+ text_dict = self.language_model(tokenized, attention_mask, position_ids)
# text feature map layer
if self.text_feat_map is not None:
text_dict['embedded'] = self.text_feat_map(
@@ -577,7 +605,9 @@ class GroundingDINO(DINO):
is_rec_tasks = [False] * len(results_list)
else:
# extract text feats
- text_dict = self.language_model(list(text_prompts))
+ if tokenized is None:
+ tokenized, attention_mask, position_ids = self.tokenizes(list(text_prompts))
+ text_dict = self.language_model(tokenized, attention_mask, position_ids)
# text feature map layer
if self.text_feat_map is not None:
text_dict['embedded'] = self.text_feat_map(
@@ -598,24 +628,27 @@ class GroundingDINO(DINO):
rescale=rescale,
batch_data_samples=batch_data_samples)
- for data_sample, pred_instances, entity, is_rec_task in zip(
- batch_data_samples, results_list, entities, is_rec_tasks):
- if len(pred_instances) > 0:
- label_names = []
- for labels in pred_instances.labels:
- if is_rec_task:
- label_names.append(entity)
- continue
- if labels >= len(entity):
- warnings.warn(
- 'The unexpected output indicates an issue with '
- 'named entity recognition. You can try '
- 'setting custom_entities=True and running '
- 'again to see if it helps.')
- label_names.append('unobject')
- else:
- label_names.append(entity[labels])
- # for visualization
- pred_instances.label_names = label_names
- data_sample.pred_instances = pred_instances
- return batch_data_samples
+ if isvisualize:
+ for data_sample, pred_instances, entity, is_rec_task in zip(
+ batch_data_samples, results_list, entities, is_rec_tasks):
+ if len(pred_instances) > 0:
+ label_names = []
+ for labels in pred_instances.labels:
+ if is_rec_task:
+ label_names.append(entity)
+ continue
+ if labels >= len(entity):
+ warnings.warn(
+ 'The unexpected output indicates an issue with '
+ 'named entity recognition. You can try '
+ 'setting custom_entities=True and running '
+ 'again to see if it helps.')
+ label_names.append('unobject')
+ else:
+ label_names.append(entity[labels])
+ # for visualization
+ pred_instances.label_names = label_names
+ data_sample.pred_instances = pred_instances
+ return batch_data_samples
+ else:
+ return batch_data_samples, results_list
@@ -134,21 +134,10 @@ class BertModel(BaseModel):
self.special_tokens = self.tokenizer.convert_tokens_to_ids(
special_tokens_list)
- def forward(self, captions: Sequence[str], **kwargs) -> dict:
+ def forward(self, tokenized, attention_mask, position_ids, **kwargs) -> dict:
"""Forward function."""
- device = next(self.language_backbone.parameters()).device
- tokenized = self.tokenizer.batch_encode_plus(
- captions,
- max_length=self.max_tokens,
- padding='max_length' if self.pad_to_max else 'longest',
- return_special_tokens_mask=True,
- return_tensors='pt',
- truncation=True).to(device)
input_ids = tokenized.input_ids
if self.use_sub_sentence_represent:
- attention_mask, position_ids = \
- generate_masks_with_special_tokens_and_transfer_map(
- tokenized, self.special_tokens)
token_type_ids = tokenized['token_type_ids']
else:
@@ -82,12 +82,13 @@ class DinoTransformerDecoder(DeformableDetrTransformerDecoder):
query_sine_embed = coordinate_to_encoding(
reference_points_input[:, :, 0, :])
+ query_sine_embed = query_sine_embed.half()
query_pos = self.ref_point_head(query_sine_embed)
query = layer(
- query,
+ query.contiguous(),
query_pos=query_pos,
- value=value,
+ value=value.contiguous(),
key_padding_mask=key_padding_mask,
self_attn_mask=self_attn_mask,
spatial_shapes=spatial_shapes,
@@ -98,7 +99,6 @@ class DinoTransformerDecoder(DeformableDetrTransformerDecoder):
if reg_branches is not None:
tmp = reg_branches[lid](query)
- assert reference_points.shape[-1] == 4
new_reference_points = tmp + inverse_sigmoid(
reference_points, eps=1e-3)
new_reference_points = new_reference_points.sigmoid()
@@ -155,29 +155,21 @@ class GroundingDinoTransformerEncoder(DeformableDetrTransformerEncoder):
for _ in range(self.num_layers)
])
self.embed_dims = self.layers[0].embed_dims
- if self.num_cp > 0:
- if checkpoint_wrapper is None:
- raise NotImplementedError(
- 'If you want to reduce GPU memory usage, \
- please install fairscale by executing the \
- following command: pip install fairscale.')
- for i in range(self.num_cp):
- self.layers[i] = checkpoint_wrapper(self.layers[i])
- self.fusion_layers[i] = checkpoint_wrapper(
- self.fusion_layers[i])
def forward(self,
query: Tensor,
query_pos: Tensor,
key_padding_mask: Tensor,
- spatial_shapes: Tensor,
+ spatial_shapes: list,
+ spatial_shapes_tensor: Tensor,
level_start_index: Tensor,
valid_ratios: Tensor,
memory_text: Tensor = None,
text_attention_mask: Tensor = None,
pos_text: Tensor = None,
text_self_attention_masks: Tensor = None,
- position_ids: Tensor = None):
+ position_ids: Tensor = None,
+ reference_points: Tensor = None):
"""Forward function of Transformer encoder.
Args:
@@ -206,8 +198,6 @@ class GroundingDinoTransformerEncoder(DeformableDetrTransformerEncoder):
Defaults to None.
"""
output = query
- reference_points = self.get_encoder_reference_points(
- spatial_shapes, valid_ratios, device=query.device)
if self.text_layers:
# generate pos_text
bs, n_text, _ = memory_text.shape
@@ -223,6 +213,7 @@ class GroundingDinoTransformerEncoder(DeformableDetrTransformerEncoder):
position_ids[..., None],
num_pos_feats=256,
exchange_xy=False)
+ pos_text = pos_text.half()
# main process
for layer_id, layer in enumerate(self.layers):
@@ -248,6 +239,7 @@ class GroundingDinoTransformerEncoder(DeformableDetrTransformerEncoder):
query_pos=query_pos,
reference_points=reference_points,
spatial_shapes=spatial_shapes,
+ spatial_shapes_tensor=spatial_shapes_tensor,
level_start_index=level_start_index,
key_padding_mask=key_padding_mask)
return output, memory_text
@@ -124,7 +124,7 @@ class DetLocalVisualizer(Visualizer):
self.set_image(image)
if 'bboxes' in instances and instances.bboxes.sum() > 0:
- bboxes = instances.bboxes
+ bboxes = instances.bboxes.to(torch.float32)
labels = instances.labels
max_label = int(max(labels) if len(labels) > 0 else 0)