import os
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch_npu
from torch_npu.contrib import transfer_to_npu
from torch.profiler import ProfilerActivity, tensorboard_trace_handler
from torch.profiler import profile, schedule
from torch.autograd.profiler import record_function
from torch.cuda.amp import autocast
from einops import rearrange, repeat
os.environ["COMBINED_ENABLE"] = "1"
os.environ["INF_NAN_MODE_ENABLE"] = "1"
os.environ["PYTORCH_NPU_ALLOC_CONF"] = "expandable_segments:True"
torch.manual_seed(0)
torch.set_printoptions(precision=16)
torch_npu.npu.set_compile_mode(jit_compile=False)
torch_npu.npu.config.allow_internal_format = False
torch.npu.conv.allow_hf32 = False
torch.npu.matmul.allow_hf32 = False
device = torch.device('npu:6')
def get_interleave_matrix(n):
matrix = torch.zeros(n, n, dtype=torch.bfloat16, device=device)
for i in range(0, n, 2):
matrix[i + 0, i + 1] = 1
matrix[i + 1, i + 0] = -1
return matrix
def get_half_matrix(n):
matrix = torch.zeros(n, n, dtype=torch.bfloat16)
half = n // 2
matrix[:half, half:] = torch.eye(half)
matrix[half:, :half] = -torch.eye(half)
return matrix.to(device)
def compose_3matrix(matrix_a, matrix_b, matrix_c):
total_rows = matrix_a.size(0) + matrix_b.size(0) + matrix_c.size(0)
total_cols = matrix_a.size(1) + matrix_b.size(1) + matrix_c.size(1)
result = torch.zeros((total_rows, total_cols), dtype=torch.bfloat16)
result[:matrix_a.size(0), :matrix_a.size(1)] = matrix_a
b_row_start = matrix_a.size(0)
b_col_start = matrix_a.size(1)
result[b_row_start:b_row_start + matrix_b.size(0),
b_col_start:b_col_start + matrix_b.size(1)] = matrix_b
c_row_start = matrix_a.size(0) + matrix_b.size(0)
c_col_start = matrix_a.size(1) + matrix_b.size(1)
result[c_row_start:c_row_start + matrix_c.size(0),
c_col_start:c_col_start + matrix_c.size(1)] = matrix_c
return result.to(device)
def rotate_half(x):
x1, x2 = torch.chunk(x, 2, dim=-1)
return torch.cat((-x2, x1), dim=-1)
def rotate_every_two(x: torch.Tensor) -> torch.Tensor:
x = rearrange(x, '... (d j) -> ... d j', j=2)
x1, x2 = x.chunk(2, dim=-1)
x = torch.cat((-x2, x1), dim=-1)
return x.flatten(-2)
def apply_rotary_pos_emb(tensor: torch.Tensor, sin: torch.Tensor, cos: torch.Tensor, mode: str) -> torch.Tensor:
sin = sin.unsqueeze(0).unsqueeze(1)
cos = cos.unsqueeze(0).unsqueeze(1)
if mode == 'interleave':
return (tensor * cos) + (rotate_every_two(tensor) * sin)
elif mode == 'half':
return (tensor * cos) + (rotate_half(tensor) * sin)
else:
raise NotImplementedError("mode error, only support half or interleave")
def apply_3drotary_pos_v3(q, k, freqs_cis, mat, high_precision=None):
'''
本方案:rope_matrix, matrix合一
'''
sincos_h, sincos_w, sincos_t = freqs_cis
sin_h, cos_h = sincos_h
sin_w, cos_w = sincos_w
sin_t, cos_t = sincos_t
sin = torch.cat((sin_h, sin_w, sin_t), dim=-1)
cos = torch.cat((cos_h, cos_w, cos_t), dim=-1)
if high_precision:
q = q.float() * cos.float() + (q @ mat).float() * sin.float()
k = k.float() * cos.float() + (k @ mat).float() * sin.float()
q, k = q.to(torch.bfloat16), k.to(torch.bfloat16)
else:
q = q * cos + (q @ mat) * sin
k = k * cos + (k @ mat) * sin
return q, k
def apply_3drotary_pos_v5(q, k, freqs_cis, mat):
'''
本方案:自定义融合算子实现
'''
import npu_ops_transformer_ext
sincos_h, sincos_w, sincos_t = freqs_cis
sin_h, cos_h = sincos_h
sin_w, cos_w = sincos_w
sin_t, cos_t = sincos_t
sin = torch.cat((sin_h, sin_w, sin_t), dim=-1)
cos = torch.cat((cos_h, cos_w, cos_t), dim=-1)
def rope(x, mat, cos, sin):
x = torch.ops.npu_ops_transformer_ext.rope_matrix(x, mat, sin, cos)
return x
q = rope(q, mat, cos, sin)
k = rope(k, mat, cos, sin)
return q, k
class ROPE3D(nn.Module):
def __init__(self, ):
super().__init__()
def forward(self, q, k, freqs_cis, mat):
return apply_3drotary_pos_v3(q, k, freqs_cis, mat)
def main():
b, n, s, d = 1, 24, 28800, 128
shape_lists_1d = [128]
shape_lists_2d = [64, 64]
shape_lists_3d = [44, 44, 40]
q = torch.randn((b, n, s, d), dtype=torch.bfloat16, device=device)
k = torch.randn((b, n, s, d), dtype=torch.bfloat16, device=device)
freqs_cis_3d = []
for shape in shape_lists_3d:
sincos = torch.randn((s, shape), dtype=torch.bfloat16, device=device)
freqs_cis_3d.append([sincos, sincos])
inter_mat_128 = get_interleave_matrix(128)
half_mat_44 = get_half_matrix(44)
half_mat_40 = get_half_matrix(40)
half_mat_44_44_40 = compose_3matrix(half_mat_44, half_mat_44, half_mat_40)
with record_function("3d rope v5"):
outq5, outk5 = apply_3drotary_pos_v5(q, k, freqs_cis_3d, half_mat_44_44_40)
torch.cuda.synchronize()
with record_function("3d rope v3"):
outq3, outk3 = apply_3drotary_pos_v3(q, k, freqs_cis_3d, half_mat_44_44_40, high_precision=True)
torch.cuda.synchronize()
denominator = torch.maximum(torch.abs(outq3), torch.abs(outq5))
abs_error = torch.abs(torch.abs(outq3) - torch.abs(outq5))
print(f'max error: {abs_error.max()}')
error = abs_error / (denominator + 1e-8)
assert ((torch.minimum(abs_error, error) > 0).sum() == 0)
if __name__ == '__main__':
main()