npu_ring_attention_update对外接口

npu_ring_attention_update( prev_attn_out: torch.Tensor, prev_softmax_max: torch.Tensor, prev_softmax_sum: torch.Tensor, cur_attn_out: torch.Tensor, cur_softmax_max: torch.Tensor, cur_softmax_sum: torch.Tensor, actual_seq_qlen: torch.Tensor = None, layout: str = "SBH", )

小算子等价计算逻辑:

import torch
from einops import rearrange


def forward_update(prev_attn_out, prev_softmax_max, prev_softmax_sum,
                   cur_attn_out, cur_softmax_max, cur_softmax_sum, actual_seq_qlen=None, layout='SBH'):
    # update softmax_max
    origin_dtype = prev_attn_out.dtype
    softmax_max = torch.maximum(prev_softmax_max, cur_softmax_max)
    prev_scale = torch.exp(prev_softmax_max - softmax_max)
    cur_scale = torch.exp(cur_softmax_max - softmax_max)

    # update softmax_sum
    prev_softmax_sum_scaled = prev_softmax_sum * prev_scale
    cur_softmax_sum_scaled = cur_softmax_sum * cur_scale
    softmax_sum = prev_softmax_sum_scaled + cur_softmax_sum_scaled

    # out updating scale
    prev_out_scale = prev_softmax_sum_scaled / softmax_sum
    cur_out_scale = cur_softmax_sum_scaled / softmax_sum

    # [b, n, s, 8] -> [s, b, h]
    # SBH layout
    n = prev_out_scale.shape[1]
    h = prev_attn_out.shape[-1]
    d = h // n
    prev_out_scale = prev_out_scale[..., 0].unsqueeze(3).repeat(1, 1, 1, d)
    prev_out_scale = rearrange(prev_out_scale, 'b n s d -> s b (n d)').contiguous()
    cur_out_scale = cur_out_scale[..., 0].unsqueeze(3).repeat(1, 1, 1, d)
    cur_out_scale = rearrange(cur_out_scale, 'b n s d -> s b (n d)').contiguous()

    # update output
    attn_out = prev_attn_out * prev_out_scale + cur_attn_out * cur_out_scale
    attn_out = attn_out.to(origin_dtype)
    return attn_out, softmax_max, softmax_sum

前向接口

输入:

  • prev_attn_out:必选输入,数据类型torch.bfloat16, torch.float, torch.float16
  • prev_softmax_max: 必选输入,数据类型torch.float
  • prev_softmax_sum: 必选输入,数据类型torch.float
  • cur_attn_out: 必选输入,数据类型torch.bfloat16, torch.float, torch.float16
  • cur_softmax_max: 必选输入,数据类型torch.float
  • cur_softmax_sum: 必选输入,数据类型torch.float

输出:

  • attn_out:必选输出,数据类型torch.bfloat16, torch.float, torch.float16
  • softmax_max:必选输出,数据类型torch.float
  • softmax_sum:必选输出,数据类型torch.float

属性:

  • actual_seq_qlen:可选属性,数据类型torch.int64, 数据单调递增,layout为TND的时候使用
  • layout:必选属性,数据类型str

案例

import torch
import torch_npu
from mindspeed.ops.npu_ring_attention_update import npu_ring_attention_update

prev_attn_out = torch.randn(2048, 1, 12, dtype=torch.bfloat16).npu()
prev_softmax_max = torch.randn(1, 12, 2048, 8, dtype=torch.float32).npu()
prev_softmax_sum = torch.randn(1, 12, 2048, 8, dtype=torch.float32).npu()
cur_attn_out = torch.randn(2048, 1, 12, dtype=torch.bfloat16).npu()
cur_softmax_max = torch.randn(1, 12, 2048, 8, dtype=torch.float32).npu()
cur_softmax_sum = torch.randn(1, 12, 2048, 8, dtype=torch.float32).npu()

attn_out, softmax_max, softmax_sum = npu_ring_attention_update(prev_attn_out, prev_softmax_max, prev_softmax_sum,
                                                cur_attn_out, cur_softmax_max, cur_softmax_sum)