707,709c707,712
<         model_kwargs["past_key_values"] = self._extract_past_from_model_output(
<             outputs, standardize_cache_format=standardize_cache_format
<         )
---
>         if 'om' in dir(self):
>             model_kwargs["past_key_values"] = None
>         else:
>             model_kwargs["past_key_values"] = self._extract_past_from_model_output(
>                 outputs, standardize_cache_format=standardize_cache_format
>             )
2722,2727c2725,2741
<             outputs = self(
<                 **model_inputs,
<                 return_dict=True,
<                 output_attentions=output_attentions,
<                 output_hidden_states=output_hidden_states,
<             )
---
>             if 'om' in dir(self):
>                 outputs = self.om.infer(
>                     [
>                         model_inputs['input_ids'].numpy(),
>                         model_inputs['attention_mask'].numpy(),
>                         model_inputs['encoder_hidden_states'].numpy()
>                     ],
>                     mode=self.om.mode,
>                     custom_sizes=100000000
>                 )
>             else:
>                 outputs = self(
>                     **model_inputs,
>                     return_dict=True,
>                     output_attentions=output_attentions,
>                     output_hidden_states=output_hidden_states,
>                 )
2733c2747,2750
<             next_token_logits = outputs.logits[:, -1, :]
---
>             if 'om' in dir(self):
>                 next_token_logits = torch.from_numpy(outputs[0])[:, -1, :]
>             else:
>                 next_token_logits = outputs.logits[:, -1, :]