707,709c707,712
< model_kwargs["past_key_values"] = self._extract_past_from_model_output(
< outputs, standardize_cache_format=standardize_cache_format
< )
> if 'om' in dir(self):
> model_kwargs["past_key_values"] = None
> else:
> model_kwargs["past_key_values"] = self._extract_past_from_model_output(
> outputs, standardize_cache_format=standardize_cache_format
> )
2722,2727c2725,2741
< outputs = self(
< **model_inputs,
< return_dict=True,
< output_attentions=output_attentions,
< output_hidden_states=output_hidden_states,
< )
> if 'om' in dir(self):
> outputs = self.om.infer(
> [
> model_inputs['input_ids'].numpy(),
> model_inputs['attention_mask'].numpy(),
> model_inputs['encoder_hidden_states'].numpy()
> ],
> mode=self.om.mode,
> custom_sizes=100000000
> )
> else:
> outputs = self(
> **model_inputs,
> return_dict=True,
> output_attentions=output_attentions,
> output_hidden_states=output_hidden_states,
> )
2733c2747,2750
< next_token_logits = outputs.logits[:, -1, :]
> if 'om' in dir(self):
> next_token_logits = torch.from_numpy(outputs[0])[:, -1, :]
> else:
> next_token_logits = outputs.logits[:, -1, :]