"""Test PreProcess in GenerationMixin.forward"""
from unittest import mock
import time
import pytest
import numpy as np
from mindspore import set_device, Tensor
from mindformers.tools.utils import is_pynative
from mindformers.generation.text_generator import GenerationMixin
from mindformers.tools.debug_info import DetailedLatency, Profiling
set_device(device_target='CPU')
class TestConfig:
def __init__(self):
self.is_encoder_decoder = False
self.is_dynamic = True
class TestGenerationMixin:
"""
Test GenerationMixin class
Mock:
prepare_inputs_for_generation: defined in model.
add_flags_custom: defined in model.
call_perf: get end time of preprocess.
call_accuracy: get outputs of preprocess.
"""
def __init__(self):
self.detailed_latency = DetailedLatency()
self.profile = Profiling()
self.config = TestConfig()
self._pre_set_phase = None
self._exec_add_flags = True
self.is_pynative = is_pynative()
def prepare_inputs_for_generation(self, input_ids, **kwargs):
return {"input_ids": Tensor.from_numpy(input_ids)}
def add_flags_custom(self, is_first_iteration):
"""Mock it."""
return
def call_perf(self, input_ids, input_position=None, init_reset=None, batch_valid_length=None, block_tables=None,
slot_mapping=None):
time_end = time.time()
return time_end
def call_accuracy(self, input_ids, input_position=None, init_reset=None, batch_valid_length=None, block_tables=None,
slot_mapping=None):
return {"input_ids": input_ids, "input_position": input_position, "init_reset": init_reset,
"batch_valid_length": batch_valid_length, "block_tables": block_tables, "slot_mapping": slot_mapping}
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
@mock.patch('mindformers.generation.text_generator.GenerationMixin.__init__', TestGenerationMixin.__init__)
@mock.patch('mindformers.generation.text_generator.GenerationMixin.prepare_inputs_for_generation',
TestGenerationMixin.prepare_inputs_for_generation)
@mock.patch('mindformers.generation.text_generator.GenerationMixin.add_flags_custom',
TestGenerationMixin.add_flags_custom)
@mock.patch('mindformers.generation.text_generator.GenerationMixin.__call__', TestGenerationMixin.call_perf)
@pytest.mark.parametrize('batch_size', [1, 8, 16])
@pytest.mark.parametrize('seq_length', [256, 512, 1024, 2048])
@pytest.mark.parametrize('prefill', [True, False])
def test_preprocess_perf(batch_size, seq_length, prefill):
"""
Feature: Test preprocess performance.
Description: Check the time delay for preprocess.
Expectation: Success.
"""
input_ids = np.arange(batch_size * seq_length).reshape(batch_size, seq_length)
valid_length_each_example = np.random.uniform(1, seq_length, (batch_size,)).astype(np.int32)
block_tables = np.arange(32 * 256).reshape(32, 256).astype(np.int32)
slot_mapping = np.arange(32 * 256).astype(np.int32)
generation_mixin = GenerationMixin()
for _ in range(3):
_, _ = generation_mixin.forward(input_ids=input_ids, valid_length_each_example=valid_length_each_example,
block_tables=block_tables, slot_mapping=slot_mapping,
prefill=prefill, use_past=True)
time_list = []
for _ in range(10):
time_start = time.time()
time_end, _ = generation_mixin.forward(input_ids=input_ids, valid_length_each_example=valid_length_each_example,
block_tables=block_tables, slot_mapping=slot_mapping,
prefill=prefill, use_past=True)
time_list.append((time_end - time_start) * 1000.0)
if prefill:
assert np.mean(time_list) < 0.4
else:
assert np.mean(time_list) < 0.5
def get_expected_outputs(input_ids, valid_length_each_example, prefill):
"""
Get expected_outputs of preprocess.
"""
batch_size = input_ids.shape[0]
seq_length = input_ids.shape[1]
current_index = [valid_length_each_example[i] - 1 + i * seq_length for i in range(batch_size)]
if not prefill and input_ids.shape[-1] != 1:
inputs_tmp = []
for i, index_value in enumerate(current_index):
current_index_tmp = (
int(index_value) - i * input_ids.shape[1]
)
inputs_tmp.append(input_ids[i][current_index_tmp: current_index_tmp + 1])
inputs_tmp = np.array(inputs_tmp, dtype=np.int32)
input_ids = inputs_tmp
input_position = np.array(current_index).astype(np.int32)
init_reset = np.array([not prefill], dtype=np.bool_)
batch_valid_length = np.array([valid_length_each_example], dtype=np.int32)
return {"input_ids": input_ids, "input_position": input_position, "init_reset": init_reset,
"batch_valid_length": batch_valid_length}
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
@mock.patch('mindformers.generation.text_generator.GenerationMixin.__init__', TestGenerationMixin.__init__)
@mock.patch('mindformers.generation.text_generator.GenerationMixin.prepare_inputs_for_generation',
TestGenerationMixin.prepare_inputs_for_generation)
@mock.patch('mindformers.generation.text_generator.GenerationMixin.add_flags_custom',
TestGenerationMixin.add_flags_custom)
@mock.patch('mindformers.generation.text_generator.GenerationMixin.__call__', TestGenerationMixin.call_accuracy)
@pytest.mark.parametrize('batch_size', [1, 8, 16])
@pytest.mark.parametrize('seq_length', [256, 512, 1024, 2048])
@pytest.mark.parametrize('prefill', [True, False])
def test_preprocess_accuracy(batch_size, seq_length, prefill):
"""
Feature: Test preprocess accuracy.
Description: Check the accuracy for preprocess.
Expectation: Success.
"""
input_ids = np.arange(batch_size * seq_length).reshape(batch_size, seq_length)
valid_length_each_example = np.random.uniform(1, seq_length, (batch_size,)).astype(np.int32)
block_tables = np.arange(32 * 256).reshape(32, 256).astype(np.int32)
slot_mapping = np.arange(32 * 256).astype(np.int32)
generation_mixin = GenerationMixin()
outputs, _ = generation_mixin.forward(input_ids=input_ids, valid_length_each_example=valid_length_each_example,
block_tables=block_tables, slot_mapping=slot_mapping,
prefill=prefill, use_past=True)
expected_outputs = get_expected_outputs(input_ids, valid_length_each_example, prefill)
assert np.allclose(outputs["input_ids"].asnumpy(), expected_outputs["input_ids"])
assert np.allclose(outputs["batch_valid_length"].asnumpy(), expected_outputs["batch_valid_length"])
assert np.allclose(outputs["block_tables"].asnumpy(), block_tables)
assert np.allclose(outputs["slot_mapping"].asnumpy(), slot_mapping)