Compressor
产品支持情况
| 产品 | 是否支持 |
|---|---|
| Ascend 950PR/Ascend 950DT | √ |
| Atlas A3 训练系列产品/Atlas A3 推理系列产品 | √ |
| Atlas A2 训练系列产品/Atlas A2 推理系列产品 | × |
| Atlas 200I/500 A2 推理产品 | × |
| Atlas 推理系列加速卡产品 | × |
| Atlas 训练系列产品 | × |
功能说明
-
API功能:Compressor是推理场景下SAS和QLI的前处理算子,用于将每4或128个token的KV cache压缩成一个,然后每个token与这些压缩的KV cache进行DSA计算。在长序列的情况下,Compressor可以有效地减少计算开销。
-
计算公式:
压缩阶段:
- 计算矩阵乘法:
- C4A: [kv_statea,score_statea]=X@[WaKV,WaGate],[kv_stateb,score_stateb]=X@[WbKV,WbGate];\left[kv\_state^a, score\_state^a\right] = X @ \left[W^{aKV}, W^{aGate}\right], \left[kv\_state^b, score\_state^b\right] = X @ \left[W^{bKV}, W^{bGate}\right];
- C128A: [kv_state,score_state]=X@[WKV,WGate]\left[kv\_state, score\_state\right] = X @ \left[W^{KV}, W^{Gate}\right]
- 计算分组加法:
- C4A: score_statei′=[score_state[4(i−1)+1:4i,:]a;score_state[4i+1:4(i+1),:]b]+Ape, i=1,2,⋯ ,s4;score\_state_i^\prime = \left[score\_state_{\left[4(i-1)+1:4i,:\right]}^a; score\_state_{\left[4i+1:4(i+1),:\right]}^b\right] + Ape,~i=1,2,\cdots, \frac{s}{4};
- C128A: score_statei′=score_state[128(i−1)+1:128i,:]+Ape, i=1,2,⋯ ,s128;score\_state_i^\prime = score\_state_{\left[128(i-1)+1:128i,:\right]} + Ape,~i=1,2,\cdots, \frac{s}{128};
- 计算分组Softmax:
- C4A: Si′=softmax(score_statei′), i=1,2,⋯ ,s4;S_i^\prime = softmax(score\_state_i^\prime),~i=1,2,\cdots, \frac{s}{4};
- C128A: Si′=softmax(score_statei′), i=1,2,⋯ ,s128;S_i^\prime = softmax(score\_state_i^\prime),~i=1,2,\cdots, \frac{s}{128};
- 计算Hadamard乘积:
- C4A: (SH)i=Si′⊙[kv_state[4(i−1)+1:4i,:]a;kv_state[4i+1:4(i+1),:]b], i=1,2,⋯ ,s4;(S_H)_i = S_i^\prime \odot \left[kv\_state^a_{\left[4(i-1)+1:4i,:\right]} ; kv\_state^b_{\left[4i+1:4(i+1),:\right]}\right],~i=1,2,\cdots, \frac{s}{4};
- C128A: SH=Si′⊙kv_state;S_H = S_i^\prime \odot kv\_state;
- 沿着压缩轴分组求和:
- C4A: CiComp=[1]1×8@(SH)i, i=1,2,⋯ ,s4;C_{i}^{\text{Comp}} = \left[1\right]_{1\times8} @ (S_H)_i, ~i=1,2,\cdots, \frac{s}{4};
- C128A: CiComp=[1]1×128@(SH)i, i=1,2,⋯ ,s128;C_{i}^{\text{Comp}} = \left[1\right]_{1\times128} @ (S_H)_i, ~i=1,2,\cdots, \frac{s}{128};
后处理阶段:
- 计算RMSNorm:
- RMS(CComp)=1N∑i=j∗N(j+1)∗N(CiComp)2+norm_eps,N=head_dim, j=1,2,⋯ ,scmp_ratio\text{RMS}(C^{\text{Comp}}) = \sqrt{\frac{1}{N} \sum_{i=j* N}^{(j+1)* N} {(C_{i}^{\text{Comp}})}^{\text{2}} + norm\_eps} ,N=head\_dim, ~j=1,2,\cdots, \frac{s}{cmp\_ratio}
- RmsNorm(CComp)=norm_weight⋅CiCompRMS(CComp)\text{RmsNorm}(C^{\text{Comp}}) = norm\_weight \cdot \frac{C_{i}^{\text{Comp}}}{\text{RMS}(C^{\text{Comp}})}
- 计算Rope;
- 计算矩阵乘法:
-
主要计算过程为:
- 将输入XX与WKVW^{KV}做Matmul运算得到kv_statekv\_state,将输入XX与WGateW^{Gate}做Matmul运算后再与ApeApe做Add运算得到score_statescore\_state,kv_statekv\_state与score_statescore\_state根据输入的start_pos及cu_seqlens完成更新。
- 在coff为2的情况下对kv_statekv\_state和score_statescore\_state进行数据重排。
- 对score_statescore\_state进行softmax运算将softmax结果与kv_statekv\_state做Mul计算,后进行ReduceSum运算。
- 根据输入数据norm_weight、rope_sin、rope_cos,进行RMSNorm和Rope运算,得到cmp_kvcmp\_kv结果输出。
参数说明
| 参数名 | 输入/输出/属性 | 描述 | 数据类型 | 数据格式 |
|---|---|---|---|---|
| x | 输入 | 公式中的XX,表示原始不经压缩的数据。 | FLOAT16、BFLOAT16 | ND |
| wkv | 输入 | 公式中的WKVW^{KV},表示kv压缩权重。 | FLOAT16、BFLOAT16 | ND |
| wgate | 输入 | 公式中的WGateW^{Gate},表示gate压缩权重。 | FLOAT16、BFLOAT16 | ND |
| state_cache | 输入 | 公式中的[kv_state,score_state]\left[kv\_state, score\_state\right], 表示kv_state和score_state的历史数据。 | FLOAT32 | ND |
| ape | 输入 | 公式中的ApeApe,表示positional biases。 | FLOAT32 | ND |
| norm_weight | 输入 | 表示计算RmsNorm时的权重系数。 | FLOAT32 | ND |
| rope_sin | 输入 | 表示Rope计算时sin的权重系数。 | FLOAT32 | ND |
| rope_cos | 输入 | 表示Rope计算时cos的权重系数。 | FLOAT32 | ND |
| rope_head_dim | 属性 | 表示rope_cos和rope_sin的hidden层最小单元大小,当前仅支持64。 | INT32 | - |
| cmp_ratio | 属性 | 用于稀疏计算,表示数据压缩率。 | INT32 | - |
| state_block_table | 可选输入 | 表示state_cache存储使用的block映射表。当其中元素的值为0时,表示当前位置无需进行更新state_cache操作。 | INT32 | ND |
| cu_seqlens | 可选输入 | 表示不同Batch中的有效token数。 | INT32 | ND |
| seqused | 可选输入 | 表示不同Batch中实际参与压缩的token数,如果指定为None时,表示和每个Batch上的Sequence Length长度相同。 | INT32 | ND |
| start_pos | 可选输入 | 表示计算起始位置。 | INT32 | ND |
| coff | 可选属性 | 默认值1,支持1/2。当coff=1时,无需进行overlap数据重排。当coff=2时,需要进行overlap数据重排。 | INT32 | - |
| norm_eps | 可选属性 | 表示RmsNorm计算的权重系数。默认值1e-6。 | FLOAT32 | - |
| rotary_mode | 可选属性 | 表示Rope计算的模式。默认值1,支持1/2。rotary_mode为1时,代表half模式。rotary_mode为2时,代表interleave模式。 | INT32 | - |
| cache_mode | 可选属性 | 表示state_cache的存储模式,1表示连续buffer,2表示循环buffer。默认值1。目前A3暂不支持输入2。 | INT32 | - |
| cmp_kv | 输出 | 表示压缩后的数据。 | FLOAT16、BFLOAT16 | ND |
约束说明
- x参数维度含义:B(Batch Size)表示输入样本批量大小、S(Sequence Length)表示输入样本序列长度、H(Head Size)表示hidden层的大小、D(Head Dim)表示hidden层的最小单元大小、T表示所有Batch输入样本序列长度的累加和。
- 输入shape限制:
- wkv支持输入shape[coff* D,H]
- wgate支持输入shape[coff* D,H]
- state_cache支持输入shape[block_num,block_size,2* coff* D],要求block_num>0。
- ape支持输入shape[cmp_ratio,coff* D]
- norm_weight支持输入shape[D,]
- start_pos支持输入shape[B,]
- 若x的维度采用BS合轴,即x的输入shape为[T,H]
- rope_sin、rope_cos要求输入shape为[min(T,T//cmp_ratio+B),rope_head_dim]。
- cu_seqlens输入shape必须为[B+1,]。该参数中每个元素的值表示当前batch与之前所有batch的token数总和,即前缀和,因此后一个元素的值必须大于等于前一个元素的值,且第一位必须位0。
- seqused,支持输入shape[B,],要求每个Batch的有效token数要求小于等于对应Sequence Length长度,即seqused[n] <= cu_seqlens[n+1] - cu_seqlens[n],且不小于0。
- state_block_table支持输入shape[B,ceil(Smax/block_size)]。Smax为每个Batch中最大的Sequence Length,即Smax=max(start_pos)+max(cu_seqlens[n+1] - cu_seqlens[n])。
- cmp_kv,输出shape为[min(T,T//cmp_ratio+B),D]:compressed_tokens + compressed_tokens + ... + compressed_tokens + pad。
- 若x的维度不采用BS合轴,即x的输入shape为[B,S,H]
- rope_sin、rope_cos要求输入shape为[B,ceil(S/cmp_ratio),rope_head_dim]。
- cu_seqlens,参数必须为空。
- seqused,支持输入shape[B,],要求每个Batch的有效token数要求小于等于对应Sequence Length长度,即要求seqused[n] <= S,且不小于0。
- state_block_table支持输入shape[B,ceil(Smax/block_size)]。Smax为每个Batch中最大的Sequence Length,即Smax=max(start_pos)+S。
- cmp_kv,输出shape为[B,ceil(S/cmp_ratio),D]:(compressed_tokens+pad0) + (compressed_tokens+pad1) + ... + (compressed_tokens+padN)。
- 输入值域限制:
- 该接口支持B、S泛化,且存在如下场景限制:
- 部分长序列场景下,如果计算量过大可能会导致出现超过NPU内存的报错,注:这里计算量会受x输入shape的影响,值越大计算量越大。典型的长序列(即B、S的乘积或T较大)场景包括但不限于:
B S H 100 65525 4096 25 261120 4096 100 131072 4096 100 261120 4096
- 该接口支持B、S泛化,且存在如下场景限制:
- 输入属性限制:
- 支持D为128/512。
- 支持H为1K~10K,512对齐。
- 支持cmp_ratio为4/128。支持如下三种情况:
- C4A: D=512, coff=2, cmp_ratio=4;
- C4Li: D=128, coff=2, cmp_ratio=4;
- C128A: D=512, coff=1, cmp_ratio=128。
- 支持rotary_mode为2,Rope计算模式为interleave。
- 该接口支持aclgraph模式。
Atlas A3 推理系列产品 调用说明
-
单算子模式调用
import torch import torch_npu import numpy as np import custom_ops import torch.nn as nn import math def get_seq_used_by_batch(batch_idx, S, seqused, cu_seqlens): if seqused is not None: return seqused[batch_idx] else: if cu_seqlens is not None: return cu_seqlens[batch_idx + 1] - cu_seqlens[batch_idx] else: return S data_type = torch.bfloat16 hidden_size = 4096 rope_head_dim = 64 norm_eps = 1e-6 coff = 1 # 1:no overlap 2:overlap cmp_ratio = 128 rotary_mode = 2 cache_mode = 1 head_dim = 512 cu_seqlens = [0, 1] # ------------- B = 1 S = 1 S_max = 0 block_size = 128 start_pos = [8191] * B # (B,) start_p=8191 seqused = None # (B,), None时cu_seqlens的数据全部参与计算,否则按传参实际值计算 # BS是否合轴 bs_combine_flag = True update_flag = 1 save_state_seqlens = None if seqused is not None: seqused = torch.tensor(seqused).to(torch.int32) if start_pos is not None: start_pos = torch.tensor(start_pos).to(torch.int32) else: start_pos = torch.full((B,), start_p, dtype=torch.int32) if bs_combine_flag: if cu_seqlens is None: T = B * S if T !=0: cu_seqlens = torch.arange(0, T + 1, S, dtype=torch.int32) else: cu_seqlens = torch.zeros((B+1), dtype=torch.int32) else: cu_seqlens = torch.tensor(cu_seqlens).to(torch.int32) for i in range(B): if start_pos[i] + cu_seqlens[i + 1] - cu_seqlens[i] > S_max: S_max = start_pos[i] + cu_seqlens[i + 1] - cu_seqlens[i] else: cu_seqlens = None S_max = max(start_pos) + S ### ======================== gen input data start ============================= # page state if cache_mode == 1: max_block_num_per_batch = (S_max + block_size - 1) // block_size block_num = B * max_block_num_per_batch next_block_id = 1 print(f"max_block_num_per_batch: {max_block_num_per_batch}") block_table = torch.zeros(size=(B, max_block_num_per_batch), dtype=torch.int32) for i in range(B): # 需要读取state的范围 cur_start = start_pos[i] // cmp_ratio * cmp_ratio - cmp_ratio cur_end = start_pos[i] // cmp_ratio * cmp_ratio + cmp_ratio if start_pos[i] % cmp_ratio == 0: cur_end = start_pos[i] cur_end = min(cur_end, start_pos[i] + S) cur_start_block_id = (cur_start // block_size) if cur_start >= 0 else 0 cur_end_block_id = (cur_end - 1) // block_size for j in range(cur_start_block_id, cur_end_block_id + 1): block_table[i][j] = next_block_id next_block_id = next_block_id + 1 # 需要写入state的范围 end_pos = get_seq_used_by_batch(i, S, seqused, cu_seqlens) if save_state_seqlens is not None: next_start = start_pos[i] + end_pos - save_state_seqlens[i] next_end = start_pos[i] + end_pos else: next_start = (start_pos[i] + end_pos) // cmp_ratio * cmp_ratio - cmp_ratio next_end = (start_pos[i] + end_pos) // cmp_ratio * cmp_ratio + cmp_ratio if (start_pos[i] + end_pos) % cmp_ratio == 0: next_end = start_pos[i] + end_pos next_end = min(next_end, start_pos[i] + end_pos) next_start_block_id = (next_start // block_size) if next_start >= 0 else 0 next_end_block_id = (next_end - 1) // block_size for j in range(next_start_block_id, next_end_block_id + 1): if block_table[i][j] == 0: block_table[i][j] = next_block_id next_block_id = next_block_id + 1 if B==0: kv_state = torch.tensor(np.random.uniform(-10, 10, (0, block_size, coff * head_dim))).to(torch.float32) score_state = torch.tensor(np.random.uniform(-10, 10, (0, block_size, coff * head_dim))).to(torch.float32) else: kv_state = torch.tensor(np.random.uniform(-10, 10, (torch.max(block_table) + 1, block_size, coff * head_dim))).to(torch.float32) score_state = torch.tensor(np.random.uniform(-10, 10, (torch.max(block_table) + 1, block_size, coff * head_dim))).to(torch.float32) else: block_table = torch.tensor(random.sample(list(range(B)), B), dtype=torch.int32) token_size = (2 * cmp_ratio + S - 1) if coff == 2 else (cmp_ratio + S - 1) if B==0: kv_state = torch.tensor(np.random.uniform(kv_state_datarange[0], kv_state_datarange[1], (0, token_size, coff * head_dim))).to(torch.float32) score_state = torch.tensor(np.random.uniform(score_state_datarange[0], score_state_datarange[1], (0, token_size, coff * head_dim))).to(torch.float32) else: kv_state = torch.tensor(np.random.uniform(kv_state_datarange[0], kv_state_datarange[1], (B, token_size, coff * head_dim))).to(torch.float32) score_state = torch.tensor(np.random.uniform(score_state_datarange[0], score_state_datarange[1], (B, token_size, coff * head_dim))).to(torch.float32) # other input if bs_combine_flag: x_shape = (cu_seqlens[-1], hidden_size) rope_sin_shape = (min(x_shape[0], x_shape[0] // cmp_ratio + B), rope_head_dim) rope_cos_shape = rope_sin_shape else: x_shape = (B, S, hidden_size) rope_sin_shape = (B, (S + cmp_ratio - 1) // cmp_ratio, rope_head_dim) rope_cos_shape = rope_sin_shape x = torch.tensor(np.random.uniform(-10.0, 10.0, x_shape)).to(data_type).npu() wkv = torch.tensor(np.random.uniform(-10, 10, (coff * head_dim, hidden_size))).to(data_type).npu() wgate = torch.tensor(np.random.uniform(-10, 10, (coff * head_dim, hidden_size))).to(data_type).npu() ape = torch.tensor(np.random.uniform(-10, 10, (cmp_ratio, coff * head_dim))).to(torch.float32).npu() norm_weight = torch.tensor(np.random.uniform(-10, 10, (head_dim))).to(torch.float32).npu() rope_sin = torch.tensor(np.random.uniform(-1, 1, rope_sin_shape)).to(torch.float32).npu() rope_cos = torch.tensor(np.random.uniform(-1, 1, rope_cos_shape)).to(torch.float32).npu() if cache_mode == 1: # 连续buffer state_cache = torch.zeros((kv_state.shape[0], kv_state.shape[1], 2*kv_state.shape[2])) state_cache = state_cache.npu() state_cache[:, :, :state_cache.shape[2]//2] = kv_state.clone() state_cache[:, :, state_cache.shape[2]//2:] = score_state.clone() else: layer_pad = random.randint(1, 50) layer_start_idx = random.randint(0, layer_pad-1) print(f"layer_pad: {layer_pad}") print(f"layer_start_idx: {layer_start_idx}") state_cache_pad = torch.zeros((kv_state.shape[0],kv_state.shape[1]*kv_state.shape[2]*2+layer_pad)) print(f"state_cache_pad: shape {state_cache_pad.shape}") state_cache_pad = state_cache_pad.to("npu:%s" % DEVICE_ID) state_cache = state_cache_pad[:, layer_start_idx : layer_start_idx + kv_state.shape[1]*kv_state.shape[2]*2].view(-1, kv_state.shape[1], kv_state.shape[2]*2) state_cache = state_cache.to("npu:%s" % DEVICE_ID) state_cache[:, :, :state_cache.shape[2]//2] = kv_state.clone() state_cache[:, :, state_cache.shape[2]//2:] = score_state.clone() print(f"state_cache: shape {state_cache.shape}, dtype: {state_cache.dtype}, is_contiguous: {state_cache.is_contiguous()}, stride0: {state_cache.stride(0)}") block_table = block_table.npu() start_pos = torch.tensor(start_pos).to(torch.int32).npu() if cu_seqlens is not None: cu_seqlens = torch.tensor(cu_seqlens).to(torch.int32).npu() if seqused is not None: seqused = torch.tensor(seqused).to(torch.int32).npu() cmp_kv = ( torch.ops.custom.compressor( x, wkv, wgate, state_cache, ape, norm_weight, rope_sin, rope_cos, rope_head_dim = rope_head_dim, cmp_ratio = cmp_ratio, state_block_table = block_table, cu_seqlens = cu_seqlens, seqused = seqused, start_pos = start_pos, coff = coff, norm_eps = norm_eps, rotary_mode = rotary_mode, cache_mode = cache_mode ) ) -
aclgraph调用
import torch import torch_npu import numpy as np import torch.nn as nn import torchair import custom_ops import math def get_seq_used_by_batch(batch_idx, S, seqused, cu_seqlens): if seqused is not None: return seqused[batch_idx] else: if cu_seqlens is not None: return cu_seqlens[batch_idx + 1] - cu_seqlens[batch_idx] else: return S data_type = torch.bfloat16 hidden_size = 4096 rope_head_dim = 64 norm_eps = 1e-6 coff = 1 # 1:no overlap 2:overlap cmp_ratio = 128 rotary_mode = 2 cache_mode = 1 head_dim = 512 cu_seqlens = [0, 1] # ------------- B = 1 S = 1 S_max = 0 block_size = 128 start_pos = [8191] * B # (B,) start_p=8191 seqused = None # (B,), None时cu_seqlens的数据全部参与计算,否则按传参实际值计算 # BS是否合轴 bs_combine_flag = True update_flag = 1 save_state_seqlens = None if seqused is not None: seqused = torch.tensor(seqused).to(torch.int32) if start_pos is not None: start_pos = torch.tensor(start_pos).to(torch.int32) else: start_pos = torch.full((B,), start_p, dtype=torch.int32) if bs_combine_flag: if cu_seqlens is None: T = B * S if T !=0: cu_seqlens = torch.arange(0, T + 1, S, dtype=torch.int32) else: cu_seqlens = torch.zeros((B+1), dtype=torch.int32) else: cu_seqlens = torch.tensor(cu_seqlens).to(torch.int32) for i in range(B): if start_pos[i] + cu_seqlens[i + 1] - cu_seqlens[i] > S_max: S_max = start_pos[i] + cu_seqlens[i + 1] - cu_seqlens[i] else: cu_seqlens = None S_max = max(start_pos) + S ### ======================== gen input data start ============================= # page state if cache_mode == 1: max_block_num_per_batch = (S_max + block_size - 1) // block_size block_num = B * max_block_num_per_batch next_block_id = 1 print(f"max_block_num_per_batch: {max_block_num_per_batch}") block_table = torch.zeros(size=(B, max_block_num_per_batch), dtype=torch.int32) for i in range(B): # 需要读取state的范围 cur_start = start_pos[i] // cmp_ratio * cmp_ratio - cmp_ratio cur_end = start_pos[i] // cmp_ratio * cmp_ratio + cmp_ratio if start_pos[i] % cmp_ratio == 0: cur_end = start_pos[i] cur_end = min(cur_end, start_pos[i] + S) cur_start_block_id = (cur_start // block_size) if cur_start >= 0 else 0 cur_end_block_id = (cur_end - 1) // block_size for j in range(cur_start_block_id, cur_end_block_id + 1): block_table[i][j] = next_block_id next_block_id = next_block_id + 1 # 需要写入state的范围 end_pos = get_seq_used_by_batch(i, S, seqused, cu_seqlens) if save_state_seqlens is not None: next_start = start_pos[i] + end_pos - save_state_seqlens[i] next_end = start_pos[i] + end_pos else: next_start = (start_pos[i] + end_pos) // cmp_ratio * cmp_ratio - cmp_ratio next_end = (start_pos[i] + end_pos) // cmp_ratio * cmp_ratio + cmp_ratio if (start_pos[i] + end_pos) % cmp_ratio == 0: next_end = start_pos[i] + end_pos next_end = min(next_end, start_pos[i] + end_pos) next_start_block_id = (next_start // block_size) if next_start >= 0 else 0 next_end_block_id = (next_end - 1) // block_size for j in range(next_start_block_id, next_end_block_id + 1): if block_table[i][j] == 0: block_table[i][j] = next_block_id next_block_id = next_block_id + 1 if B==0: kv_state = torch.tensor(np.random.uniform(-10, 10, (0, block_size, coff * head_dim))).to(torch.float32) score_state = torch.tensor(np.random.uniform(-10, 10, (0, block_size, coff * head_dim))).to(torch.float32) else: kv_state = torch.tensor(np.random.uniform(-10, 10, (torch.max(block_table) + 1, block_size, coff * head_dim))).to(torch.float32) score_state = torch.tensor(np.random.uniform(-10, 10, (torch.max(block_table) + 1, block_size, coff * head_dim))).to(torch.float32) else: block_table = torch.tensor(random.sample(list(range(B)), B), dtype=torch.int32) token_size = (2 * cmp_ratio + S - 1) if coff == 2 else (cmp_ratio + S - 1) if B==0: kv_state = torch.tensor(np.random.uniform(kv_state_datarange[0], kv_state_datarange[1], (0, token_size, coff * head_dim))).to(torch.float32) score_state = torch.tensor(np.random.uniform(score_state_datarange[0], score_state_datarange[1], (0, token_size, coff * head_dim))).to(torch.float32) else: kv_state = torch.tensor(np.random.uniform(kv_state_datarange[0], kv_state_datarange[1], (B, token_size, coff * head_dim))).to(torch.float32) score_state = torch.tensor(np.random.uniform(score_state_datarange[0], score_state_datarange[1], (B, token_size, coff * head_dim))).to(torch.float32) # other input if bs_combine_flag: x_shape = (cu_seqlens[-1], hidden_size) rope_sin_shape = (min(x_shape[0], x_shape[0] // cmp_ratio + B), rope_head_dim) rope_cos_shape = rope_sin_shape else: x_shape = (B, S, hidden_size) rope_sin_shape = (B, (S + cmp_ratio - 1) // cmp_ratio, rope_head_dim) rope_cos_shape = rope_sin_shape x = torch.tensor(np.random.uniform(-10.0, 10.0, x_shape)).to(data_type).npu() wkv = torch.tensor(np.random.uniform(-10, 10, (coff * head_dim, hidden_size))).to(data_type).npu() wgate = torch.tensor(np.random.uniform(-10, 10, (coff * head_dim, hidden_size))).to(data_type).npu() ape = torch.tensor(np.random.uniform(-10, 10, (cmp_ratio, coff * head_dim))).to(torch.float32).npu() norm_weight = torch.tensor(np.random.uniform(-10, 10, (head_dim))).to(torch.float32).npu() rope_sin = torch.tensor(np.random.uniform(-1, 1, rope_sin_shape)).to(torch.float32).npu() rope_cos = torch.tensor(np.random.uniform(-1, 1, rope_cos_shape)).to(torch.float32).npu() if cache_mode == 1: # 连续buffer state_cache = torch.zeros((kv_state.shape[0], kv_state.shape[1], 2*kv_state.shape[2])) state_cache = state_cache.npu() state_cache[:, :, :state_cache.shape[2]//2] = kv_state.clone() state_cache[:, :, state_cache.shape[2]//2:] = score_state.clone() else: layer_pad = random.randint(1, 50) layer_start_idx = random.randint(0, layer_pad-1) print(f"layer_pad: {layer_pad}") print(f"layer_start_idx: {layer_start_idx}") state_cache_pad = torch.zeros((kv_state.shape[0],kv_state.shape[1]*kv_state.shape[2]*2+layer_pad)) print(f"state_cache_pad: shape {state_cache_pad.shape}") state_cache_pad = state_cache_pad.to("npu:%s" % DEVICE_ID) state_cache = state_cache_pad[:, layer_start_idx : layer_start_idx + kv_state.shape[1]*kv_state.shape[2]*2].view(-1, kv_state.shape[1], kv_state.shape[2]*2) state_cache = state_cache.to("npu:%s" % DEVICE_ID) state_cache[:, :, :state_cache.shape[2]//2] = kv_state.clone() state_cache[:, :, state_cache.shape[2]//2:] = score_state.clone() print(f"state_cache: shape {state_cache.shape}, dtype: {state_cache.dtype}, is_contiguous: {state_cache.is_contiguous()}, stride0: {state_cache.stride(0)}") block_table = block_table.npu() start_pos = torch.tensor(start_pos).to(torch.int32).npu() if cu_seqlens is not None: cu_seqlens = torch.tensor(cu_seqlens).to(torch.int32).npu() if seqused is not None: seqused = torch.tensor(seqused).to(torch.int32).npu() class CompressorNetwork(nn.Module): def __init__(self): super(CompressorNetwork, self).__init__() def forward(self, x, wkv, wgate, state_cache, ape, norm_weight, rope_sin, rope_cos, rope_head_dim, cmp_ratio, state_block_table = None, cu_seqlens = None, seqused = None, start_pos = None, coff = 1, norm_eps = 1e-6, rotary_mode = 1, cache_mode = 1): cmp_kv = ( torch.ops.custom.compressor( x, wkv, wgate, state_cache, ape, norm_weight, rope_sin, rope_cos, rope_head_dim = rope_head_dim, cmp_ratio = cmp_ratio, state_block_table = state_block_table, cu_seqlens = cu_seqlens, seqused = seqused, start_pos = start_pos, coff = coff, norm_eps = norm_eps, rotary_mode = rotary_mode, cache_mode = cache_mode ) ) return cmp_kv from torchair.configs.compiler_config import CompilerConfig config = CompilerConfig() config.mode = "reduce-overhead" npu_backend = torchair.get_npu_backend(compiler_config=config) torch._dynamo.reset() npu_mode = torch.compile(CompressorNetwork(), fullgraph=True, backend=npu_backend, dynamic=False) cmp_kv = npu_mode( x, wkv, wgate, state_cache, ape, norm_weight, rope_sin, rope_cos, rope_head_dim = rope_head_dim, cmp_ratio = cmp_ratio, state_block_table = block_table, cu_seqlens = cu_seqlens, seqused = seqused, start_pos = start_pos, coff = coff, norm_eps = norm_eps, rotary_mode = rotary_mode, cache_mode = cache_mode)
更多使用示例见pytest示例。