data/vqa_dataset.py | 5 +++-
models/blip_vqa.py | 72 ++++++++++++++++++++++++++++++++++++---------
models/med.py | 4 +--
3 files changed, 64 insertions(+), 17 deletions(-)
@@ -9,6 +9,8 @@ from data.utils import pre_question
from torchvision.datasets.utils import download_url
+import numpy as np
+
class vqa_dataset(Dataset):
def __init__(self, transform, ann_root, vqa_root, vg_root, train_files=[], split="train"):
self.split = split
@@ -48,11 +50,12 @@ class vqa_dataset(Dataset):
image = Image.open(image_path).convert('RGB')
image = self.transform(image)
+ image = np.asarray(image)
if self.split == 'test':
question = pre_question(ann['question'])
question_id = ann['question_id']
- return image, question, question_id
+ return image_path, image, question, question_id
elif self.split=='train':
@@ -36,10 +36,15 @@ class BLIP_VQA(nn.Module):
def forward(self, image, question, answer=None, n=None, weights=None, train=True, inference='rank', k_test=128):
- image_embeds = self.visual_encoder(image)
+ if 'om' in dir(self.visual_encoder):
+ image_embeds = torch.from_numpy(
+ self.visual_encoder.om.infer([image.numpy()])[0]
+ )
+ else:
+ image_embeds = self.visual_encoder(image)
image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device)
- question = self.tokenizer(question, padding='longest', truncation=True, max_length=35,
+ question = self.tokenizer(question, padding='max_length', truncation=True, max_length=35,
return_tensors="pt").to(image.device)
question.input_ids[:,0] = self.tokenizer.enc_token_id
@@ -82,15 +87,27 @@ class BLIP_VQA(nn.Module):
else:
- question_output = self.text_encoder(question.input_ids,
- attention_mask = question.attention_mask,
- encoder_hidden_states = image_embeds,
- encoder_attention_mask = image_atts,
- return_dict = True)
+ if 'om' in dir(self.text_encoder):
+ question_states = torch.from_numpy(
+ self.text_encoder.om.infer(
+ [
+ question.input_ids.numpy(),
+ question.attention_mask.numpy(),
+ image_embeds.numpy(),
+ ]
+ )[0]
+ )
+ else:
+ question_output = self.text_encoder(question.input_ids,
+ attention_mask = question.attention_mask,
+ encoder_hidden_states = image_embeds,
+ encoder_attention_mask = image_atts,
+ return_dict = True)
+ question_states = question_output.last_hidden_state
if inference=='generate':
- num_beams = 3
- question_states = question_output.last_hidden_state.repeat_interleave(num_beams,dim=0)
+ num_beams = 1
+ question_states = question_states.repeat_interleave(num_beams,dim=0)
question_atts = torch.ones(question_states.size()[:-1],dtype=torch.long).to(question_states.device)
model_kwargs = {"encoder_hidden_states": question_states, "encoder_attention_mask":question_atts}
@@ -111,7 +128,7 @@ class BLIP_VQA(nn.Module):
return answers
elif inference=='rank':
- max_ids = self.rank_answer(question_output.last_hidden_state, question.attention_mask,
+ max_ids = self.rank_answer(question_states, question.attention_mask,
answer.input_ids, answer.attention_mask, k_test)
return max_ids
@@ -122,12 +139,24 @@ class BLIP_VQA(nn.Module):
num_ques = question_states.size(0)
start_ids = answer_ids[0,0].repeat(num_ques,1) # bos token
- start_output = self.text_decoder(start_ids,
+ if 'rank_1_om' in dir(self.text_decoder):
+ logits = torch.from_numpy(
+ self.text_decoder.rank_1_om.infer(
+ [
+ start_ids.numpy(),
+ question_states.numpy(),
+ question_atts.numpy()
+ ]
+ )[0]
+ )
+ else:
+ start_output = self.text_decoder(start_ids,
encoder_hidden_states = question_states,
encoder_attention_mask = question_atts,
return_dict = True,
reduction = 'none')
- logits = start_output.logits[:,0,:] # first token's logit
+ logits = start_output.logits
+ logits = logits[:,0,:] # first token's logit
# topk_probs: top-k probability
# topk_ids: [num_question, k]
@@ -150,15 +179,30 @@ class BLIP_VQA(nn.Module):
question_states = tile(question_states, 0, k)
question_atts = tile(question_atts, 0, k)
- output = self.text_decoder(input_ids,
+ if 'rank_2_om' in dir(self.text_decoder):
+ output = torch.from_numpy(
+ self.text_decoder.rank_2_om.infer(
+ [
+ input_ids.numpy(),
+ input_atts.numpy(),
+ question_states.numpy(),
+ question_atts.numpy(),
+ targets_ids.numpy(),
+ ]
+ )[0]
+ )
+ loss = output
+ else:
+ output = self.text_decoder(input_ids,
attention_mask = input_atts,
encoder_hidden_states = question_states,
encoder_attention_mask = question_atts,
labels = targets_ids,
return_dict = True,
reduction = 'none')
+ loss = output.loss
- log_probs_sum = -output.loss
+ log_probs_sum = -loss
log_probs_sum = log_probs_sum.view(num_ques,k)
max_topk_ids = log_probs_sum.argmax(dim=1)
@@ -904,14 +904,14 @@ class BertLMHeadModel(BertPreTrainedModel):
prediction_scores = self.cls(sequence_output)
if return_logits:
- return prediction_scores[:, :-1, :].contiguous()
+ return prediction_scores
lm_loss = None
if labels is not None:
# we are doing next-token prediction; shift prediction scores and input ids by one
shifted_prediction_scores = prediction_scores[:, :-1, :].contiguous()
labels = labels[:, 1:].contiguous()
- loss_fct = CrossEntropyLoss(reduction=reduction, label_smoothing=0.1)
+ loss_fct = CrossEntropyLoss(reduction=reduction)
lm_loss = loss_fct(shifted_prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
if reduction=='none':
lm_loss = lm_loss.view(prediction_scores.size(0),-1).sum(1)