27c27
< from ..modeling_outputs import CausalLMOutputWithPast, Seq2SeqLMOutput
> from ..modeling_outputs import CausalLMOutput, CausalLMOutputWithPast, Seq2SeqLMOutput
2179,2184c2179,2198
< outputs = self(
< **model_inputs,
< return_dict=True,
< output_attentions=output_attentions,
< output_hidden_states=output_hidden_states,
< )
> if 'generate_om' in dir(self):
> logits = torch.from_numpy(
> self.generate_om.infer(
> [
> model_inputs['input_ids'].numpy(),
> model_inputs['attention_mask'].numpy(),
> model_inputs['encoder_hidden_states'].numpy(),
> model_inputs['encoder_attention_mask'].numpy(),
> ],
> mode='dymdims',
> )[0]
> )
> outputs = CausalLMOutput(logits=logits)
> else:
> outputs = self(
> **model_inputs,
> return_dict=True,
> output_attentions=output_attentions,
> output_hidden_states=output_hidden_states,
> )