NormRopeConcat

产品支持情况

产品 是否支持
Ascend 950PR/Ascend 950DT
Atlas A3 训练系列产品/Atlas A3 推理系列产品
Atlas A2 训练系列产品/Atlas A2 推理系列产品
Atlas 200I/500 A2 推理产品 ×
Atlas 推理系列产品 ×
Atlas 训练系列产品 ×

功能说明

  • 算子功能:(多模态)transformer注意力机制中,针对query、key和Value实现归一化(Norm)、旋转位置编码(Rope)、特征拼接(Concat):

    • 归一化(Norm)当前支持层归一化(LayerNorm)、带仿射变换参数层归一化(AFFINE LayerNorm)、均方根归一化(RmsNorm)和带仿射变换参数均方根归一化(AFFINE RmsNorm)类型。
    • 旋转位置编码(Rope)支持Interleave和Half类型。
    • 特征拼接(Concat)支持在sequence维度上进行拼接,拼接有顺序区别。
  • 计算公式(以Query(视频)和EncoderQuery(文本)为例):

    $$
    

    hiddenState_q = \text{Norm}(query, normQueryWeight, normQueryBias, eps) \ hiddenState_{eq} = \text{Norm}(encoderQuery, normEncoderQueryWeight, normEncoderQueryBias, eps) \ concatedHiddenState = \text{Concat}(hiddenState_q, hiddenState_{eq}) \ transposedHiddenState = \text{Transpose}(concatedHiddenState, (0, 2, 1, 3)) \ hiddenState = \text{RoPE}(concatedHiddenState, ropeSin, ropeCos)

    1. 输入输出布局如下:输入query的shape为(B, S, N, D),输出hiddenState的shape为(B, N, S, D),其中 B为batch,S为sequenceLen,N为headNum,D为headDim。

    2. Norm有五种模式(normType):NONE(0), LAYER_NORM(1), LAYER_NORM_AFFINE(2), RMS_NORM(3), RMS_NORM_AFFINE(4),其中: 当normType = NONE时:

      hiddenStateq=queryhiddenState_q = query

      normType = LAYER_NORM

      queryMeanb,s,n=1D∑i=0Dqueryb,s,nqueryVarb,s,n=1D∑i=0D(query−queryMeanb,s,n)2queryRstdb,s,n=1queryVarb,s,n+ϵhiddenStateq=(query−queryMean)∗queryRstdqueryMean_{b,s,n} = \frac{1}{D}\sum_{i=0}^{D}query_{b,s,n} \\ queryVar_{b,s,n} = \frac{1}{D}\sum_{i=0}^{D}(query-queryMean_{b,s,n})^2 \\ queryRstd_{b,s,n}= \frac{1}{\sqrt{queryVar_{b,s,n}+\epsilon}} \\ hiddenState_q = (query-queryMean)*queryRstd

      normType = LAYER_NORM_AFFINE时,在上面的基础上

      hiddenStateq=normQueryWeight∗hiddenStateq+normQueryBiashiddenState_q = normQueryWeight*hiddenState_q + normQueryBias

      normType = RMS_NORM时:

      queryMs=1D∑i=0D(queryb,s,n)2queryRms=1queryMs+ϵhiddenStateq=query∗queryRmsqueryMs = \frac{1}{D}\sum_{i=0}^{D}(query_{b,s,n})^2 \\ queryRms = \frac{1}{\sqrt{queryMs+\epsilon}} \\ hiddenState_q = query * queryRms

      normType = RMS_NORM_AFFINE时,在上面的基础上

      hiddenStateq=normQueryWeight∗hiddenStateqhiddenState_q = normQueryWeight*hiddenState_q

    3. Concat指在sequence维度上进行拼接,拼接有顺序区别(concatOrder),当concatOrder=0时,hiddenStateqhiddenState_qhiddenStateeqhiddenState_{eq}前,当concatOrder=1时,hiddenStateqhiddenState_qhiddenStateeqhiddenState_{eq}后。

    4. RoPE有三种模式(ropeType):NONE(0), INTERLEAVE(1), HALF(2),其中当ropeType=NONE时直接输出不做变换,其余情况参考如下:

        def image_rotary_emb(hidden_states, rope_sin, rope_cos, mode=1):
            out = torch.empty_like(hidden_states)
            if mode == 1: # interleave
                x = hidden_states.view(*hidden_states.shape[:-1], -1, 2)
                x1, x2 = x[..., 0], x[..., 1]
                rotated_x = torch.stack([-x2, x1],dim=-1).flatten(3)
                out = hidden_states.float() * rope_cos + rotated_x.float()*rope_sin
                return out.type_as(hidden_states)
            else: # half
                x1, x2 = hidden_states.reshape(*hidden_states.shape[:-1], 2, -1).unbind(-2)
                rotated_x = torch.cat([-x2, x1],dim=-1)
                out = hidden_states.float() * rope_cos + rotated_x.float()*rope_sin
                return out.type_as(hidden_states)
      
    5. RoPE的输入ropeSin的shape为(seqRope, D),其中

    seqRope<=min(seqQuery+seqEncoderQuery,seqKey+seqEncoderKey)seqRope <= min(seqQuery+seqEncoderQuery, seqKey+seqEncoderKey)

    1. 当场景为训练时,会输出queryMean, queryRstd, encoderQueryMean, encoderQueryRstd供后续反向使用。

参数说明

参数名 输入/输出/属性 描述 数据类型 数据格式
query 输入 表示注意力机制中的Query FLOAT16、BFLOAT16、FLOAT ND
key 输入 表示注意力机制中的Key FLOAT16、BFLOAT16、FLOAT ND
value 输入 表示注意力机制中的Value FLOAT16、BFLOAT16、FLOAT ND
encoderQuery 输入 表示注意力机制中的Query,来自EncoderHiddenState,可以为空指针 FLOAT16、BFLOAT16、FLOAT ND
encoderKey 输入 表示注意力机制中的Key,来自EncoderHiddenState,可以为空指针 FLOAT16、BFLOAT16、FLOAT ND
encoderValue 输入 表示注意力机制中的Value,来自EncoderHiddenState,可以为空指针 FLOAT16、BFLOAT16、FLOAT ND
normQueryWeight 输入 表示LayerNorm的仿射变换参数,作用在Query上 FLOAT16、BFLOAT16、FLOAT ND
normQueryBias 输入 表示LayerNorm的仿射变换参数,作用在Query上 FLOAT16、BFLOAT16、FLOAT ND
normKeyWeight 输入 表示LayerNorm的仿射变换参数,作用在Key上 FLOAT16、BFLOAT16、FLOAT ND
normKeyBias 输入 表示LayerNorm的仿射变换参数,作用在Key上 FLOAT16、BFLOAT16、FLOAT ND
normAddedQueryWeight 输入 表示LayerNorm的仿射变换参数,作用在encoderQuery上 FLOAT16、BFLOAT16、FLOAT ND
normAddedQueryBias 输入 表示LayerNorm的仿射变换参数,作用在encoderQuery上 FLOAT16、BFLOAT16、FLOAT ND
normAddedKeyWeight 输入 表示LayerNorm的仿射变换参数,作用在encoderKey上 FLOAT16、BFLOAT16、FLOAT ND
normAddedKeyBias 输入 表示LayerNorm的仿射变换参数,作用在encoderKey上 FLOAT16、BFLOAT16、FLOAT ND
ropeSin 输入 表示RoPE的正弦编码 FLOAT16、BFLOAT16、FLOAT ND
ropeCos 输入 表示RoPE的余弦编码 FLOAT16、BFLOAT16、FLOAT ND
normType 属性 表示作用在q,k上的正则化类型,0: 不做正则化,1: LayerNorm, 2: LayerNormAffine, 3: RmsNorm, 4: RmsNormAffine int64 ND
normAddedType 属性 表示作用在encoderQuery,encoderKey上的正则化类型,0: 不做正则化,1: LayerNorm, 2: LayerNormAffine, 3: RmsNorm, 4: RmsNormAffine int64 ND
ropeType 属性 表示RoPE的模式,int64类型,0: 不做RoPE,1: Interleave, 2: Half int64 ND
concatOrder 属性 表示拼接的顺序,int64类型,0: query在前,1: query在后 int64 ND
eps 属性 表示正则化中的epsilon值 float32 ND
isTraining 属性 表示是否为训练阶段,决定是否输出反向使用的值 bool ND
queryOutput 输出 表示注意力机制中的query输出 FLOAT16、BFLOAT16、FLOAT ND
keyOutput 输出 表示注意力机制中的key输出 FLOAT16、BFLOAT16、FLOAT ND
valueOutput 输出 表示注意力机制中的value输出 FLOAT16、BFLOAT16、FLOAT ND
normQueryMean 输出 LayerNorm中的query均值输出,用于反向 FLOAT ND
normQueryRstd 输出 LayerNorm中的query标准差输出,用于反向 FLOAT ND
normKeyMean 输出 LayerNorm中的key均值输出,用于反向 FLOAT ND
normKeyRstd 输出 LayerNorm中的key标准差输出,用于反向 FLOAT ND
normEncoderQueryMean 输出 LayerNorm中的encoderQuery均值输出,用于反向 FLOAT ND
normEncoderQueryRstd 输出 LayerNorm中的encoderQuery标准差输出,用于反向 FLOAT ND
normEncoderKeyMean 输出 LayerNorm中的encoderKey均值输出,用于反向 FLOAT ND
normEncoderKeyRstd 输出 LayerNorm中的encoderKey标准差输出,用于反向 FLOAT ND

约束说明

  • 确定性计算:
    • aclnnNormRopeConcat默认确定性实现。
  • query、key、value、encoderQuery、encoderKey、encoderValue数据类型需一致。
  • headDim长度在[1~1024]间,且为偶数。
  • seqRope长度大小在[1~Min(seqQuery+seqEncoderQuery, seqKey+seqEncoderKey)]之间。

调用说明

调用方式 调用样例 说明
aclnn调用 test_aclnn_norm_rope_concat.cpp 通过aclnnNormRopeConcat接口方式调用NormRopeConcat算子。