文件最后提交记录最后更新时间
compressor算子修改参数数据类型 Co-authored-by: 莫允扬<moyunyang@huawei.com> Co-authored-by: guigui_jzh<jinzhonghao@huawei.com> # message auto-generated for no-merge-commit merge: !5980 merge cpfp32 into master compressor算子修改参数数据类型 Created-by: myy268 Commit-by: 莫允扬;guigui_jzh Merged-by: cann-robot Description: ## 描述 <!--在这里详细描述你的改动,包括改动的原因和所采取的方法。--> compressor算子,normweight、ropesin、ropecos参数数据类型从bf16/fp16更改为float32。 ## 关联的Issue <!-- 如果这个PR是为了解决特定的Issue,请在这里提供Issue链接。例如:关联Issue #000--> <!-- 如果这个PR是为了解决特定的问题单,请在这里描述问题单单号。--> ## 测试 <!--描述进行了哪些测试来验证你的改动。包括但不限于二级冒烟、算子泛化等。--> ## 文档更新 <!--如果这个PR包含文档的更新,请在这里指出。例如:更新了README.md文件。--> ## 类型标签 <!-- [x] 表示选中 --> - [ ] 🐛 Bug 修复 - [ ] ✨ 新特性 - [ ] ⚡ 性能优化 - [ ] ♻️ 重构 - [ ] 🧪 测试 - [ ] 📦 构建/CI - [x] 🔧 配置变更 - [ ] 📝 文档更新 - [ ] ⬆️ 依赖升级 - [ ] 🔒 安全修复 - [ ] 🧹 代码清理 - [ ] ❓ 其他,请描述: See merge request: cann/ops-transformer!59801 天前
compressor算子修改参数数据类型 Co-authored-by: 莫允扬<moyunyang@huawei.com> Co-authored-by: guigui_jzh<jinzhonghao@huawei.com> # message auto-generated for no-merge-commit merge: !5980 merge cpfp32 into master compressor算子修改参数数据类型 Created-by: myy268 Commit-by: 莫允扬;guigui_jzh Merged-by: cann-robot Description: ## 描述 <!--在这里详细描述你的改动,包括改动的原因和所采取的方法。--> compressor算子,normweight、ropesin、ropecos参数数据类型从bf16/fp16更改为float32。 ## 关联的Issue <!-- 如果这个PR是为了解决特定的Issue,请在这里提供Issue链接。例如:关联Issue #000--> <!-- 如果这个PR是为了解决特定的问题单,请在这里描述问题单单号。--> ## 测试 <!--描述进行了哪些测试来验证你的改动。包括但不限于二级冒烟、算子泛化等。--> ## 文档更新 <!--如果这个PR包含文档的更新,请在这里指出。例如:更新了README.md文件。--> ## 类型标签 <!-- [x] 表示选中 --> - [ ] 🐛 Bug 修复 - [ ] ✨ 新特性 - [ ] ⚡ 性能优化 - [ ] ♻️ 重构 - [ ] 🧪 测试 - [ ] 📦 构建/CI - [x] 🔧 配置变更 - [ ] 📝 文档更新 - [ ] ⬆️ 依赖升级 - [ ] 🔒 安全修复 - [ ] 🧹 代码清理 - [ ] ❓ 其他,请描述: See merge request: cann/ops-transformer!59801 天前
compressor算子修改参数数据类型 Co-authored-by: 莫允扬<moyunyang@huawei.com> Co-authored-by: guigui_jzh<jinzhonghao@huawei.com> # message auto-generated for no-merge-commit merge: !5980 merge cpfp32 into master compressor算子修改参数数据类型 Created-by: myy268 Commit-by: 莫允扬;guigui_jzh Merged-by: cann-robot Description: ## 描述 <!--在这里详细描述你的改动,包括改动的原因和所采取的方法。--> compressor算子,normweight、ropesin、ropecos参数数据类型从bf16/fp16更改为float32。 ## 关联的Issue <!-- 如果这个PR是为了解决特定的Issue,请在这里提供Issue链接。例如:关联Issue #000--> <!-- 如果这个PR是为了解决特定的问题单,请在这里描述问题单单号。--> ## 测试 <!--描述进行了哪些测试来验证你的改动。包括但不限于二级冒烟、算子泛化等。--> ## 文档更新 <!--如果这个PR包含文档的更新,请在这里指出。例如:更新了README.md文件。--> ## 类型标签 <!-- [x] 表示选中 --> - [ ] 🐛 Bug 修复 - [ ] ✨ 新特性 - [ ] ⚡ 性能优化 - [ ] ♻️ 重构 - [ ] 🧪 测试 - [ ] 📦 构建/CI - [x] 🔧 配置变更 - [ ] 📝 文档更新 - [ ] ⬆️ 依赖升级 - [ ] 🔒 安全修复 - [ ] 🧹 代码清理 - [ ] ❓ 其他,请描述: See merge request: cann/ops-transformer!59801 天前
新增稀疏Attention等高效融合算子支持DeepSeekV4 Co-authored-by: wangzhe123456789<wangzhe92@huawei.com> # message auto-generated for no-merge-commit merge: !4596 merge publish_dsv4 into master 新增稀疏Attention等高效融合算子支持DeepSeekV4 Created-by: songjionghui Commit-by: wangzhe123456789 Merged-by: cann-robot Description: ## 描述 <!--在这里详细描述你的改动,包括改动的原因和所采取的方法。--> 针对DeepSeek V4新网络结构,在experimental目录下新增稀疏Attention等高效融合算子 新增算子包括: SparseAttnSharedkv KvQuantSparseAttnSharedkv Compressor KvQuantSparseAttnSharedkvMetadata QuantLightningIndexerMetadata SparseAttnSharedkvMetadata 适配算子包括: QuantLightningIndexer ## 关联的Issue <!-- 如果这个PR是为了解决特定的Issue,请在这里提供Issue链接。例如:关联Issue #000--> <!-- 如果这个PR是为了解决特定的问题单,请在这里描述问题单单号。--> 关联Issue [#2059](https://gitcode.com/cann/ops-transformer/issues/2059) ## 测试 <!--描述进行了哪些测试来验证你的改动。包括但不限于二级冒烟、算子泛化等。--> ## 文档更新 <!--如果这个PR包含文档的更新,请在这里指出。例如:更新了README.md文件。--> ## 类型标签 <!-- [x] 表示选中 --> - [ ] 🐛 Bug 修复 - [x] ✨ 新特性 - [ ] ⚡ 性能优化 - [ ] ♻️ 重构 - [ ] 🧪 测试 - [ ] 📦 构建/CI - [ ] 🔧 配置变更 - [ ] 📝 文档更新 - [ ] ⬆️ 依赖升级 - [ ] 🔒 安全修复 - [ ] 🧹 代码清理 - [ ] ❓ 其他,请描述: See merge request: cann/ops-transformer!45961 个月前
compressor算子修改参数数据类型 Co-authored-by: 莫允扬<moyunyang@huawei.com> Co-authored-by: guigui_jzh<jinzhonghao@huawei.com> # message auto-generated for no-merge-commit merge: !5980 merge cpfp32 into master compressor算子修改参数数据类型 Created-by: myy268 Commit-by: 莫允扬;guigui_jzh Merged-by: cann-robot Description: ## 描述 <!--在这里详细描述你的改动,包括改动的原因和所采取的方法。--> compressor算子,normweight、ropesin、ropecos参数数据类型从bf16/fp16更改为float32。 ## 关联的Issue <!-- 如果这个PR是为了解决特定的Issue,请在这里提供Issue链接。例如:关联Issue #000--> <!-- 如果这个PR是为了解决特定的问题单,请在这里描述问题单单号。--> ## 测试 <!--描述进行了哪些测试来验证你的改动。包括但不限于二级冒烟、算子泛化等。--> ## 文档更新 <!--如果这个PR包含文档的更新,请在这里指出。例如:更新了README.md文件。--> ## 类型标签 <!-- [x] 表示选中 --> - [ ] 🐛 Bug 修复 - [ ] ✨ 新特性 - [ ] ⚡ 性能优化 - [ ] ♻️ 重构 - [ ] 🧪 测试 - [ ] 📦 构建/CI - [x] 🔧 配置变更 - [ ] 📝 文档更新 - [ ] ⬆️ 依赖升级 - [ ] 🔒 安全修复 - [ ] 🧹 代码清理 - [ ] ❓ 其他,请描述: See merge request: cann/ops-transformer!59801 天前
README.md

Compressor

产品支持情况

产品 是否支持
Ascend 950PR/Ascend 950DT
Atlas A3 训练系列产品/Atlas A3 推理系列产品
Atlas A2 训练系列产品/Atlas A2 推理系列产品 ×
Atlas 200I/500 A2 推理产品 ×
Atlas 推理系列加速卡产品 ×
Atlas 训练系列产品 ×

功能说明

  • API功能:Compressor是推理场景下SAS和QLI的前处理算子,用于将每4或128个token的KV cache压缩成一个,然后每个token与这些压缩的KV cache进行DSA计算。在长序列的情况下,Compressor可以有效地减少计算开销。

  • 计算公式:

    压缩阶段:

    1. 计算矩阵乘法:
      • C4A: [kv_statea,score_statea]=X@[WaKV,WaGate],[kv_stateb,score_stateb]=X@[WbKV,WbGate];\left[kv\_state^a, score\_state^a\right] = X @ \left[W^{aKV}, W^{aGate}\right], \left[kv\_state^b, score\_state^b\right] = X @ \left[W^{bKV}, W^{bGate}\right];
      • C128A: [kv_state,score_state]=X@[WKV,WGate]\left[kv\_state, score\_state\right] = X @ \left[W^{KV}, W^{Gate}\right]
    2. 计算分组加法:
      • C4A: score_statei′=[score_state[4(i−1)+1:4i,:]a;score_state[4i+1:4(i+1),:]b]+Ape, i=1,2,⋯ ,s4;score\_state_i^\prime = \left[score\_state_{\left[4(i-1)+1:4i,:\right]}^a; score\_state_{\left[4i+1:4(i+1),:\right]}^b\right] + Ape,~i=1,2,\cdots, \frac{s}{4};
      • C128A: score_statei′=score_state[128(i−1)+1:128i,:]+Ape, i=1,2,⋯ ,s128;score\_state_i^\prime = score\_state_{\left[128(i-1)+1:128i,:\right]} + Ape,~i=1,2,\cdots, \frac{s}{128};
    3. 计算分组Softmax:
      • C4A: Si′=softmax(score_statei′), i=1,2,⋯ ,s4;S_i^\prime = softmax(score\_state_i^\prime),~i=1,2,\cdots, \frac{s}{4};
      • C128A: Si′=softmax(score_statei′), i=1,2,⋯ ,s128;S_i^\prime = softmax(score\_state_i^\prime),~i=1,2,\cdots, \frac{s}{128};
    4. 计算Hadamard乘积:
      • C4A: (SH)i=Si′⊙[kv_state[4(i−1)+1:4i,:]a;kv_state[4i+1:4(i+1),:]b], i=1,2,⋯ ,s4;(S_H)_i = S_i^\prime \odot \left[kv\_state^a_{\left[4(i-1)+1:4i,:\right]} ; kv\_state^b_{\left[4i+1:4(i+1),:\right]}\right],~i=1,2,\cdots, \frac{s}{4};
      • C128A: SH=Si′⊙kv_state;S_H = S_i^\prime \odot kv\_state;
    5. 沿着压缩轴分组求和:
      • C4A: CiComp=[1]1×8@(SH)i, i=1,2,⋯ ,s4;C_{i}^{\text{Comp}} = \left[1\right]_{1\times8} @ (S_H)_i, ~i=1,2,\cdots, \frac{s}{4};
      • C128A: CiComp=[1]1×128@(SH)i, i=1,2,⋯ ,s128;C_{i}^{\text{Comp}} = \left[1\right]_{1\times128} @ (S_H)_i, ~i=1,2,\cdots, \frac{s}{128};

    后处理阶段:

    1. 计算RMSNorm:
      • RMS(CComp)=1N∑i=j∗N(j+1)∗N(CiComp)2+norm_eps,N=head_dim, j=1,2,⋯ ,scmp_ratio\text{RMS}(C^{\text{Comp}}) = \sqrt{\frac{1}{N} \sum_{i=j* N}^{(j+1)* N} {(C_{i}^{\text{Comp}})}^{\text{2}} + norm\_eps} ,N=head\_dim, ~j=1,2,\cdots, \frac{s}{cmp\_ratio}
      • RmsNorm(CComp)=norm_weight⋅CiCompRMS(CComp)\text{RmsNorm}(C^{\text{Comp}}) = norm\_weight \cdot \frac{C_{i}^{\text{Comp}}}{\text{RMS}(C^{\text{Comp}})}
    2. 计算Rope;
  • 主要计算过程为:

    1. 将输入XXWKVW^{KV}做Matmul运算得到kv_statekv\_state,将输入XXWGateW^{Gate}做Matmul运算后再与ApeApe做Add运算得到score_statescore\_statekv_statekv\_statescore_statescore\_state根据输入的start_pos及cu_seqlens完成更新。
    2. 在coff为2的情况下对kv_statekv\_statescore_statescore\_state进行数据重排。
    3. score_statescore\_state进行softmax运算将softmax结果与kv_statekv\_state做Mul计算,后进行ReduceSum运算。
    4. 根据输入数据norm_weight、rope_sin、rope_cos,进行RMSNorm和Rope运算,得到cmp_kvcmp\_kv结果输出。

参数说明

参数名 输入/输出/属性 描述 数据类型 数据格式
x 输入 公式中的XX,表示原始不经压缩的数据。 FLOAT16、BFLOAT16 ND
wkv 输入 公式中的WKVW^{KV},表示kv压缩权重。 FLOAT16、BFLOAT16 ND
wgate 输入 公式中的WGateW^{Gate},表示gate压缩权重。 FLOAT16、BFLOAT16 ND
state_cache 输入 公式中的[kv_state,score_state]\left[kv\_state, score\_state\right], 表示kv_state和score_state的历史数据。 FLOAT32 ND
ape 输入 公式中的ApeApe,表示positional biases。 FLOAT32 ND
norm_weight 输入 表示计算RmsNorm时的权重系数。 FLOAT32 ND
rope_sin 输入 表示Rope计算时sin的权重系数。 FLOAT32 ND
rope_cos 输入 表示Rope计算时cos的权重系数。 FLOAT32 ND
rope_head_dim 属性 表示rope_cos和rope_sin的hidden层最小单元大小,当前仅支持64。 INT32 -
cmp_ratio 属性 用于稀疏计算,表示数据压缩率。 INT32 -
state_block_table 可选输入 表示state_cache存储使用的block映射表。当其中元素的值为0时,表示当前位置无需进行更新state_cache操作。 INT32 ND
cu_seqlens 可选输入 表示不同Batch中的有效token数。 INT32 ND
seqused 可选输入 表示不同Batch中实际参与压缩的token数,如果指定为None时,表示和每个Batch上的Sequence Length长度相同。 INT32 ND
start_pos 可选输入 表示计算起始位置。 INT32 ND
coff 可选属性 默认值1,支持1/2。当coff=1时,无需进行overlap数据重排。当coff=2时,需要进行overlap数据重排。 INT32 -
norm_eps 可选属性 表示RmsNorm计算的权重系数。默认值1e-6。 FLOAT32 -
rotary_mode 可选属性 表示Rope计算的模式。默认值1,支持1/2。rotary_mode为1时,代表half模式。rotary_mode为2时,代表interleave模式。 INT32 -
cache_mode 可选属性 表示state_cache的存储模式,1表示连续buffer,2表示循环buffer。默认值1。目前A3暂不支持输入2 INT32 -
cmp_kv 输出 表示压缩后的数据。 FLOAT16、BFLOAT16 ND

约束说明

  • x参数维度含义:B(Batch Size)表示输入样本批量大小、S(Sequence Length)表示输入样本序列长度、H(Head Size)表示hidden层的大小、D(Head Dim)表示hidden层的最小单元大小、T表示所有Batch输入样本序列长度的累加和。
  • 输入shape限制:
    • wkv支持输入shape[coff* D,H]
    • wgate支持输入shape[coff* D,H]
    • state_cache支持输入shape[block_num,block_size,2* coff* D],要求block_num>0。
    • ape支持输入shape[cmp_ratio,coff* D]
    • norm_weight支持输入shape[D,]
    • start_pos支持输入shape[B,]
    • 若x的维度采用BS合轴,即x的输入shape为[T,H]
      • rope_sin、rope_cos要求输入shape为[min(T,T//cmp_ratio+B),rope_head_dim]。
      • cu_seqlens输入shape必须为[B+1,]。该参数中每个元素的值表示当前batch与之前所有batch的token数总和,即前缀和,因此后一个元素的值必须大于等于前一个元素的值,且第一位必须位0。
      • seqused,支持输入shape[B,],要求每个Batch的有效token数要求小于等于对应Sequence Length长度,即seqused[n] <= cu_seqlens[n+1] - cu_seqlens[n],且不小于0。
      • state_block_table支持输入shape[B,ceil(Smax/block_size)]。Smax为每个Batch中最大的Sequence Length,即Smax=max(start_pos)+max(cu_seqlens[n+1] - cu_seqlens[n])。
      • cmp_kv,输出shape为[min(T,T//cmp_ratio+B),D]:compressed_tokens + compressed_tokens + ... + compressed_tokens + pad。
    • 若x的维度不采用BS合轴,即x的输入shape为[B,S,H]
      • rope_sin、rope_cos要求输入shape为[B,ceil(S/cmp_ratio),rope_head_dim]。
      • cu_seqlens,参数必须为空。
      • seqused,支持输入shape[B,],要求每个Batch的有效token数要求小于等于对应Sequence Length长度,即要求seqused[n] <= S,且不小于0。
      • state_block_table支持输入shape[B,ceil(Smax/block_size)]。Smax为每个Batch中最大的Sequence Length,即Smax=max(start_pos)+S。
      • cmp_kv,输出shape为[B,ceil(S/cmp_ratio),D]:(compressed_tokens+pad0) + (compressed_tokens+pad1) + ... + (compressed_tokens+padN)。
  • 输入值域限制:
    • 该接口支持B、S泛化,且存在如下场景限制:
      • 部分长序列场景下,如果计算量过大可能会导致出现超过NPU内存的报错,注:这里计算量会受x输入shape的影响,值越大计算量越大。典型的长序列(即B、S的乘积或T较大)场景包括但不限于:
      B S H
      100 65525 4096
      25 261120 4096
      100 131072 4096
      100 261120 4096
  • 输入属性限制:
    • 支持D为128/512。
    • 支持H为1K~10K,512对齐。
    • 支持cmp_ratio为4/128。支持如下三种情况:
      • C4A: D=512, coff=2, cmp_ratio=4;
      • C4Li: D=128, coff=2, cmp_ratio=4;
      • C128A: D=512, coff=1, cmp_ratio=128。
    • 支持rotary_mode为2,Rope计算模式为interleave。
    • 该接口支持aclgraph模式。

Atlas A3 推理系列产品 调用说明

  • 单算子模式调用

    import torch
    import torch_npu
    import numpy as np
    import custom_ops
    import torch.nn as nn
    import math
    
    def get_seq_used_by_batch(batch_idx, S, seqused, cu_seqlens):
        if seqused is not None:
            return seqused[batch_idx]
        else:
            if cu_seqlens is not None:
                return cu_seqlens[batch_idx + 1] - cu_seqlens[batch_idx]
            else:
                return S
    
    data_type = torch.bfloat16
    hidden_size = 4096
    rope_head_dim = 64
    norm_eps = 1e-6
    coff = 1 # 1:no overlap 2:overlap
    cmp_ratio = 128
    rotary_mode = 2
    cache_mode = 1
    head_dim = 512
    cu_seqlens = [0, 1]
    # ------------- 
    B = 1
    S = 1
    S_max = 0
    block_size = 128
    start_pos = [8191] * B # (B,)
    start_p=8191
    seqused = None # (B,), None时cu_seqlens的数据全部参与计算,否则按传参实际值计算
    
    # BS是否合轴
    bs_combine_flag = True
    update_flag = 1
    save_state_seqlens = None
    if seqused is not None:
        seqused = torch.tensor(seqused).to(torch.int32)
    if start_pos is not None:
        start_pos = torch.tensor(start_pos).to(torch.int32)
    else:
        start_pos = torch.full((B,), start_p, dtype=torch.int32)
    
    if bs_combine_flag:
        if cu_seqlens is None:
            T = B * S
            if T !=0:
                cu_seqlens = torch.arange(0, T + 1, S, dtype=torch.int32)
            else:
                cu_seqlens = torch.zeros((B+1), dtype=torch.int32)
        else:
            cu_seqlens = torch.tensor(cu_seqlens).to(torch.int32)
        for i in range(B):
            if start_pos[i] + cu_seqlens[i + 1] - cu_seqlens[i] > S_max:
                S_max = start_pos[i] + cu_seqlens[i + 1] - cu_seqlens[i] 
    else:
        cu_seqlens = None
        S_max = max(start_pos) + S
    ### ======================== gen input data start =============================
    # page state
    if cache_mode == 1:
        max_block_num_per_batch = (S_max + block_size - 1) // block_size
        block_num = B * max_block_num_per_batch
        next_block_id = 1
        print(f"max_block_num_per_batch: {max_block_num_per_batch}")
        block_table = torch.zeros(size=(B, max_block_num_per_batch), dtype=torch.int32)
        for i in range(B):
            # 需要读取state的范围
            cur_start = start_pos[i] // cmp_ratio * cmp_ratio - cmp_ratio
            cur_end = start_pos[i] // cmp_ratio * cmp_ratio + cmp_ratio
            if start_pos[i] % cmp_ratio == 0:
                cur_end = start_pos[i]
            cur_end = min(cur_end, start_pos[i] + S)
            cur_start_block_id = (cur_start // block_size) if cur_start >= 0 else 0
            cur_end_block_id = (cur_end - 1) // block_size
            for j in range(cur_start_block_id, cur_end_block_id + 1):
                block_table[i][j] = next_block_id
                next_block_id = next_block_id + 1
            # 需要写入state的范围
            end_pos = get_seq_used_by_batch(i, S, seqused, cu_seqlens)
            if save_state_seqlens is not None:
                next_start = start_pos[i] + end_pos - save_state_seqlens[i]
                next_end = start_pos[i] + end_pos
            else:
                next_start = (start_pos[i] + end_pos) // cmp_ratio * cmp_ratio - cmp_ratio
                next_end = (start_pos[i] + end_pos) // cmp_ratio * cmp_ratio + cmp_ratio
                if (start_pos[i] + end_pos) % cmp_ratio == 0:
                    next_end = start_pos[i] + end_pos
            next_end = min(next_end, start_pos[i] + end_pos)
            next_start_block_id = (next_start // block_size) if next_start >= 0 else 0
            next_end_block_id = (next_end - 1) // block_size
            for j in range(next_start_block_id, next_end_block_id + 1):
                if block_table[i][j] == 0:
                    block_table[i][j] = next_block_id
                    next_block_id = next_block_id + 1
    
        if B==0:
            kv_state = torch.tensor(np.random.uniform(-10, 10, (0, block_size, coff * head_dim))).to(torch.float32)
            score_state = torch.tensor(np.random.uniform(-10, 10, (0, block_size, coff * head_dim))).to(torch.float32)
        else:
            kv_state = torch.tensor(np.random.uniform(-10, 10, (torch.max(block_table) + 1, block_size, coff * head_dim))).to(torch.float32)
            score_state = torch.tensor(np.random.uniform(-10, 10, (torch.max(block_table) + 1, block_size, coff * head_dim))).to(torch.float32)
    else:
        block_table = torch.tensor(random.sample(list(range(B)), B), dtype=torch.int32)
        token_size = (2 * cmp_ratio + S - 1) if coff == 2 else (cmp_ratio + S - 1)
        if B==0:
            kv_state = torch.tensor(np.random.uniform(kv_state_datarange[0], kv_state_datarange[1], (0, token_size, coff * head_dim))).to(torch.float32)
            score_state = torch.tensor(np.random.uniform(score_state_datarange[0], score_state_datarange[1], (0, token_size, coff * head_dim))).to(torch.float32)
        else:
            kv_state = torch.tensor(np.random.uniform(kv_state_datarange[0], kv_state_datarange[1], (B, token_size, coff * head_dim))).to(torch.float32)
            score_state = torch.tensor(np.random.uniform(score_state_datarange[0], score_state_datarange[1], (B, token_size, coff * head_dim))).to(torch.float32)
    
    # other input
    if bs_combine_flag:
        x_shape = (cu_seqlens[-1], hidden_size)
        rope_sin_shape = (min(x_shape[0], x_shape[0] // cmp_ratio + B), rope_head_dim)
        rope_cos_shape = rope_sin_shape
    else:
        x_shape = (B, S, hidden_size)
        rope_sin_shape = (B, (S + cmp_ratio - 1) // cmp_ratio, rope_head_dim)
        rope_cos_shape = rope_sin_shape
    
    x = torch.tensor(np.random.uniform(-10.0, 10.0, x_shape)).to(data_type).npu()
    wkv = torch.tensor(np.random.uniform(-10, 10, (coff * head_dim, hidden_size))).to(data_type).npu()
    wgate = torch.tensor(np.random.uniform(-10, 10, (coff * head_dim, hidden_size))).to(data_type).npu()
    ape = torch.tensor(np.random.uniform(-10, 10, (cmp_ratio, coff * head_dim))).to(torch.float32).npu()
    norm_weight = torch.tensor(np.random.uniform(-10, 10, (head_dim))).to(torch.float32).npu()
    rope_sin = torch.tensor(np.random.uniform(-1, 1, rope_sin_shape)).to(torch.float32).npu()
    rope_cos = torch.tensor(np.random.uniform(-1, 1, rope_cos_shape)).to(torch.float32).npu()
    if cache_mode == 1:  # 连续buffer
        state_cache = torch.zeros((kv_state.shape[0], kv_state.shape[1], 2*kv_state.shape[2]))
        state_cache = state_cache.npu()
        state_cache[:, :, :state_cache.shape[2]//2] = kv_state.clone()
        state_cache[:, :, state_cache.shape[2]//2:] = score_state.clone()
    else:
        layer_pad = random.randint(1, 50)
        layer_start_idx = random.randint(0, layer_pad-1)
        print(f"layer_pad: {layer_pad}")
        print(f"layer_start_idx: {layer_start_idx}")
        state_cache_pad = torch.zeros((kv_state.shape[0],kv_state.shape[1]*kv_state.shape[2]*2+layer_pad))
        print(f"state_cache_pad: shape {state_cache_pad.shape}")
        state_cache_pad = state_cache_pad.to("npu:%s" % DEVICE_ID)
        state_cache = state_cache_pad[:, layer_start_idx : layer_start_idx + kv_state.shape[1]*kv_state.shape[2]*2].view(-1, kv_state.shape[1], kv_state.shape[2]*2)
        state_cache = state_cache.to("npu:%s" % DEVICE_ID)
        state_cache[:, :, :state_cache.shape[2]//2] = kv_state.clone()
        state_cache[:, :, state_cache.shape[2]//2:] = score_state.clone()
        print(f"state_cache: shape {state_cache.shape}, dtype: {state_cache.dtype}, is_contiguous: {state_cache.is_contiguous()}, stride0: {state_cache.stride(0)}")
    
    block_table = block_table.npu()
    start_pos = torch.tensor(start_pos).to(torch.int32).npu()
    if cu_seqlens is not None:
        cu_seqlens = torch.tensor(cu_seqlens).to(torch.int32).npu()
    if seqused is not None:
        seqused = torch.tensor(seqused).to(torch.int32).npu()
    
    cmp_kv = (
        torch.ops.custom.compressor(
            x,
            wkv,
            wgate,
            state_cache,
            ape,
            norm_weight, 
            rope_sin,
            rope_cos,
            rope_head_dim = rope_head_dim,
            cmp_ratio = cmp_ratio,
            state_block_table = block_table,
            cu_seqlens = cu_seqlens,
            seqused = seqused,
            start_pos = start_pos,
            coff = coff,
            norm_eps = norm_eps,
            rotary_mode = rotary_mode,
            cache_mode = cache_mode
        )
    )
    
  • aclgraph调用

    import torch
    import torch_npu
    import numpy as np
    import torch.nn as nn
    import torchair
    import custom_ops
    import math
    
    def get_seq_used_by_batch(batch_idx, S, seqused, cu_seqlens):
        if seqused is not None:
            return seqused[batch_idx]
        else:
            if cu_seqlens is not None:
                return cu_seqlens[batch_idx + 1] - cu_seqlens[batch_idx]
            else:
                return S
    
    data_type = torch.bfloat16
    hidden_size = 4096
    rope_head_dim = 64
    norm_eps = 1e-6
    coff = 1 # 1:no overlap 2:overlap
    cmp_ratio = 128
    rotary_mode = 2
    cache_mode = 1
    head_dim = 512
    cu_seqlens = [0, 1]
    # ------------- 
    B = 1
    S = 1
    S_max = 0
    block_size = 128
    start_pos = [8191] * B # (B,)
    start_p=8191
    seqused = None # (B,), None时cu_seqlens的数据全部参与计算,否则按传参实际值计算
    
    # BS是否合轴
    bs_combine_flag = True
    update_flag = 1
    save_state_seqlens = None
    if seqused is not None:
        seqused = torch.tensor(seqused).to(torch.int32)
    if start_pos is not None:
        start_pos = torch.tensor(start_pos).to(torch.int32)
    else:
        start_pos = torch.full((B,), start_p, dtype=torch.int32)
    
    if bs_combine_flag:
        if cu_seqlens is None:
            T = B * S
            if T !=0:
                cu_seqlens = torch.arange(0, T + 1, S, dtype=torch.int32)
            else:
                cu_seqlens = torch.zeros((B+1), dtype=torch.int32)
        else:
            cu_seqlens = torch.tensor(cu_seqlens).to(torch.int32)
        for i in range(B):
            if start_pos[i] + cu_seqlens[i + 1] - cu_seqlens[i] > S_max:
                S_max = start_pos[i] + cu_seqlens[i + 1] - cu_seqlens[i] 
    else:
        cu_seqlens = None
        S_max = max(start_pos) + S
    ### ======================== gen input data start =============================
    # page state
    if cache_mode == 1:
        max_block_num_per_batch = (S_max + block_size - 1) // block_size
        block_num = B * max_block_num_per_batch
        next_block_id = 1
        print(f"max_block_num_per_batch: {max_block_num_per_batch}")
        block_table = torch.zeros(size=(B, max_block_num_per_batch), dtype=torch.int32)
        for i in range(B):
            # 需要读取state的范围
            cur_start = start_pos[i] // cmp_ratio * cmp_ratio - cmp_ratio
            cur_end = start_pos[i] // cmp_ratio * cmp_ratio + cmp_ratio
            if start_pos[i] % cmp_ratio == 0:
                cur_end = start_pos[i]
            cur_end = min(cur_end, start_pos[i] + S)
            cur_start_block_id = (cur_start // block_size) if cur_start >= 0 else 0
            cur_end_block_id = (cur_end - 1) // block_size
            for j in range(cur_start_block_id, cur_end_block_id + 1):
                block_table[i][j] = next_block_id
                next_block_id = next_block_id + 1
            # 需要写入state的范围
            end_pos = get_seq_used_by_batch(i, S, seqused, cu_seqlens)
            if save_state_seqlens is not None:
                next_start = start_pos[i] + end_pos - save_state_seqlens[i]
                next_end = start_pos[i] + end_pos
            else:
                next_start = (start_pos[i] + end_pos) // cmp_ratio * cmp_ratio - cmp_ratio
                next_end = (start_pos[i] + end_pos) // cmp_ratio * cmp_ratio + cmp_ratio
                if (start_pos[i] + end_pos) % cmp_ratio == 0:
                    next_end = start_pos[i] + end_pos
            next_end = min(next_end, start_pos[i] + end_pos)
            next_start_block_id = (next_start // block_size) if next_start >= 0 else 0
            next_end_block_id = (next_end - 1) // block_size
            for j in range(next_start_block_id, next_end_block_id + 1):
                if block_table[i][j] == 0:
                    block_table[i][j] = next_block_id
                    next_block_id = next_block_id + 1
    
        if B==0:
            kv_state = torch.tensor(np.random.uniform(-10, 10, (0, block_size, coff * head_dim))).to(torch.float32)
            score_state = torch.tensor(np.random.uniform(-10, 10, (0, block_size, coff * head_dim))).to(torch.float32)
        else:
            kv_state = torch.tensor(np.random.uniform(-10, 10, (torch.max(block_table) + 1, block_size, coff * head_dim))).to(torch.float32)
            score_state = torch.tensor(np.random.uniform(-10, 10, (torch.max(block_table) + 1, block_size, coff * head_dim))).to(torch.float32)
    else:
        block_table = torch.tensor(random.sample(list(range(B)), B), dtype=torch.int32)
        token_size = (2 * cmp_ratio + S - 1) if coff == 2 else (cmp_ratio + S - 1)
        if B==0:
            kv_state = torch.tensor(np.random.uniform(kv_state_datarange[0], kv_state_datarange[1], (0, token_size, coff * head_dim))).to(torch.float32)
            score_state = torch.tensor(np.random.uniform(score_state_datarange[0], score_state_datarange[1], (0, token_size, coff * head_dim))).to(torch.float32)
        else:
            kv_state = torch.tensor(np.random.uniform(kv_state_datarange[0], kv_state_datarange[1], (B, token_size, coff * head_dim))).to(torch.float32)
            score_state = torch.tensor(np.random.uniform(score_state_datarange[0], score_state_datarange[1], (B, token_size, coff * head_dim))).to(torch.float32)
    
    # other input
    if bs_combine_flag:
        x_shape = (cu_seqlens[-1], hidden_size)
        rope_sin_shape = (min(x_shape[0], x_shape[0] // cmp_ratio + B), rope_head_dim)
        rope_cos_shape = rope_sin_shape
    else:
        x_shape = (B, S, hidden_size)
        rope_sin_shape = (B, (S + cmp_ratio - 1) // cmp_ratio, rope_head_dim)
        rope_cos_shape = rope_sin_shape
    
    x = torch.tensor(np.random.uniform(-10.0, 10.0, x_shape)).to(data_type).npu()
    wkv = torch.tensor(np.random.uniform(-10, 10, (coff * head_dim, hidden_size))).to(data_type).npu()
    wgate = torch.tensor(np.random.uniform(-10, 10, (coff * head_dim, hidden_size))).to(data_type).npu()
    ape = torch.tensor(np.random.uniform(-10, 10, (cmp_ratio, coff * head_dim))).to(torch.float32).npu()
    norm_weight = torch.tensor(np.random.uniform(-10, 10, (head_dim))).to(torch.float32).npu()
    rope_sin = torch.tensor(np.random.uniform(-1, 1, rope_sin_shape)).to(torch.float32).npu()
    rope_cos = torch.tensor(np.random.uniform(-1, 1, rope_cos_shape)).to(torch.float32).npu()
    if cache_mode == 1:  # 连续buffer
        state_cache = torch.zeros((kv_state.shape[0], kv_state.shape[1], 2*kv_state.shape[2]))
        state_cache = state_cache.npu()
        state_cache[:, :, :state_cache.shape[2]//2] = kv_state.clone()
        state_cache[:, :, state_cache.shape[2]//2:] = score_state.clone()
    else:
        layer_pad = random.randint(1, 50)
        layer_start_idx = random.randint(0, layer_pad-1)
        print(f"layer_pad: {layer_pad}")
        print(f"layer_start_idx: {layer_start_idx}")
        state_cache_pad = torch.zeros((kv_state.shape[0],kv_state.shape[1]*kv_state.shape[2]*2+layer_pad))
        print(f"state_cache_pad: shape {state_cache_pad.shape}")
        state_cache_pad = state_cache_pad.to("npu:%s" % DEVICE_ID)
        state_cache = state_cache_pad[:, layer_start_idx : layer_start_idx + kv_state.shape[1]*kv_state.shape[2]*2].view(-1, kv_state.shape[1], kv_state.shape[2]*2)
        state_cache = state_cache.to("npu:%s" % DEVICE_ID)
        state_cache[:, :, :state_cache.shape[2]//2] = kv_state.clone()
        state_cache[:, :, state_cache.shape[2]//2:] = score_state.clone()
        print(f"state_cache: shape {state_cache.shape}, dtype: {state_cache.dtype}, is_contiguous: {state_cache.is_contiguous()}, stride0: {state_cache.stride(0)}")
    
    block_table = block_table.npu()
    start_pos = torch.tensor(start_pos).to(torch.int32).npu()
    if cu_seqlens is not None:
        cu_seqlens = torch.tensor(cu_seqlens).to(torch.int32).npu()
    if seqused is not None:
        seqused = torch.tensor(seqused).to(torch.int32).npu()
    
    class CompressorNetwork(nn.Module):
        def __init__(self):
            super(CompressorNetwork, self).__init__()
    
        def forward(self, x, wkv, wgate, state_cache, ape, norm_weight, rope_sin,         
                    rope_cos, rope_head_dim, cmp_ratio, state_block_table = None, cu_seqlens = None, 
                    seqused = None, start_pos = None, coff = 1, norm_eps = 1e-6, rotary_mode = 1, cache_mode = 1):
            cmp_kv = (
                torch.ops.custom.compressor(
                    x,
                    wkv,
                    wgate,
                    state_cache,
                    ape,
                    norm_weight, 
                    rope_sin,
                    rope_cos,
                    rope_head_dim = rope_head_dim,
                    cmp_ratio = cmp_ratio,
                    state_block_table = state_block_table,
                    cu_seqlens = cu_seqlens,
                    seqused = seqused,
                    start_pos = start_pos,
                    coff = coff,
                    norm_eps = norm_eps,
                    rotary_mode = rotary_mode,
                    cache_mode = cache_mode
                )
            )
            return cmp_kv
    
    from torchair.configs.compiler_config import CompilerConfig
    config = CompilerConfig()
    config.mode = "reduce-overhead"
    npu_backend = torchair.get_npu_backend(compiler_config=config)
    torch._dynamo.reset()
    npu_mode = torch.compile(CompressorNetwork(), fullgraph=True, backend=npu_backend, dynamic=False)
    cmp_kv = npu_mode(
                    x,
                    wkv,
                    wgate,
                    state_cache,
                    ape,
                    norm_weight, 
                    rope_sin,
                    rope_cos,
                    rope_head_dim = rope_head_dim,
                    cmp_ratio = cmp_ratio,
                    state_block_table = block_table,
                    cu_seqlens = cu_seqlens,
                    seqused = seqused,
                    start_pos = start_pos,
                    coff = coff,
                    norm_eps = norm_eps,
                    rotary_mode = rotary_mode,
                    cache_mode = cache_mode)
    

更多使用示例见pytest示例