NormRopeConcat
产品支持情况
| 产品 | 是否支持 |
|---|---|
| Ascend 950PR/Ascend 950DT | √ |
| Atlas A3 训练系列产品/Atlas A3 推理系列产品 | √ |
| Atlas A2 训练系列产品/Atlas A2 推理系列产品 | √ |
| Atlas 200I/500 A2 推理产品 | × |
| Atlas 推理系列产品 | × |
| Atlas 训练系列产品 | × |
功能说明
-
算子功能:(多模态)transformer注意力机制中,针对query、key和Value实现归一化(Norm)、旋转位置编码(Rope)、特征拼接(Concat):
- 归一化(Norm)当前支持层归一化(LayerNorm)、带仿射变换参数层归一化(AFFINE LayerNorm)、均方根归一化(RmsNorm)和带仿射变换参数均方根归一化(AFFINE RmsNorm)类型。
- 旋转位置编码(Rope)支持Interleave和Half类型。
- 特征拼接(Concat)支持在sequence维度上进行拼接,拼接有顺序区别。
-
计算公式(以Query(视频)和EncoderQuery(文本)为例):
$$hiddenState_q = \text{Norm}(query, normQueryWeight, normQueryBias, eps) \ hiddenState_{eq} = \text{Norm}(encoderQuery, normEncoderQueryWeight, normEncoderQueryBias, eps) \ concatedHiddenState = \text{Concat}(hiddenState_q, hiddenState_{eq}) \ transposedHiddenState = \text{Transpose}(concatedHiddenState, (0, 2, 1, 3)) \ hiddenState = \text{RoPE}(concatedHiddenState, ropeSin, ropeCos)
-
输入输出布局如下:输入
query的shape为(B, S, N, D),输出hiddenState的shape为(B, N, S, D),其中 B为batch,S为sequenceLen,N为headNum,D为headDim。 -
Norm有五种模式(
normType):NONE(0), LAYER_NORM(1), LAYER_NORM_AFFINE(2), RMS_NORM(3), RMS_NORM_AFFINE(4),其中: 当normType = NONE时:hiddenStateq=queryhiddenState_q = query
当
normType = LAYER_NORM时queryMeanb,s,n=1D∑i=0Dqueryb,s,nqueryVarb,s,n=1D∑i=0D(query−queryMeanb,s,n)2queryRstdb,s,n=1queryVarb,s,n+ϵhiddenStateq=(query−queryMean)∗queryRstdqueryMean_{b,s,n} = \frac{1}{D}\sum_{i=0}^{D}query_{b,s,n} \\ queryVar_{b,s,n} = \frac{1}{D}\sum_{i=0}^{D}(query-queryMean_{b,s,n})^2 \\ queryRstd_{b,s,n}= \frac{1}{\sqrt{queryVar_{b,s,n}+\epsilon}} \\ hiddenState_q = (query-queryMean)*queryRstd
当
normType = LAYER_NORM_AFFINE时,在上面的基础上hiddenStateq=normQueryWeight∗hiddenStateq+normQueryBiashiddenState_q = normQueryWeight*hiddenState_q + normQueryBias
当
normType = RMS_NORM时:queryMs=1D∑i=0D(queryb,s,n)2queryRms=1queryMs+ϵhiddenStateq=query∗queryRmsqueryMs = \frac{1}{D}\sum_{i=0}^{D}(query_{b,s,n})^2 \\ queryRms = \frac{1}{\sqrt{queryMs+\epsilon}} \\ hiddenState_q = query * queryRms
当
normType = RMS_NORM_AFFINE时,在上面的基础上hiddenStateq=normQueryWeight∗hiddenStateqhiddenState_q = normQueryWeight*hiddenState_q
-
Concat指在sequence维度上进行拼接,拼接有顺序区别(
concatOrder),当concatOrder=0时,hiddenStateqhiddenState_q在hiddenStateeqhiddenState_{eq}前,当concatOrder=1时,hiddenStateqhiddenState_q在hiddenStateeqhiddenState_{eq}后。 -
RoPE有三种模式(
ropeType):NONE(0), INTERLEAVE(1), HALF(2),其中当ropeType=NONE时直接输出不做变换,其余情况参考如下:def image_rotary_emb(hidden_states, rope_sin, rope_cos, mode=1): out = torch.empty_like(hidden_states) if mode == 1: # interleave x = hidden_states.view(*hidden_states.shape[:-1], -1, 2) x1, x2 = x[..., 0], x[..., 1] rotated_x = torch.stack([-x2, x1],dim=-1).flatten(3) out = hidden_states.float() * rope_cos + rotated_x.float()*rope_sin return out.type_as(hidden_states) else: # half x1, x2 = hidden_states.reshape(*hidden_states.shape[:-1], 2, -1).unbind(-2) rotated_x = torch.cat([-x2, x1],dim=-1) out = hidden_states.float() * rope_cos + rotated_x.float()*rope_sin return out.type_as(hidden_states) -
RoPE的输入
ropeSin的shape为(seqRope, D),其中
seqRope<=min(seqQuery+seqEncoderQuery,seqKey+seqEncoderKey)seqRope <= min(seqQuery+seqEncoderQuery, seqKey+seqEncoderKey)
- 当场景为训练时,会输出
queryMean, queryRstd, encoderQueryMean, encoderQueryRstd供后续反向使用。
-
参数说明
| 参数名 | 输入/输出/属性 | 描述 | 数据类型 | 数据格式 |
|---|---|---|---|---|
| query | 输入 | 表示注意力机制中的Query | FLOAT16、BFLOAT16、FLOAT | ND |
| key | 输入 | 表示注意力机制中的Key | FLOAT16、BFLOAT16、FLOAT | ND |
| value | 输入 | 表示注意力机制中的Value | FLOAT16、BFLOAT16、FLOAT | ND |
| encoderQuery | 输入 | 表示注意力机制中的Query,来自EncoderHiddenState,可以为空指针 | FLOAT16、BFLOAT16、FLOAT | ND |
| encoderKey | 输入 | 表示注意力机制中的Key,来自EncoderHiddenState,可以为空指针 | FLOAT16、BFLOAT16、FLOAT | ND |
| encoderValue | 输入 | 表示注意力机制中的Value,来自EncoderHiddenState,可以为空指针 | FLOAT16、BFLOAT16、FLOAT | ND |
| normQueryWeight | 输入 | 表示LayerNorm的仿射变换参数,作用在Query上 | FLOAT16、BFLOAT16、FLOAT | ND |
| normQueryBias | 输入 | 表示LayerNorm的仿射变换参数,作用在Query上 | FLOAT16、BFLOAT16、FLOAT | ND |
| normKeyWeight | 输入 | 表示LayerNorm的仿射变换参数,作用在Key上 | FLOAT16、BFLOAT16、FLOAT | ND |
| normKeyBias | 输入 | 表示LayerNorm的仿射变换参数,作用在Key上 | FLOAT16、BFLOAT16、FLOAT | ND |
| normAddedQueryWeight | 输入 | 表示LayerNorm的仿射变换参数,作用在encoderQuery上 | FLOAT16、BFLOAT16、FLOAT | ND |
| normAddedQueryBias | 输入 | 表示LayerNorm的仿射变换参数,作用在encoderQuery上 | FLOAT16、BFLOAT16、FLOAT | ND |
| normAddedKeyWeight | 输入 | 表示LayerNorm的仿射变换参数,作用在encoderKey上 | FLOAT16、BFLOAT16、FLOAT | ND |
| normAddedKeyBias | 输入 | 表示LayerNorm的仿射变换参数,作用在encoderKey上 | FLOAT16、BFLOAT16、FLOAT | ND |
| ropeSin | 输入 | 表示RoPE的正弦编码 | FLOAT16、BFLOAT16、FLOAT | ND |
| ropeCos | 输入 | 表示RoPE的余弦编码 | FLOAT16、BFLOAT16、FLOAT | ND |
| normType | 属性 | 表示作用在q,k上的正则化类型,0: 不做正则化,1: LayerNorm, 2: LayerNormAffine, 3: RmsNorm, 4: RmsNormAffine | int64 | ND |
| normAddedType | 属性 | 表示作用在encoderQuery,encoderKey上的正则化类型,0: 不做正则化,1: LayerNorm, 2: LayerNormAffine, 3: RmsNorm, 4: RmsNormAffine | int64 | ND |
| ropeType | 属性 | 表示RoPE的模式,int64类型,0: 不做RoPE,1: Interleave, 2: Half | int64 | ND |
| concatOrder | 属性 | 表示拼接的顺序,int64类型,0: query在前,1: query在后 | int64 | ND |
| eps | 属性 | 表示正则化中的epsilon值 | float32 | ND |
| isTraining | 属性 | 表示是否为训练阶段,决定是否输出反向使用的值 | bool | ND |
| queryOutput | 输出 | 表示注意力机制中的query输出 | FLOAT16、BFLOAT16、FLOAT | ND |
| keyOutput | 输出 | 表示注意力机制中的key输出 | FLOAT16、BFLOAT16、FLOAT | ND |
| valueOutput | 输出 | 表示注意力机制中的value输出 | FLOAT16、BFLOAT16、FLOAT | ND |
| normQueryMean | 输出 | LayerNorm中的query均值输出,用于反向 | FLOAT | ND |
| normQueryRstd | 输出 | LayerNorm中的query标准差输出,用于反向 | FLOAT | ND |
| normKeyMean | 输出 | LayerNorm中的key均值输出,用于反向 | FLOAT | ND |
| normKeyRstd | 输出 | LayerNorm中的key标准差输出,用于反向 | FLOAT | ND |
| normEncoderQueryMean | 输出 | LayerNorm中的encoderQuery均值输出,用于反向 | FLOAT | ND |
| normEncoderQueryRstd | 输出 | LayerNorm中的encoderQuery标准差输出,用于反向 | FLOAT | ND |
| normEncoderKeyMean | 输出 | LayerNorm中的encoderKey均值输出,用于反向 | FLOAT | ND |
| normEncoderKeyRstd | 输出 | LayerNorm中的encoderKey标准差输出,用于反向 | FLOAT | ND |
约束说明
- 确定性计算:
- aclnnNormRopeConcat默认确定性实现。
- query、key、value、encoderQuery、encoderKey、encoderValue数据类型需一致。
- headDim长度在[1~1024]间,且为偶数。
- seqRope长度大小在[1~Min(seqQuery+seqEncoderQuery, seqKey+seqEncoderKey)]之间。
调用说明
| 调用方式 | 调用样例 | 说明 |
|---|---|---|
| aclnn调用 | test_aclnn_norm_rope_concat.cpp | 通过aclnnNormRopeConcat接口方式调用NormRopeConcat算子。 |