import os
import time
import argparse
import logging
import copy
import numpy as np
import torch
import torch_npu
from transformers import AutoTokenizer, AutoProcessor
from executor.utils import get_default_group, process_infer_time, remove_padding_left, detokenize_outputs
from executor.model_loader.default_loader import DefaultModelLoader
from executor.model_loader.dummy_loader import DummyModelLoader
from module.quantization import (QUANTIZATION_METHODS,
QuantizationMethods,
QuantizeMethodBase,
get_quantization_config,
get_quant_config)
from executor.utils.logging_config import setup_logging
torch.npu.config.allow_internal_format = True
setup_logging()
logger = logging.getLogger(__name__)
torch.manual_seed(42)
torch.npu.manual_seed_all(42)
class FakeContextManager:
def __enter__(self):
return self
def __exit__(self, exc_type, exc_value, traceback):
pass
@staticmethod
def step():
return
class ModelRunner:
def __init__(self, runner_settings):
self.runner_settings = runner_settings
self.model_name = runner_settings.get("model_name", "default_model_name")
model_path = self.runner_settings.get("model_path")
self.dtype = runner_settings.get("model_config").get("dtype", torch.bfloat16)
self.enable_online_split_weight = runner_settings.get("model_config").get("enable_online_split_weight", False)
self.max_position_embeddings = \
runner_settings.get("data_config").get("max_position_embeddings", 131072)
self.input_max_len = runner_settings.get("data_config").get("input_max_len", 1024)
self.max_new_tokens = runner_settings.get("data_config").get("max_new_tokens", 32)
self.batch_size = runner_settings.get("data_config").get("batch_size", 16)
self.tokenizer = None
self.model = None
self.device = None
self.local_rank = int(os.getenv("LOCAL_RANK", "0"))
self.rank_offset = int(os.getenv("RANK_OFFSET", "0"))
self.global_rank = self.local_rank + self.rank_offset
self.world_size = int(os.getenv("WORLD_SIZE", "1"))
if self.world_size == 1 or self.enable_online_split_weight:
self.model_path = model_path
else:
self.model_path = os.path.join(model_path, f"rank_{self.global_rank}")
self.res_path = os.getenv("RES_PATH", "./")
self.enable_profiler = runner_settings.get("model_config").get("enable_profiler", False)
self.enable_prefill_profiler = self.enable_profiler
self.enable_decode_profiler = self.enable_profiler
self.use_pretrained_model = True
self.execute_mode = runner_settings.get("exe_mode", "ge_graph")
self.tokenizer_mode = runner_settings.get("model_config").get("tokenizer_mode", "default")
self.platform_version = runner_settings.get("model_config").get("platform_version", "A3")
self.profiler = FakeContextManager()
self.hf_config = None
self.quantization = None
self.init_device()
@staticmethod
def define_profiler(enable_profiler=False, profile_save_path="prof", active=10, repeat=1, skip_first=10):
if enable_profiler:
os.makedirs(profile_save_path, exist_ok=True)
experimental_config = torch_npu.profiler._ExperimentalConfig(
profiler_level=torch_npu.profiler.ProfilerLevel.Level1,
aic_metrics=torch_npu.profiler.AiCMetrics.PipeUtilization
)
profiler = torch_npu.profiler.profile(
activities=[
torch_npu.profiler.ProfilerActivity.NPU,
torch_npu.profiler.ProfilerActivity.CPU,
],
with_stack=False,
record_shapes=False,
profile_memory=False,
experimental_config=experimental_config,
schedule=torch_npu.profiler.schedule(
wait=0,
warmup=0,
active=active,
repeat=repeat,
skip_first=skip_first
),
on_trace_ready=torch_npu.profiler.tensorboard_trace_handler(profile_save_path)
)
else:
profiler = FakeContextManager()
return profiler
def init_device(self):
logger.info("Set execution using npu index: %s, global: %s", self.local_rank, self.global_rank)
self.device = torch.device("%s:%s" % ("npu", self.local_rank))
torch.npu.set_device(self.device)
if torch.npu.is_available() and self.world_size > 1:
default_pg = get_default_group()
if default_pg is None:
options = torch_npu._C._distributed_c10d.ProcessGroupHCCL.Options()
if self.platform_version == "950":
options.hccl_config = {"hccl_op_expansion_mode" : 5}
torch.distributed.init_process_group(
backend="hccl", world_size=self.world_size, rank=self.global_rank, pg_options=options)
def init_model(self, model, config=None, **kwargs):
if self.enable_online_split_weight:
self._init_model_with_online_splited_weight(model, config, **kwargs)
else:
self._init_model_with_offline_splited_weight(model, config, **kwargs)
self._process_weight_after_loading()
self.compile_model()
self.init_tokenizer()
def _process_weight_after_loading(self):
self.to_device()
if self.enable_online_split_weight:
for _, module in self.model.named_modules():
quant_method = getattr(module, "quant_method", None)
if isinstance(quant_method, QuantizeMethodBase):
quant_method.process_weights_after_loading(module)
else:
self.scale_dtype_adapter()
self.cast_format()
def _init_model_with_online_splited_weight(self, model, config, **kwargs):
if config is None:
raise Exception("config cannot be None")
if self.use_pretrained_model:
logger.info("Try to load pretrained model in path: %s", self.model_path)
loader = DefaultModelLoader()
else:
loader = DummyModelLoader()
self.hf_config = config.from_pretrained(
self.model_path,
low_cpu_mem_usage=True,
ignore_mismatched_sizes=True,
runner_settings=self.runner_settings
)
self._verify_quantization()
if self.quantization is not None:
self.hf_config.quant_config = get_quant_config(self.hf_config, self.quantization, self.model_path)
self.check_model_cfg()
self.update_model_cfg()
self.model = loader.load_model(config=self.hf_config, model_cls=model,
runner_settings=self.runner_settings, model_path=self.model_path, **kwargs)
def _init_model_with_offline_splited_weight(self, model, config, **kwargs):
if self.use_pretrained_model:
self._load_model_with_manual_splited_weight(model, **kwargs)
else:
self._init_model_from_config(model, config=config, **kwargs)
def _init_model_from_config(self, model, config, **kwargs):
if config is None:
raise Exception("config cannot be None")
config_file = os.path.join(self.model_path, "config.json")
model_config = config.from_pretrained(config_file, torch_dtype=self.dtype)
self.model = model(model_config, runner_settings=self.runner_settings, **kwargs).to(self.dtype)
def _load_model_with_manual_splited_weight(self, model, **kwargs):
logger.info("Try to load pretrained model in path: %s", self.model_path)
self.model = model.from_pretrained(self.model_path,
low_cpu_mem_usage=True,
ignore_mismatched_sizes=True,
torch_dtype="auto",
runner_settings=self.runner_settings,
**kwargs)
def save_model(self):
pass
def scale_dtype_adapter(self):
pass
def to_device(self):
self.model.to(self.device)
logger.info("Model weights H2D finished.")
def cast_format(self):
pass
def init_tokenizer(self):
if "deepseek_vl2" in self.model_name:
self.tokenizer = AutoProcessor.from_pretrained(
self.model_path, trust_remote_code=True, use_fast=True
)
logging.info(f"use deepseek_vl2 tokenizer.")
return
self.tokenizer = AutoTokenizer.from_pretrained(
self.model_path,
padding_side="right",
truncation_side='right',
trust_remote_code=True
)
if self.tokenizer.pad_token is None:
self.tokenizer.pad_token = self.tokenizer.eos_token
self.tokenizer.pad_token_id = self.tokenizer.eos_token_id
def tokenize_prompts(self, prompts):
kwargs = {
"return_tensors": "pt", "truncation": True, "padding": "max_length", "max_length": self.input_max_len
}
if self.runner_settings.get("data_config").get("dataset", "default") != "default":
from executor.utils.data_utils import build_dataset_input
prompts = build_dataset_input(self.tokenizer, prompts, self.input_max_len,
self.max_new_tokens, is_chat=False)
inputs = self.tokenizer(prompts, **kwargs).to(self.device)
return inputs
def compile_model(self):
logger.info("The final model structure is: \n %s", self.model)
if "graph" in self.execute_mode:
logger.info("Try to compile model")
self.graph_compile()
def graph_compile(self):
import torchair as tng
import torchair.ge_concrete_graph.ge_converter.experimental.patch_for_hcom_allreduce
from torchair.configs.compiler_config import CompilerConfig
compiler_config = CompilerConfig()
compiler_config.experimental_config.frozen_parameter = True
compiler_config.experimental_config.tiling_schedule_optimize = True
npu_backend = tng.get_npu_backend(compiler_config=compiler_config)
self.model = torch.compile(self.model, dynamic=True, fullgraph=True, backend=npu_backend)
@staticmethod
def mark_detail(model_inputs, item_key, is_cache=False):
item = model_inputs.get(item_key, None)
if item is None:
return
if is_cache:
for item_sub in item:
for sub_item in item_sub:
if isinstance(sub_item, torch.Tensor):
torch._dynamo.mark_static(sub_item)
elif isinstance(item, torch.Tensor):
torch._dynamo.mark_static(item)
def mark_inputs(self, model_inputs, loop_list=None):
if loop_list is None:
loop_list = []
for input_key, _ in model_inputs.items():
is_cache = True if input_key in loop_list else False
self.mark_detail(model_inputs, input_key, is_cache=is_cache)
def model_input_prepare(self, input_dict):
pass
return None
def model_inference(self, model_inputs, is_prefill, warm_up=False):
torch.npu.synchronize()
start_time = time.time()
with torch.no_grad():
logits = self.model(**model_inputs)
torch.npu.synchronize()
end_time = time.time()
inference_time = end_time - start_time
inference_stage = "prefill" if is_prefill else "decode"
logger.info(f"{self.model_name} inference time cost of {inference_stage} is {(inference_time)*1000:.2f} ms")
return (logits, inference_time)
def _parse_quant_hf_config(self):
quant_cfg = getattr(self.hf_config, "quantization_config", None)
if quant_cfg is None:
quant_cfg = getattr(self.hf_config, "compression_config", None)
return quant_cfg
def _verify_quantization(self) -> None:
'''
Adapted from vllm, verify quantization configurations
'''
supported_quantization = QUANTIZATION_METHODS
quant_cfg = self._parse_quant_hf_config()
if quant_cfg is not None and quant_cfg:
quant_method = quant_cfg.get("quant_method", "").lower()
quant_method = quant_method.replace("compressed_tensors",
"compressed-tensors")
self.quantization = quant_method
if self.quantization is not None:
if self.quantization not in supported_quantization:
raise ValueError(
f"Unknown quantization method: {self.quantization}. Must "
f"be one of {supported_quantization}.")
def model_generate(self, input_dict, input_lens, warm_up=False):
logger.info("Prompt lens is : %d", input_lens)
generate_tokens = 0
cnt = 0
infer_time_rec = []
enable_profiler = self.enable_prefill_profiler and not warm_up
self.profiler = self.define_profiler(
enable_profiler=enable_profiler,
profile_save_path=os.path.join(self.res_path, "prof", "prefill"),
active=1, skip_first=0, repeat=1)
with self.profiler:
model_inputs = self.model_input_prepare(input_dict)
outputs = self.model_inference(model_inputs, is_prefill=input_dict['is_prefill'], warm_up=warm_up)
logits = outputs[0]
inference_time = outputs[1]
self.model_output_process(model_inputs, logits, input_dict)
self.profiler.step()
generate_tokens += 1
cnt += 1
infer_time_rec.append(inference_time)
enable_profiler = self.enable_decode_profiler and not warm_up
self.profiler = self.define_profiler(
enable_profiler=enable_profiler,
profile_save_path=os.path.join(self.res_path, "prof", "decode"))
with self.profiler:
while True:
jump_flag = self.get_jump_flag(cnt, warm_up)
if jump_flag:
break
model_inputs = self.model_input_prepare(input_dict)
outputs = self.model_inference(model_inputs, is_prefill=input_dict['is_prefill'], warm_up=warm_up)
logits = outputs[0]
inference_time = outputs[1]
self.model_output_process(model_inputs, logits, input_dict)
self.profiler.step()
generate_tokens += 1
cnt += 1
infer_time_rec.append(inference_time)
if not warm_up:
avg_infer_time = process_infer_time(infer_time_rec, cnt)
logger.info(f"{self.model_name} average inference time cost is {(avg_infer_time)*1000:.2f} ms")
generate_ids = input_dict["generate_ids"].clip(0,\
self.model.config.vocab_size - 1)
generate_ids_list = remove_padding_left(generate_ids, self.tokenizer.pad_token_id)
res_list = detokenize_outputs(generate_ids_list, self.tokenizer, input_lens)
return res_list
def get_jump_flag(self, cnt, warm_up):
default_decode_dump = 2
jump_flag_warm = warm_up and cnt >= default_decode_dump
jump_flag_oversize = cnt >= self.max_new_tokens
jump_flag = jump_flag_oversize or jump_flag_warm
return jump_flag
def model_output_process(self, model_inputs, outputs, input_dict):
pass
def check_model_cfg(self):
pass
def update_model_cfg(self):
pass