from dataclasses import dataclass
import torch
import torch_npu
@dataclass
class TestArgs:
query_layer: torch.Tensor
key_layer: torch.Tensor
value_layer: torch.Tensor
pos_key_layer: torch.Tensor
pos_query_layer: torch.Tensor
relative_pos: torch.Tensor
atten_mask: torch.Tensor
pos_att_type: str
score_scale: float
@dataclass
class DataArgs:
b: int
n: int
s: int
d: int
bucket_size: int
max_position: int
pos_att_type: str
def build_relative_position(
query_size: int,
key_size: int,
bucket_size: int = -1,
max_position: int = -1,
device: torch.device = torch.device("cpu"),
) -> torch.Tensor:
q_ids = torch.arange(0, query_size, dtype=torch.long, device=device)
k_ids = torch.arange(0, key_size, dtype=torch.long, device=device)
rel_pos_ids = q_ids[:, None] - k_ids.repeat(q_ids.shape[0], 1)
if bucket_size > 0 and max_position > 0:
mid = bucket_size // 2
sign = rel_pos_ids.sign()
abs_pos = torch.where(
(rel_pos_ids < mid) & (rel_pos_ids > -mid),
torch.zeros_like(rel_pos_ids).fill_(mid - 1),
rel_pos_ids.abs(),
).float()
log_pos = (
torch.ceil(
torch.log(abs_pos / mid)
/ torch.log(torch.tensor((max_position - 1) / mid))
* (mid - 1)
).long()
+ mid
)
bucket_pos = torch.where(abs_pos <= mid, rel_pos_ids, log_pos * sign).long()
rel_pos_ids = bucket_pos
rel_pos_ids = rel_pos_ids[:query_size, :]
return rel_pos_ids
def create_binary_tensor(b: int, s: int):
rand_tensor = torch.rand((b, 1, s, s), dtype=torch.float16)
bin_tensor = torch.where(
rand_tensor > 0.5,
torch.tensor(1.0, dtype=torch.float16),
torch.tensor(0.0, dtype=torch.float16),
)
min_fp16 = torch.finfo(torch.float16).min
result = torch.where(bin_tensor == 1, min_fp16, bin_tensor)
return result
def create_score_scale(pos_att_type: str, d: int):
pos_att_type_list = tuple(
[x.strip() for x in pos_att_type.lower().split("|")] if pos_att_type else []
)
score_scale = d**-0.5
scale_factor = 1 + len(pos_att_type_list)
score_scale = (d * scale_factor) ** -0.5
return score_scale
def create_test_data(args: DataArgs, npu_device):
query_layer = torch.rand(args.b, args.n, args.s, args.d, dtype=torch.float16).uniform_(-1, 1)
key_layer = torch.rand(args.b, args.n, args.s, args.d, dtype=torch.float16).uniform_(-1, 1)
value_layer = torch.rand(args.b, args.n, args.s, args.d, dtype=torch.float16).uniform_(-1, 1)
pos_key_layer = torch.rand(2 * args.s, args.n, args.d, dtype=torch.float16).uniform_(-1, 1)
pos_query_layer = torch.rand(2 * args.s, args.n, args.d, dtype=torch.float16).uniform_(-1, 1)
relative_position = build_relative_position(
args.s, args.s, args.bucket_size, args.max_position, query_layer.device
)
atten_mask = create_binary_tensor(args.b, args.s)
score_scale = create_score_scale(args.pos_att_type, args.d)
test_args = TestArgs(
query_layer.to(npu_device),
key_layer.to(npu_device),
value_layer.to(npu_device),
pos_key_layer.to(npu_device),
pos_query_layer.to(npu_device),
relative_position.to(npu_device),
atten_mask.to(npu_device),
args.pos_att_type,
score_scale,
)
return test_args