"""Llm boost models' APIs."""
import numpy as np
from mindspore import Tensor
from mindspore.experimental.llm_boost.register import LlmBoostRegister, LlmBoostType
from research.llm_boost.llm_boost_config import LlmBoostConfig
from research.llm_boost.utils import is_support_lccl
from mindformers.models.modeling_utils import PreTrainedModel
from mindformers.tools.register.register import MindFormerModuleType, MindFormerRegister
from mindformers.version_control import need_nz
from mindformers.tools.utils import get_predict_run_mode
from mindformers.modules.layers import FreqsMgr
__all__ = ["LlmBoostForCausalLM"]
@MindFormerRegister.register(MindFormerModuleType.MODELS)
class LlmBoostForCausalLM(PreTrainedModel):
r"""
Use third-party frameworks to accelerate large models
Args:
config (LlmBoostConfig): The config of llm boost model.
Returns:
output: Tensor, the output of llm decoderlayer
"""
config_class = LlmBoostConfig
def __init__(self, config):
super().__init__(config, auto_prefix=True)
self.use_past = config.use_past
self.head_dim = config.hidden_size // config.num_heads
self.is_first_iteration = True
self.llm_backend = config.llm_backend
self.predict_run_mode = get_predict_run_mode()
self.freqs_mgr = FreqsMgr(
head_dim=self.head_dim,
seq_length=config.seq_length,
max_position_embedding=config.max_position_embedding,
rotary_dtype=config.rotary_dtype,
theta=config.theta,
scaling_factor=config.scaling_factor,
extend_method=config.extend_method,
parallel_config=config.parallel_config,
)
config.need_nz = need_nz()
if config.communication_backend == "":
config.communication_backend = "lccl" if is_support_lccl() else "hccl"
llm_boost_kwargs = {"config": config}
self.llm_boost = LlmBoostRegister.get_instance(
config.llm_backend, config.boost_model_name, **llm_boost_kwargs
)
self.llm_boost.init()
self.is_set_kvcache = False
self.need_prepare = (config.llm_backend == LlmBoostType.BUILDIN)
def prepare_inputs_for_generation(self, input_ids, **kwargs):
model_inputs = {}
if self.config.is_dynamic and "origin_inputs" in kwargs:
input_ids = kwargs["origin_inputs"]
model_inputs["input_ids"] = Tensor.from_numpy(input_ids.astype(np.int32))
batch_valid_length = kwargs.get("valid_length_each_example", None)
position_ids = kwargs.get("position_ids", None)
block_tables = kwargs.get("block_tables", None)
slot_mapping = kwargs.get("slot_mapping", None)
lm_head_indices = kwargs.get("prefill_head_indices", None)
prefill = kwargs.get("prefill")
if self.llm_backend == LlmBoostType.BUILDIN:
model_inputs["llm_boost_inputs"] = self.prepare_inputs_for_build_in(
prefill=prefill,
input_ids=input_ids,
position_ids=position_ids,
batch_valid_length=batch_valid_length,
block_tables=block_tables,
slot_mapping=slot_mapping,
lm_head_indices=lm_head_indices,
)
return model_inputs
def prepare_inputs_for_build_in(
self,
prefill,
input_ids,
position_ids,
batch_valid_length,
block_tables,
slot_mapping,
lm_head_indices,
):
llm_boost_inputs = {}
seq_lens = batch_valid_length.tolist()
bs = batch_valid_length.shape[0]
input_ids_list = []
if prefill:
if bs > 1 and input_ids.shape[0] != 1:
for i in range(bs):
context_len = batch_valid_length[i]
input_ids_list.append(input_ids[i][:context_len])
slot_mapping = np.delete(slot_mapping, np.where(slot_mapping == -1))
if position_ids is None:
position_ids_list = [
np.arange(context_len, dtype=np.int64)
for context_len in batch_valid_length
]
position_ids = np.concatenate(position_ids_list, 0)
llm_boost_inputs["position_ids"] = Tensor.from_numpy(position_ids)
if lm_head_indices is None:
lm_head_indices = np.cumsum(batch_valid_length, dtype=np.int64) - 1
llm_boost_inputs["lm_head_indices"] = Tensor.from_numpy(
lm_head_indices.astype(np.int64)
)
else:
if input_ids.shape[-1] != 1:
for i in range(bs):
context_len = batch_valid_length[i]
input_ids_list.append(input_ids[i][context_len - 1: context_len])
if position_ids is None:
position_ids = batch_valid_length - 1
if input_ids_list:
input_ids = np.concatenate(input_ids_list, 0)
llm_boost_inputs["input_ids"] = Tensor.from_numpy(input_ids.astype(np.int64))
llm_boost_inputs["position_ids"] = Tensor.from_numpy(
position_ids.astype(np.int64)
)
llm_boost_inputs["block_tables"] = Tensor.from_numpy(block_tables)
llm_boost_inputs["slot_mapping"] = Tensor.from_numpy(slot_mapping)
llm_boost_inputs["batch_valid_length"] = Tensor.from_numpy(batch_valid_length)
llm_boost_inputs["seq_lens"] = seq_lens
return llm_boost_inputs
def prepare_inputs_for_predict_layout(self, input_ids, **kwargs):
return input_ids
def set_dynamic_inputs(self, **kwargs):
pass
def add_flags_custom(self, is_first_iteration):
self.is_first_iteration = is_first_iteration
self.llm_boost.add_flags(is_first_iteration=self.is_first_iteration)
def construct(self, llm_boost_inputs=None, **kwargs):
"""llm boost forward"""
if self.need_prepare:
if not self.is_set_kvcache:
self.llm_boost.set_kvcache()
self.is_set_kvcache = True
llm_boost_inputs["cos_embed"] = self.freqs_mgr.freqs_cos
llm_boost_inputs["sin_embed"] = self.freqs_mgr.freqs_sin
return self.llm_boost.forward(llm_boost_inputs)
bvl = kwargs["batch_valid_length"]
if bvl.ndim > 1:
bvl = Tensor(bvl[0].asnumpy())
return self.llm_boost.forward(kwargs["input_ids"], bvl, None)