27c27
< from ..modeling_outputs import CausalLMOutputWithPast, Seq2SeqLMOutput
---
> from ..modeling_outputs import CausalLMOutput, CausalLMOutputWithPast, Seq2SeqLMOutput
2179,2184c2179,2198
<             outputs = self(
<                 **model_inputs,
<                 return_dict=True,
<                 output_attentions=output_attentions,
<                 output_hidden_states=output_hidden_states,
<             )
---
>             if 'generate_om' in dir(self):
>                 logits = torch.from_numpy(
>                     self.generate_om.infer(
>                         [
>                             model_inputs['input_ids'].numpy(),
>                             model_inputs['attention_mask'].numpy(),
>                             model_inputs['encoder_hidden_states'].numpy(),
>                             model_inputs['encoder_attention_mask'].numpy(),
>                         ],
>                         mode='dymdims',
>                     )[0]
>                 )
>                 outputs = CausalLMOutput(logits=logits)
>             else:
>                 outputs = self(
>                     **model_inputs,
>                     return_dict=True,
>                     output_attentions=output_attentions,
>                     output_hidden_states=output_hidden_states,
>                 )