import torch
from mindspeed.op_builder import RingAttentionUpdateOpBuilder
__all__ = ["npu_ring_attention_update"]
op_builder = RingAttentionUpdateOpBuilder()
def npu_ring_attention_update(
prev_attn_out: torch.Tensor,
prev_softmax_max: torch.Tensor,
prev_softmax_sum: torch.Tensor,
cur_attn_out: torch.Tensor,
cur_softmax_max: torch.Tensor,
cur_softmax_sum: torch.Tensor,
actual_seq_qlen: torch.Tensor = None,
layout: str = "SBH",
):
ops = op_builder.load()
return ops.npu_ring_attention_update(
prev_attn_out, prev_softmax_max, prev_softmax_sum, cur_attn_out, cur_softmax_max, cur_softmax_sum,
actual_seq_qlen, layout)