import math
import logging
from functools import wraps
from typing import Dict
from enum import Enum
import yaml
import torch
import torch_npu
import numpy as np
import torchair as tng
logger = logging.getLogger(__name__)
def get_had_pow2(n, norm=True):
if not ((n & (n - 1) == 0) and (n > 0)):
raise ValueError(f"n must be a positive power of 2, got{n}")
had = torch.ones(1, 1, dtype=torch.bfloat16).npu()
while had.shape[0] != n:
had = torch.cat((torch.cat([had, had], 1), torch.cat([had, -had], 1)), 0)
if norm:
had /= math.sqrt(2)
return had
def read_yaml(yaml_file_path):
try:
with open(yaml_file_path, "r", encoding="utf-8") as file:
data = yaml.safe_load(file)
except FileNotFoundError:
logger.error(f"No such yaml file: {yaml_file_path}")
except yaml.YAMLError as e:
logger.error(f"Load yaml file failed: {e}")
return data
class FakeContextManager:
def __init__(self) -> None:
pass
def __enter__(self):
pass
def __exit__(self, exc_type, exc_value, traceback):
pass
def superkernel_scope(enable: bool, scope: str, options: str = None):
if enable:
return tng.scope.super_kernel(scope, options)
else:
return FakeContextManager()
def align_up(a, b):
if b <= 0:
raise ValueError("b should be larger then zero!")
return (a + b - 1) // b * b
def align_memory(tensor: torch.Tensor, alignment: int) -> torch.Tensor:
"""Return a view of `tensor` whose `data_ptr()` is `alignment`-byte aligned.
`alignment` is a byte count. The caller must over-allocate `tensor` by at
least `ceil(alignment / tensor.element_size())` elements so the returned
view can be safely narrowed to its target size.
"""
data_ptr = tensor.data_ptr()
aligned_addr = (data_ptr + alignment - 1) // alignment * alignment
offset = int((aligned_addr - data_ptr) // tensor.element_size())
return tensor.narrow(0, offset, tensor.numel() - offset)
def ceil_div(a, b):
return (a + b - 1) // b
def update_settings(runner_settings: Dict, module_name: str, key: str, value):
if runner_settings.get(module_name) is None:
raise Exception(f"runner_settings doesn't have submodule ({module_name})!")
module = runner_settings.get(module_name)
module.update({key: value})
logger.info(f"add ({key}: {value}) to runner_settings.")
return runner_settings
def override(func):
@wraps(func)
def wrapper(*args, **kwargs):
return func(*args, **kwargs)
return wrapper
def get_init_attn_mask(mask_length, device, valid_len=None):
share_mask_tril = ~torch.tril(
torch.ones((mask_length, mask_length),
dtype=torch.bool, device=device))
if valid_len is not None:
share_mask_tril[-valid_len:, :] = torch.zeros(valid_len, mask_length)
return share_mask_tril
def get_decode_mask(mask_length, device, position):
decode_mask = torch.zeros((1, mask_length), device=device)
decode_mask[0, :position] = 1
return decode_mask
def npu_wait_tensor(switch_flag: bool, out: torch.Tensor, wait_tensor: torch.Tensor):
if switch_flag:
out = tng.scope.npu_wait_tensor(out, wait_tensor)
return out
def npu_stream_switch(switch_flag: bool, stream_tag: str, stream_priority: int = 0):
if switch_flag:
return tng.scope.npu_stream_switch(stream_tag, stream_priority)
else:
return FakeContextManager()
def limit_core_num(switch_flag: bool, aic_num: str, aiv_num: str):
if switch_flag:
return tng.scope.limit_core_num(aic_num, aiv_num)
else:
return FakeContextManager()
def record_event(switch_flag: bool, events: tuple[torch.npu.Event], idx: int):
if switch_flag:
tng.ops.npu_tagged_event_record(events[idx])
def wait_event(switch_flag: bool, events: tuple[torch.npu.Event], idx: int):
if switch_flag:
tng.ops.npu_tagged_event_wait(events[idx])
def record_stream(switch_flag: bool, out: torch.Tensor, stream_id: str):
if switch_flag:
tng.ops.npu_record_tagged_stream(out, stream_id)
def npu_prefetch(switch_flag, weight, depend, size, offset=0):
if switch_flag:
return torch_npu.npu_prefetch(weight, depend, size, offset)
else:
return None
def process_infer_time(infer_time_rec, token_count):
if len(infer_time_rec) == 0:
logger.info(f"precoss infer time receives empty time record")
return 0
elif len(infer_time_rec) == 1 or (token_count <= 1):
return infer_time_rec[0]
else:
avg_token_per_round = token_count / len(infer_time_rec)
infer_time_rec = infer_time_rec[1:]
token_count -= 1
q1 = np.percentile(infer_time_rec, 25)
q3 = np.percentile(infer_time_rec, 75)
iqr_upper_threshold = q3 + 1.5 * (q3 - q1)
total_time = 0
for t in infer_time_rec:
if t > iqr_upper_threshold:
token_count -= avg_token_per_round
continue
total_time += t
if token_count == 0:
return infer_time_rec[0]
avg_infer_time = total_time / token_count
return avg_infer_time
class MicroBatchMode(Enum):
DISABLE = 0
PREFILL_MICRO_BATCH_DP_EP = 1
PREFILL_MICRO_BATCH_SP_TP_EP = 2
def remove_padding_left(tensor, pad_id):
if tensor.shape[0] == 1:
return [tensor[0]]
if tensor.dim() != 2:
raise ValueError("remove padding func input dim must be 2")
batch_size, seq_len = tensor.shape
output_tensorlist = []
for i in range(batch_size):
row = tensor[i]
mask = (row != pad_id)
if mask.any():
first_valid_token = torch.argmax(mask.float())
processed_row = row[first_valid_token:]
else:
processed_row = row
output_tensorlist.append(processed_row)
return output_tensorlist
def remove_eos_right(output_tensorlist: list[torch.Tensor], eos_id: int) -> list[list[int]]:
res = []
for toks in output_tensorlist:
if eos_id in toks:
toks = toks[:toks.index(eos_id)]
res.append(toks.cpu().tolist().append(eos_id))
return res
def detokenize_outputs(generate_ids_list, tokenizer, input_lens):
res_list = []
for generate_ids in generate_ids_list:
res = tokenizer.decode(generate_ids[input_lens:], skip_special_tokens=False)
if tokenizer.eos_token in res:
res = res.split(tokenizer.eos_token)[0]
res_list.append(res)
if isinstance(res_list, list):
logger.info("Inference decode result for batch 0: \n%s", res_list[0])
else:
logger.info("Inference decode result: \n%s", res_list)
return res_list
def check_common_parallel_settings(world_size, runner_settings):
if world_size <= 0:
raise ValueError(f"{world_size=} must greater than 0")
parallel_config = runner_settings.get("parallel_config", {})
batch_size = runner_settings.get("data_config").get("batch_size", 1)
target_keys = ("tp_size", "ep_size", "kvp_size")
for key, value in parallel_config.items():
is_target_key = any(target_key in key for target_key in target_keys)
if is_target_key and world_size % value != 0:
raise ValueError(f"{world_size=} is not divisible by {key}={value}")
if "dp_size" in key and batch_size % value != 0:
raise ValueError(f"{batch_size=} is not divisible by {key}={value}")
def update_common_vars(world_size, runner_settings):
attn_dp_size = world_size // runner_settings.get("parallel_config").get("attn_tp_size", 1)
moe_dp_size = world_size // runner_settings.get("parallel_config").get("moe_tp_size", 1)
moe_ep_size = moe_dp_size
embed_dp_size = world_size // runner_settings.get("parallel_config").get("embed_tp_size", 1)
batch_size = runner_settings.get("data_config").get("batch_size", 1)
batch_size_per_rank = batch_size // attn_dp_size
runner_settings = update_settings(runner_settings, "data_config", "batch_size_per_rank", batch_size_per_rank)
runner_settings = update_settings(runner_settings, "parallel_config", "attn_dp_size", attn_dp_size)
runner_settings = update_settings(runner_settings, "parallel_config", "moe_dp_size", moe_dp_size)
runner_settings = update_settings(runner_settings, "parallel_config", "moe_ep_size", moe_ep_size)
runner_settings = update_settings(runner_settings, "parallel_config", "embed_dp_size", embed_dp_size)
input_max_len = runner_settings.get("data_config").get("input_max_len", 32)
max_new_tokens = runner_settings.get("data_config").get("max_new_tokens", 32)
next_n = runner_settings.get("model_config").get("next_n", 0)
max_position_embeddings = max_new_tokens * (next_n + 1) + input_max_len
runner_settings = update_settings(runner_settings, "data_config", "max_position_embeddings",
max_position_embeddings)
def obtain_mtp_stats(next_n, model_name, total_accepted_num, cnt, infer_time_rec):
avg_accepted_num = torch.mean(total_accepted_num)
logger.info(f"Finished inference, number of loop step is {cnt}, "
f"draft tokens per batch is {cnt}*{next_n}, "
f"average accepted number per batch is {avg_accepted_num.to(torch.int32)}")
total_tokens = avg_accepted_num + cnt
equivalent_infer_time = process_infer_time(infer_time_rec, total_tokens)
avg_infer_time = process_infer_time(infer_time_rec, len(infer_time_rec))
logger.info(
f"{model_name} main and mtp model average inference time cost is {(avg_infer_time)*1000:.2f} ms")
logger.info(
f"{model_name} model average equivalent latency of MTP{next_n}"
f" is {(equivalent_infer_time)*1000:.2f} ms")
return avg_infer_time
def weight_dequant(weight: torch.Tensor, scale: torch.Tensor, block_size: int = 128) -> torch.Tensor:
"""
Dequantizes the given weight tensor using the provided scale tensor, efficiently handling cases where
`weight` is not a multiple of `block_size` by broadcasting `scale`.
Args:
weight (torch.Tensor): The quantized weight tensor of shape(M, N).
scale (torch.Tensor): The scale tensor of shape (M // block_size, N // block_size).
block_size (int, optional): The block size to use for dequantization. Defaults to 128.
Returns:
torch.Tensor: The dequantized weight tensor of the same shape as `weight`, converted to the default dtype.
Raises:
AssertionError: If `scale` dimensions do not align with `weight` shape after scaling.
"""
M, N = weight.shape
scale_m, scale_n = scale.shape
assert scale_m == (
M + block_size - 1) // block_size, "Mismatch in scale rows and weight rows."
assert scale_n == (
N + block_size - 1) // block_size, "Mismatch in scale columns and weight columns."
weight = weight.to(torch.float32)
scale = scale.to(torch.float32)
scale_expanded = scale.repeat_interleave(
block_size, dim=0).repeat_interleave(block_size, dim=1)
scale_expanded = scale_expanded[:M, :N]
dequantized_weight = weight * scale_expanded
dequantized_weight = dequantized_weight.to(torch.get_default_dtype())
return dequantized_weight