"""visualglm text generator adaptor for lite inference."""
import time
from typing import List, Union, Optional
import numpy as np
from mindspore_lite import Model
from mindformers.generation.streamers import BaseStreamer
from mindformers.inference import InferTask
from mindformers.inference.infers.text_generator_infer import TextGeneratorInfer, InputOfInfer, BaseInputsOfInfer
def register_task():
"""register task for visualglm. """
InputOfInfer.MAPPING["visualglm"] = VisualGlmInputsOfInfer
InferTask.task_mapping["visualglm_generation"] = VisualGlmGeneratorInfer
class VisualGlmInputsOfInfer(BaseInputsOfInfer):
"""
VisualGlmInputsOfInfer
"""
def get_inputs(self, model: Model, input_ids=None, input_embeddings=None, current_index=None, valid_length=None,
init_reset=None, tokenizer=None, is_first_iteration=True,
attention_mask=None, position_ids=None, **kwargs):
"""
get input for lite
"""
del tokenizer, kwargs
position_ids = position_ids.asnumpy()
attention_mask = attention_mask.asnumpy()
if not is_first_iteration:
inputs_tmp = []
position_ids_tmp = []
attention_mask_tmp = []
for i in range(len(current_index)):
current_index_tmp = int(current_index[i]) - i * input_ids.shape[1]
inputs_tmp.append(input_ids[i][current_index_tmp:current_index_tmp + 1])
position_ids_tmp.append(position_ids[i][:, current_index_tmp:current_index_tmp + 1])
attention_mask_tmp.append(attention_mask[i][:, current_index_tmp:current_index_tmp + 1, :])
input_ids = np.array(inputs_tmp, np.int32)
position_ids = np.array(position_ids_tmp, np.int32)
attention_mask = np.array(attention_mask_tmp, np.int32)
inputs = [input_ids, None, position_ids, attention_mask,
current_index, init_reset, valid_length]
else:
input_embeddings = input_embeddings.asnumpy()
inputs = [input_embeddings, input_ids, None, position_ids, attention_mask,
current_index, init_reset, valid_length]
lite_inputs = self.get_lite_tensor_list(inputs, model)
return lite_inputs
class VisualGlmGeneratorInfer(TextGeneratorInfer):
"""
VisualGlm generator infer implement class.
"""
def infer(self,
inputs: Optional[Union[List[int], List[List[int]]]],
do_sample: bool = False,
top_k: int = 1,
top_p: float = 1.0,
temperature: float = 1.0,
repetition_penalty: float = 1.0,
eos_token_id: int = 2,
pad_token_id: int = 0,
max_length: int = 256,
is_sample_acceleration: bool = False,
add_special_tokens: bool = False,
streamer: Optional[BaseStreamer] = None,
**kwargs):
"""
text generator inference api
Args:
inputs(List(str), List(List(str))): The token id list or a list of token id list.
do_sample(bool): Whether to do sampling on the candidate ids.
If set True it will be enabled, and set it to be False to disable the sampling, equivalent to topk 1.
If set None, it follows the setting in the configureation in the model. Default None.
top_k(int): Determine the topK numbers token id as candidate. This should be a positive number.
If set None, it follows the setting in the configureation in the model. Default 1.
top_p(float): The accumulation probability of the candidate token ids below the top_p will be select as the
condaite ids. The valid value of top_p is between (0, 1]. If the value is larger than 1,
top_K algorithm will be enabled. If set None, it follows the setting in the configureation in the model.
Default 1.
temperature (`float`, *optional*, defaults to 1.0): The value used to modulate the next token probabilities.
eos_token_id(int): The end of sentence token id. If set None, it follows the setting in the configureation
in the model. Default 2.
pad_token_id(int): The padding of sentence token id. If set None, it follows the setting in the
configureation in the model. Default 0.
repetition_penalty(float): The penalty factor of the frequency that generated words. The If set 1,
the repetition_penalty will not be enabled. If set None, it follows the setting in the configureation in
the model. Default 1.
max_length: The maximum length of the generated words. If set None, it follows the setting in the
configureation in the model. Default 256.
is_sample_acceleration: The postprocess are processing in model. Default False.
add_special_tokens: Add special tokens for preprocess.
streamer: The streamer that generator uses.
Returns:
outputs of model infer
"""
del add_special_tokens
start_generate_time = time.time()
output_ids = self.generate(inputs, do_sample, top_k, top_p, temperature,
repetition_penalty, eos_token_id, pad_token_id,
max_length, is_sample_acceleration, streamer, **kwargs)
generate_time = time.time() - start_generate_time
outputs = self.postprocess(output_ids)
return outputs, output_ids, generate_time