NsaCompressGrad
产品支持情况
| 产品 | 是否支持 |
|---|---|
| Atlas A3 训练系列产品/Atlas A3 推理系列产品 | × |
| Atlas A2 训练系列产品 | √ |
| Atlas A2 推理系列产品 | × |
功能说明
-
算子功能:aclnnNsaCompress算子的反向计算。
-
计算公式: 选择注意力的正向计算公式如下:
dw=dk_cmp⋅K⊤\text{dw} = \text{dk\_cmp} \cdot K^\top
dk=W⊤⋅dk_cmp\text{dk} = W^\top \cdot \text{dk\_cmp}
参数说明
| 参数名 | 输入/输出/属性 | 描述 | 数据类型 | 数据格式 |
|---|---|---|---|---|
| outputGrad | 输入 | 正向算子输出的反向梯度,shape支持[T, N, D]。 | BFLOAT16、FLOAT16 | ND |
| input | 输入 | 待压缩张量,shape支持[T, N, D]。 | BFLOAT16、FLOAT16 | ND |
| weight | 输入 | 压缩权重,shape为[compressBlockSize, N],与input满足broadcast关系。 | BFLOAT16、FLOAT16 | ND |
| actSeqLenOptional | 输入 | 每个Batch对应的S大小,batch序列长度不等时需输入。 | INT64 | ND |
| compressBlockSize | 输入 | 压缩滑窗大小。 | INT64 | - |
| compressStride | 输入 | 两次压缩滑窗间隔大小。 | INT64 | - |
| actSeqLenType | 输入 | 序列长度类型,0表示cumsum结果,1表示每个batch序列大小,当前仅支持0。 | INT64 | - |
| layoutOptional | 输入 | 输入数据排布格式,支持TND。 | String | - |
| inputGrad | 输出 | input的梯度,shape与input保持一致。 | BFLOAT16、FLOAT16 | ND |
| weightGrad | 输出 | weight的梯度,shape与weight保持一致。 | BFLOAT16、FLOAT16 | ND |
约束说明
- compressBlockSize和compressStride要是16的整数倍,且compressBlockSize > compressStride。
调用说明
| 调用方式 | 调用样例 | 说明 |
|---|---|---|
| aclnn调用 | test_aclnn_nsa_compress_grad | 非TND场景,通过aclnnNsaCompressGrad接口方式调用NsaCompressGrad算子。 |