文件最后提交记录最后更新时间
QLI支持HiF8 Co-authored-by: miaofangzheng<miaofangzheng@huawei.com> # message auto-generated for no-merge-commit merge: !5347 merge master into master QLI支持HiF8 Created-by: miaofangzheng Commit-by: miaofangzheng Merged-by: cann-robot Description: ## 描述 <!--在这里详细描述你的改动,包括改动的原因和所采取的方法。--> QLI支持HiF8 ## 关联的Issue <!-- 如果这个PR是为了解决特定的Issue,请在这里提供Issue链接。例如:关联Issue #000--> <!-- 如果这个PR是为了解决特定的问题单,请在这里描述问题单单号。--> [#2448](https://gitcode.com/cann/ops-transformer/issues/2448) ## 测试 <!--描述进行了哪些测试来验证你的改动。包括但不限于二级冒烟、算子泛化等。--> ## 文档更新 <!--如果这个PR包含文档的更新,请在这里指出。例如:更新了README.md文件。--> ## 类型标签 <!-- [x] 表示选中 --> - [ ] 🐛 Bug 修复 - [x] ✨ 新特性 - [ ] ⚡ 性能优化 - [ ] ♻️ 重构 - [ ] 🧪 测试 - [ ] 📦 构建/CI - [ ] 🔧 配置变更 - [ ] 📝 文档更新 - [ ] ⬆️ 依赖升级 - [ ] 🔒 安全修复 - [ ] 🧹 代码清理 - [ ] ❓ 其他,请描述: See merge request: cann/ops-transformer!534715 天前
update sas&qsas&qli metadata ParamCheck + CostFunction Co-authored-by: qq_32807861<handongchen2@huawei.com> # message auto-generated for no-merge-commit merge: !5608 merge master into master update sas&qsas&qli metadata ParamCheck + CostFunction Created-by: qq_32807861 Commit-by: qq_32807861 Merged-by: cann-robot Description: ## 描述 sas、qsas、qli metadata参数校验迁移至aclnn接口; 新增scfa典型场景下性能优化处理 ## 关联的Issue <!-- 如果这个PR是为了解决特定的Issue,请在这里提供Issue链接。例如:关联Issue #000--> <!-- 如果这个PR是为了解决特定的问题单,请在这里描述问题单单号。--> ## 测试 <!--描述进行了哪些测试来验证你的改动。包括但不限于二级冒烟、算子泛化等。--> ## 文档更新 <!--如果这个PR包含文档的更新,请在这里指出。例如:更新了README.md文件。--> ## 类型标签 <!-- [x] 表示选中 --> - [ ] 🐛 Bug 修复 - [ ] ✨ 新特性 - [ ] ⚡ 性能优化 - [ ] ♻️ 重构 - [ ] 🧪 测试 - [ ] 📦 构建/CI - [ ] 🔧 配置变更 - [ ] 📝 文档更新 - [ ] ⬆️ 依赖升级 - [ ] 🔒 安全修复 - [ ] 🧹 代码清理 - [ ] ❓ 其他,请描述: See merge request: cann/ops-transformer!56088 天前
fix exp pytest bug Co-authored-by: zhengwenhui0817<zhengwenhui7@huawei.com> # message auto-generated for no-merge-commit merge: !5799 merge master into master fix exp pytest bug Created-by: zhengwenhui0817 Commit-by: zhengwenhui0817 Merged-by: cann-robot Description: ## 描述 <!--在这里详细描述你的改动,包括改动的原因和所采取的方法。--> fix exp pytest bug ## 关联的Issue <!-- 如果这个PR是为了解决特定的Issue,请在这里提供Issue链接。例如:关联Issue #000--> <!-- 如果这个PR是为了解决特定的问题单,请在这里描述问题单单号。--> ## 测试 <!--描述进行了哪些测试来验证你的改动。包括但不限于二级冒烟、算子泛化等。--> ## 文档更新 <!--如果这个PR包含文档的更新,请在这里指出。例如:更新了README.md文件。--> ## 类型标签 <!-- [x] 表示选中 --> - [ ] 🐛 Bug 修复 - [ ] ✨ 新特性 - [ ] ⚡ 性能优化 - [ ] ♻️ 重构 - [ ] 🧪 测试 - [ ] 📦 构建/CI - [ ] 🔧 配置变更 - [ ] 📝 文档更新 - [ ] ⬆️ 依赖升级 - [ ] 🔒 安全修复 - [ ] 🧹 代码清理 - [ ] ❓ 其他,请描述: See merge request: cann/ops-transformer!57994 天前
新增稀疏Attention等高效融合算子支持DeepSeekV4 Co-authored-by: wangzhe123456789<wangzhe92@huawei.com> # message auto-generated for no-merge-commit merge: !4596 merge publish_dsv4 into master 新增稀疏Attention等高效融合算子支持DeepSeekV4 Created-by: songjionghui Commit-by: wangzhe123456789 Merged-by: cann-robot Description: ## 描述 <!--在这里详细描述你的改动,包括改动的原因和所采取的方法。--> 针对DeepSeek V4新网络结构,在experimental目录下新增稀疏Attention等高效融合算子 新增算子包括: SparseAttnSharedkv KvQuantSparseAttnSharedkv Compressor KvQuantSparseAttnSharedkvMetadata QuantLightningIndexerMetadata SparseAttnSharedkvMetadata 适配算子包括: QuantLightningIndexer ## 关联的Issue <!-- 如果这个PR是为了解决特定的Issue,请在这里提供Issue链接。例如:关联Issue #000--> <!-- 如果这个PR是为了解决特定的问题单,请在这里描述问题单单号。--> 关联Issue [#2059](https://gitcode.com/cann/ops-transformer/issues/2059) ## 测试 <!--描述进行了哪些测试来验证你的改动。包括但不限于二级冒烟、算子泛化等。--> ## 文档更新 <!--如果这个PR包含文档的更新,请在这里指出。例如:更新了README.md文件。--> ## 类型标签 <!-- [x] 表示选中 --> - [ ] 🐛 Bug 修复 - [x] ✨ 新特性 - [ ] ⚡ 性能优化 - [ ] ♻️ 重构 - [ ] 🧪 测试 - [ ] 📦 构建/CI - [ ] 🔧 配置变更 - [ ] 📝 文档更新 - [ ] ⬆️ 依赖升级 - [ ] 🔒 安全修复 - [ ] 🧹 代码清理 - [ ] ❓ 其他,请描述: See merge request: cann/ops-transformer!45961 个月前
QLI支持HiF8 Co-authored-by: miaofangzheng<miaofangzheng@huawei.com> # message auto-generated for no-merge-commit merge: !5347 merge master into master QLI支持HiF8 Created-by: miaofangzheng Commit-by: miaofangzheng Merged-by: cann-robot Description: ## 描述 <!--在这里详细描述你的改动,包括改动的原因和所采取的方法。--> QLI支持HiF8 ## 关联的Issue <!-- 如果这个PR是为了解决特定的Issue,请在这里提供Issue链接。例如:关联Issue #000--> <!-- 如果这个PR是为了解决特定的问题单,请在这里描述问题单单号。--> [#2448](https://gitcode.com/cann/ops-transformer/issues/2448) ## 测试 <!--描述进行了哪些测试来验证你的改动。包括但不限于二级冒烟、算子泛化等。--> ## 文档更新 <!--如果这个PR包含文档的更新,请在这里指出。例如:更新了README.md文件。--> ## 类型标签 <!-- [x] 表示选中 --> - [ ] 🐛 Bug 修复 - [x] ✨ 新特性 - [ ] ⚡ 性能优化 - [ ] ♻️ 重构 - [ ] 🧪 测试 - [ ] 📦 构建/CI - [ ] 🔧 配置变更 - [ ] 📝 文档更新 - [ ] ⬆️ 依赖升级 - [ ] 🔒 安全修复 - [ ] 🧹 代码清理 - [ ] ❓ 其他,请描述: See merge request: cann/ops-transformer!534715 天前
README.md

QuantLightningIndexer

产品支持情况

产品 是否支持
Ascend 950PR/Ascend 950DT
Atlas A3 训练系列产品/Atlas A3 推理系列产品
Atlas A2 训练系列产品/Atlas A2 推理系列产品 ×
Atlas 200I/500 A2 推理产品 ×
Atlas 推理系列加速卡产品 ×
Atlas 训练系列产品 ×

功能说明

  • API功能:QuantLightningIndexer是推理场景下,稀疏attention前处理的计算,选出关键的稀疏token,并对输入query和key进行量化实现存8算8,获取最大收益。

  • 计算公式:

    out=Top-k{[1]1×g@[(W@[1]1×Sk)⊙ReLU((ScaleQ@ScaleKT)⊙(QindexQuant@(KindexQuant)T))]}out = \text{Top-}k\left\{[1]_{1\times g}@\left[(W@[1]_{1\times S_{k}})\odot\text{ReLU}\left(\left(Scale_Q@Scale_K^T\right)\odot\left(Q_{index}^{Quant}@{\left(K_{index}^{Quant}\right)}^T\right)\right)\right]\right\}

    主要计算过程为:

    1. 将某个token对应的输入参数queryQindexQuant∈Rg×dQ_{index}^{Quant}\in\R^{g\times d})乘以给定上下文keyKindexQuant∈RSk×dK_{index}^{Quant}\in\R^{S_{k}\times d}),得到相关性。
    2. 相关性结果与querykey对应的反量化系数query_dequant_scaleScaleQScale_Q)和key_dequant_scaleScaleKTScale_K^T)相乘,通过激活函数ReLUReLU过滤无效负相关信号后,得到当前Token与所有前序Token的相关性分数向量。
    3. 将其与权重系数weightsWW)相乘后,沿g的方向,选取前Top−kTop-k个索引值得到输出outout,作为Attention的输入。

参数说明

参数名 输入/输出/属性 描述 数据类型 数据格式
query 输入 公式中的QindexQuant∈Rg×d,表示输入IndexQueryQ_{index}^{Quant}\in\R^{g\times d},表示输入Index Query,不支持非连续。 INT8、FLOAT8_e4m3fn、HIFLOAT8 ND
key 输入 公式的KindexQuant∈RSk×d,表示压缩后的输入IndexKeyK_{index}^{Quant}\in\R^{S_{k}\times d},表示压缩后的输入Index Key,支持0轴非连续。 INT8、FLOAT8_e4m3fn、HIFLOAT8 ND
weights 输入 公式中的WW,表示权重系数,不支持非连续。 FLOAT16、FLOAT32 ND
query_dequant_scale 输入 公式中的ScaleQScale_Q,表示Index Query的反量化系数,不支持非连续 FLOAT16、FLOAT32 ND
key_dequant_scale 输入 公式中的ScaleKScale_K,表示Index Key的反量化系数,不支持非连续 FLOAT16、FLOAT32 ND
actual_seq_lengths_query 可选输入 表示不同Batch中query的有效token数 INT32 ND
actual_seq_lengths_key 可选输入 表示不同Batch中key的有效token数 INT32 ND
block_table 可选输入 表示PageAttention中KV存储使用的block映射表 INT32 ND
metadata 可选输入 QuantLightningIndexerMetadata算子传入的分核信息,包含使用核数、分块大小以及每个核处理数据的起始点等内容,shape大小为[1024],当前不支持传空 INT32 ND
query_quant_mode 属性 用于标识输入query的量化模式,当前支持Per-Token-Head量化模式,当前仅支持传入0 INT32 -
key_quant_mode 属性 用于标识输入key的量化模式,当前支持Per-Token-Head量化模式,当前仅支持传入0 INT32 -
layout_query 可选属性 用于标识输入query的数据排布格式,当前支持BSND、TND,默认值"BSND" STRING -
layout_key 可选属性 用于标识输入key的数据排布格式,当前仅支持传入PA_BSND,默认值"PA_BSND" STRING -
sparse_count 可选属性 代表topK阶段需要保留的block数量,支持[1, 2048],默认值2048 INT32 -
sparse_mode 可选属性 表示sparse的模式,支持0/3,数据类型支持int32。 sparse_mode为0时,代表defaultMask模式。sparse_mode为3时,代表rightDownCausal模式的mask,对应以右顶点为划分的下三角场景。 INT32 -
pre_tokens 可选属性 预留参数,表示attention需要和前几个Token计算关联,仅支持默认值2^63-1 INT64 -
next_tokens 可选属性 预留参数,表示attention需要和前几个Token计算关联,仅支持默认值2^63-1 INT64 -
cmp_ratio 可选属性 用于稀疏计算,表示key的压缩倍数。数据类型支持int32。Atlas A3 推理系列产品支持1/2/4/8/16/32/64/128,Ascend 950PR/Ascend 950DT支持1/4/128,默认值1。 INT32 -
return_value 可选属性 表示是否输出sparse_values。True表示输出,False表示不输出;仅支持默认值False BOOL -
sparse_indices 输出 公式中的输出Out,参与稀疏attention计算的token索引值 INT32 ND
sparse_values 输出 公式中的Indices输出对应的value值,目前暂不支持返回sparse_values。 FLOAT32 ND
  • Ascend 950PR/Ascend 950DT:query、key不支持INT8;weights、query_dequant_scale和key_dequant_scale不支持FLOAT16。
  • Atlas A3 训练系列产品/Atlas A3 推理系列产品、Atlas A2 训练系列产品/Atlas A2 推理系列产品:query、key不支持FLOAT8_e4m3fn和HIFLOAT8;weights、query_dequant_scale和key_dequant_scale不支持FLOAT32。

约束说明

  • 该接口支持图模式。
  • 该接口要求W⊙ScaleQW \odot Scale_Q的结果在float16(Atlas A3)/float32(Ascend 950PR/Ascend 950DT)的表示范围内。
  • 该接口的TopK过程对NAN排序是未定义行为。
  • 参数query中的D轴和参数key中的D轴值相等为128。
  • 参数query和key中的N轴分别仅支持64和1。
  • layout_query为TND时,actual_seq_lengths_query必须传入,且以该入参元素的数量作为B值,该入参中每个元素的值表示当前batch与之前所有batch的token数总和,即前缀和,因此后一个元素的值必须大于等于前一个元素的值。不能出现负值。
  • layout_key为PA_BSND时,actual_seq_lengths_key该入参必须传入。
  • PageAttention场景下,block_table必须为二维,第一维长度需要等于B,第二维长度不能小于maxBlockNumPerSeq(maxBlockNumPerSeq为每个batch中最大actual_seq_lengths_key对应的block数量),支持block_size取值为16的整数倍,最大支持到1024。
  • query、key、weights、query_dequant_scale、key_dequant_scale数据排布格式支持从多种维度解读,其中B(Batch Size)表示输入样本批量大小、S(Sequence Length)表示输入样本序列长度、H(Head Size)表示hidden层的大小、N(Head Num)表示多头数、D(Head Dim)表示hidden层最小的单元尺寸,且满足D=H/N、T表示所有Batch输入样本序列长度的累加和。

Atlas A3 推理系列产品 调用说明

  • 单算子模式调用

    import torch
    import torch_npu
    import numpy as np
    import torch.nn as nn
    import math
    import custom_ops
    
    n1 = 64
    n2 = 1
    d = 128
    block_size = 128
    layout_key = "PA_BSND"
    layout_query = "BSND"
    query_quant_mode = 0
    key_quant_mode = 0
    np.random.seed(0)
    # -------------
    b = 24
    t = None
    s1 = 4
    s2 = 512
    act_seq_q = None
    act_seq_k = None
    sparse_mode = 0
    sparse_count = 512
    cmp_ratio = 1
    max_block_table_num = (s2 + block_size - 1) // block_size
    block_table = torch.tensor([range(b * max_block_table_num)], dtype = torch.int32).reshape(b, -1)
    key = torch.tensor(np.random.uniform(-128, 127, (b * max_block_table_num, block_size, n2, d))).to(torch.int8)
    key_dequant_scale = torch.tensor(np.random.uniform(0, 10, (b * max_block_table_num, block_size, n2)))
    key_dequant_scale = key_dequant_scale.to(torch.float16)
    query = torch.tensor(np.random.uniform(-128, 127, (b, s1, n1, d))).to(torch.int8)
    query_dequant_scale = torch.tensor(np.random.uniform(0, 10, (b, s1, n1))).to(torch.float16)
    weights = torch.tensor(np.random.uniform(0, 0.01, (b, s1, n1))).to(torch.float16)
    actual_seq_lengths_query = torch.tensor(np.random.uniform(s1, s1, (b))).to(torch.int32) \
                                if act_seq_q is None else torch.tensor(act_seq_q).to(torch.int32)
    actual_seq_lengths_key = torch.tensor(np.random.uniform(s2, s2, (b))).to(torch.int32) \
                                if act_seq_k is None else torch.tensor(act_seq_k).to(torch.int32)
    max_seqlen_q = actual_seq_lengths_query.max().item()
    max_seqlen_k = actual_seq_lengths_key.max().item()
    metadata = torch.ops.custom.npu_quant_lightning_indexer_metadata (
                                    actual_seq_lengths_query = actual_seq_lengths_query.npu(),
                                    actual_seq_lengths_key = actual_seq_lengths_key.npu(),
                                    num_heads_q = n1,
                                    num_heads_k = n2,
                                    head_dim = d,
                                    query_quant_mode = query_quant_mode,
                                    key_quant_mode = key_quant_mode,
                                    batch_size = b,
                                    max_seqlen_q = max_seqlen_q,
                                    max_seqlen_k = max_seqlen_k,
                                    layout_query = layout_query,
                                    layout_key = layout_key,
                                    sparse_count = sparse_count,
                                    sparse_mode = sparse_mode,
                                    pre_tokens = (1<<63)-1,
                                    next_tokens = (1<<63)-1,
                                    cmp_ratio = cmp_ratio,
                                    device = 'npu:0')
    
    sparse_indices, sparse_values = torch.ops.custom.npu_quant_lightning_indexer(query.npu(), key.npu(), weights.npu(), query_dequant_scale.npu(),
                                                    key_dequant_scale.npu(),
                                                    actual_seq_lengths_query=actual_seq_lengths_query.npu(),
                                                    actual_seq_lengths_key=actual_seq_lengths_key.npu(),
                                                    block_table=block_table.npu(),
                                                    metadata = metadata,
                                                    query_quant_mode=query_quant_mode,
                                                    key_quant_mode=key_quant_mode,
                                                    layout_query=layout_query,
                                                    layout_key=layout_key, sparse_count=sparse_count,
                                                    sparse_mode=sparse_mode, pre_tokens=(1<<63)-1,
                                                    next_tokens=(1<<63)-1, cmp_ratio=cmp_ratio)
    
  • aclgarph调用

    import torch
    import torch_npu
    import numpy as np
    import torch.nn as nn
    import math
    import torchair
    import custom_ops
    from torchair.configs.compiler_config import CompilerConfig
    
    n1 = 64
    n2 = 1
    d = 128
    block_size = 128
    layout_key = "PA_BSND"
    layout_query = "BSND"
    query_quant_mode = 0
    key_quant_mode = 0
    np.random.seed(0)
    # -------------
    b = 24
    t = None
    s1 = 4
    s2 = 512
    act_seq_q = None
    act_seq_k = None
    sparse_mode = 3
    sparse_count = 512
    pre_tokens=(1<<63)-1
    next_tokens=(1<<63)-1
    cmp_ratio = 4
    max_block_table_num = (s2 + block_size - 1) // block_size
    block_table = torch.tensor([range(b * max_block_table_num)], dtype = torch.int32).reshape(b, -1).npu()
    key = torch.tensor(np.random.uniform(-128, 127, (b * max_block_table_num, block_size, n2, d))).to(torch.int8).npu()
    key_dequant_scale = torch.tensor(np.random.uniform(0, 10, (b * max_block_table_num, block_size, n2))).npu()
    key_dequant_scale = key_dequant_scale.to(torch.float16).npu()
    query = torch.tensor(np.random.uniform(-128, 127, (b, s1, n1, d))).to(torch.int8).npu()
    query_dequant_scale = torch.tensor(np.random.uniform(0, 10, (b, s1, n1))).to(torch.float16).npu()
    weights = torch.tensor(np.random.uniform(0, 0.01, (b, s1, n1))).to(torch.float16).npu()
    actual_seq_lengths_query = torch.tensor(np.random.uniform(s1, s1, (b))).to(torch.int32).npu() \
                                if act_seq_q is None else torch.tensor(act_seq_q).to(torch.int32).npu()
    actual_seq_lengths_key = torch.tensor(np.random.uniform(s2, s2, (b))).to(torch.int32).npu() \
                                if act_seq_k is None else torch.tensor(act_seq_k).to(torch.int32).npu()
    max_seqlen_q = actual_seq_lengths_query.max().item()
    max_seqlen_k = actual_seq_lengths_key.max().item()
    
    class QLINetwork(nn.Module):
        def __init__(self):
            super(QLINetwork, self).__init__()
    
        def forward(self, query, key, weights, q_scale, k_scale, query_quant_mode, key_quant_mode,
                    batch_size, num_heads_q, num_heads_k, head_dim,
                    actual_seq_lengths_query=None, actual_seq_lengths_key=None,
                    block_table=None, layout_query='BSND', layout_key='BSND',
                    sparse_count=512, sparse_mode=3, pre_tokens=(1<<63)-1,
                    next_tokens=(1<<63)-1, cmp_ratio=cmp_ratio, return_value=False):
            metadata = torch.ops.custom.npu_quant_lightning_indexer_metadata(
                                    actual_seq_lengths_query = actual_seq_lengths_query,
                                    actual_seq_lengths_key = actual_seq_lengths_key,
                                    num_heads_q = num_heads_q,
                                    num_heads_k = num_heads_k,
                                    head_dim = head_dim,
                                    query_quant_mode = query_quant_mode,
                                    key_quant_mode = key_quant_mode,
                                    batch_size = batch_size,
                                    max_seqlen_q = max_seqlen_q,
                                    max_seqlen_k = max_seqlen_k,
                                    layout_query = layout_query,
                                    layout_key = layout_key,
                                    sparse_count = sparse_count,
                                    sparse_mode = sparse_mode,
                                    pre_tokens = (1<<63)-1,
                                    next_tokens = (1<<63)-1,
                                    cmp_ratio = cmp_ratio,
                                    device = 'npu:0')
    
            sparse_indices, sparse_values = torch.ops.custom.npu_quant_lightning_indexer(query, key, weights,
                                                        q_scale, k_scale,
                                                        actual_seq_lengths_query=actual_seq_lengths_query,
                                                        actual_seq_lengths_key=actual_seq_lengths_key,
                                                        block_table=block_table, metadata=metadata,
                                                        query_quant_mode=query_quant_mode,
                                                        key_quant_mode=key_quant_mode,
                                                        layout_query=layout_query,
                                                        layout_key=layout_key, sparse_count=sparse_count,
                                                        sparse_mode=sparse_mode,pre_tokens=pre_tokens,
                                                        next_tokens=next_tokens, cmp_ratio=cmp_ratio,
                                                        return_value=return_value)
            return sparse_indices
    
    
    config = CompilerConfig()
    config.mode = "reduce-overhead"
    npu_backend = torchair.get_npu_backend(compiler_config=config)
    torch._dynamo.reset()
    npu_mode = torch.compile(QLINetwork().npu(), fullgraph=True, backend=npu_backend, dynamic=False)
    sparse_indices = npu_mode( query, key, weights, query_dequant_scale, key_dequant_scale,
                        query_quant_mode, key_quant_mode, b, n1, n2, d,
                        actual_seq_lengths_query=actual_seq_lengths_query,
                        actual_seq_lengths_key=actual_seq_lengths_key,
                        block_table=block_table,
                        layout_query=layout_query, layout_key=layout_key,
                        sparse_count=sparse_count, sparse_mode=sparse_mode,
                        pre_tokens=pre_tokens, next_tokens=next_tokens,
                        cmp_ratio=cmp_ratio, return_value=False)
    

更多使用示例见pytest示例