data/vqa_dataset.py |  5 +++-
 models/blip_vqa.py  | 72 ++++++++++++++++++++++++++++++++++++---------
 models/med.py       |  4 +--
 3 files changed, 64 insertions(+), 17 deletions(-)

diff --git a/data/vqa_dataset.py b/data/vqa_dataset.py
index 92ec1df..c22db69 100644
--- a/data/vqa_dataset.py
+++ b/data/vqa_dataset.py
@@ -9,6 +9,8 @@ from data.utils import pre_question
 
 from torchvision.datasets.utils import download_url
 
+import numpy as np
+
 class vqa_dataset(Dataset):
     def __init__(self, transform, ann_root, vqa_root, vg_root, train_files=[], split="train"):
         self.split = split        
@@ -48,11 +50,12 @@ class vqa_dataset(Dataset):
             
         image = Image.open(image_path).convert('RGB')   
         image = self.transform(image)          
+        image = np.asarray(image)
         
         if self.split == 'test':
             question = pre_question(ann['question'])   
             question_id = ann['question_id']            
-            return image, question, question_id
+            return image_path, image, question, question_id
 
 
         elif self.split=='train':                       
diff --git a/models/blip_vqa.py b/models/blip_vqa.py
index d4cb368..0f7a4db 100644
--- a/models/blip_vqa.py
+++ b/models/blip_vqa.py
@@ -36,10 +36,15 @@ class BLIP_VQA(nn.Module):
 
     def forward(self, image, question, answer=None, n=None, weights=None, train=True, inference='rank', k_test=128):
         
-        image_embeds = self.visual_encoder(image) 
+        if 'om' in dir(self.visual_encoder):
+            image_embeds = torch.from_numpy(
+                self.visual_encoder.om.infer([image.numpy()])[0]
+            )
+        else:
+            image_embeds = self.visual_encoder(image)
         image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device)
         
-        question = self.tokenizer(question, padding='longest', truncation=True, max_length=35, 
+        question = self.tokenizer(question, padding='max_length', truncation=True, max_length=35,
                                   return_tensors="pt").to(image.device) 
         question.input_ids[:,0] = self.tokenizer.enc_token_id
         
@@ -82,15 +87,27 @@ class BLIP_VQA(nn.Module):
             
 
         else: 
-            question_output = self.text_encoder(question.input_ids, 
-                                                attention_mask = question.attention_mask, 
-                                                encoder_hidden_states = image_embeds,
-                                                encoder_attention_mask = image_atts,                                    
-                                                return_dict = True) 
+            if 'om' in dir(self.text_encoder):
+                question_states = torch.from_numpy(
+                    self.text_encoder.om.infer(
+                        [
+                            question.input_ids.numpy(),
+                            question.attention_mask.numpy(),
+                            image_embeds.numpy(),
+                        ]
+                    )[0]
+                )
+            else:
+                question_output = self.text_encoder(question.input_ids,
+                                                    attention_mask = question.attention_mask,
+                                                    encoder_hidden_states = image_embeds,
+                                                    encoder_attention_mask = image_atts,
+                                                    return_dict = True)
+                question_states = question_output.last_hidden_state
             
             if inference=='generate':
-                num_beams = 3
-                question_states = question_output.last_hidden_state.repeat_interleave(num_beams,dim=0)
+                num_beams = 1
+                question_states = question_states.repeat_interleave(num_beams,dim=0)
                 question_atts = torch.ones(question_states.size()[:-1],dtype=torch.long).to(question_states.device)
                 model_kwargs = {"encoder_hidden_states": question_states, "encoder_attention_mask":question_atts}
                 
@@ -111,7 +128,7 @@ class BLIP_VQA(nn.Module):
                 return answers
             
             elif inference=='rank':
-                max_ids = self.rank_answer(question_output.last_hidden_state, question.attention_mask, 
+                max_ids = self.rank_answer(question_states, question.attention_mask,
                                            answer.input_ids, answer.attention_mask, k_test) 
                 return max_ids
  
@@ -122,12 +139,24 @@ class BLIP_VQA(nn.Module):
         num_ques = question_states.size(0)
         start_ids = answer_ids[0,0].repeat(num_ques,1) # bos token
         
-        start_output = self.text_decoder(start_ids, 
+        if 'rank_1_om' in dir(self.text_decoder):
+            logits = torch.from_numpy(
+                self.text_decoder.rank_1_om.infer(
+                    [
+                        start_ids.numpy(),
+                        question_states.numpy(),
+                        question_atts.numpy()
+                    ]
+                )[0]
+            )
+        else:
+            start_output = self.text_decoder(start_ids,
                                          encoder_hidden_states = question_states,
                                          encoder_attention_mask = question_atts,                                      
                                          return_dict = True,
                                          reduction = 'none')              
-        logits = start_output.logits[:,0,:] # first token's logit
+            logits = start_output.logits
+        logits = logits[:,0,:] # first token's logit
         
         # topk_probs: top-k probability 
         # topk_ids: [num_question, k]        
@@ -150,15 +179,30 @@ class BLIP_VQA(nn.Module):
         question_states = tile(question_states, 0, k)
         question_atts = tile(question_atts, 0, k)
         
-        output = self.text_decoder(input_ids, 
+        if 'rank_2_om' in dir(self.text_decoder):
+            output = torch.from_numpy(
+                self.text_decoder.rank_2_om.infer(
+                    [
+                        input_ids.numpy(),
+                        input_atts.numpy(),
+                        question_states.numpy(),
+                        question_atts.numpy(),
+                        targets_ids.numpy(),
+                    ]
+                )[0]
+            )
+            loss = output
+        else:
+            output = self.text_decoder(input_ids,
                                    attention_mask = input_atts, 
                                    encoder_hidden_states = question_states,
                                    encoder_attention_mask = question_atts,     
                                    labels = targets_ids,
                                    return_dict = True, 
                                    reduction = 'none')   
+            loss = output.loss
         
-        log_probs_sum = -output.loss
+        log_probs_sum = -loss
         log_probs_sum = log_probs_sum.view(num_ques,k)
 
         max_topk_ids = log_probs_sum.argmax(dim=1) 
diff --git a/models/med.py b/models/med.py
index 7b00a35..6379967 100644
--- a/models/med.py
+++ b/models/med.py
@@ -904,14 +904,14 @@ class BertLMHeadModel(BertPreTrainedModel):
         prediction_scores = self.cls(sequence_output)
         
         if return_logits:
-            return prediction_scores[:, :-1, :].contiguous()  
+            return prediction_scores
 
         lm_loss = None
         if labels is not None:
             # we are doing next-token prediction; shift prediction scores and input ids by one
             shifted_prediction_scores = prediction_scores[:, :-1, :].contiguous()
             labels = labels[:, 1:].contiguous()
-            loss_fct = CrossEntropyLoss(reduction=reduction, label_smoothing=0.1) 
+            loss_fct = CrossEntropyLoss(reduction=reduction)
             lm_loss = loss_fct(shifted_prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
             if reduction=='none':
                 lm_loss = lm_loss.view(prediction_scores.size(0),-1).sum(1)