import os
from typing import List, Optional, Tuple
import pytest
import torch
import triton
import triton.language as tl
from mindspeed.ops.triton.utils import assert_close
from mindspeed.ops.triton.cumsum import chunk_local_cumsum
from mindspeed.ops.triton.utils import prepare_chunk_indices
@triton.heuristics({
'HAS_SCALE': lambda args: args['scale'] is not None,
'IS_VARLEN': lambda args: args['cu_seqlens'] is not None,
})
@triton.autotune(
configs=[
triton.Config({}, num_warps=num_warps)
for num_warps in [1, 2, 4, 8]
],
key=['B', 'H', 'BT', 'IS_VARLEN', 'REVERSE'],
)
@triton.jit(do_not_specialize=['T'])
def chunk_local_cumsum_scalar_kernel(
s,
o,
scale,
cu_seqlens,
chunk_indices,
T,
B: tl.constexpr,
H: tl.constexpr,
BT: tl.constexpr,
REVERSE: tl.constexpr,
HAS_SCALE: tl.constexpr,
IS_VARLEN: tl.constexpr,
HEAD_FIRST: tl.constexpr,
):
i_t, i_bh = tl.program_id(0), tl.program_id(1)
i_b, i_h = i_bh // H, i_bh % H
if IS_VARLEN:
i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32)
bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32)
T = eos - bos
else:
bos, eos = i_b * T, i_b * T + T
if HEAD_FIRST:
p_s = tl.make_block_ptr(s + bos * H + i_h * T, (T,), (1,), (i_t * BT,), (BT,), (0,))
p_o = tl.make_block_ptr(o + bos * H + i_h * T, (T,), (1,), (i_t * BT,), (BT,), (0,))
else:
p_s = tl.make_block_ptr(s + bos * H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,))
p_o = tl.make_block_ptr(o + bos * H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,))
b_s = tl.load(p_s, boundary_check=(0,)).to(tl.float32)
b_o = tl.cumsum(b_s, axis=0)
if REVERSE:
b_z = tl.sum(b_s, axis=0)
b_o = -b_o + b_z[None] + b_s
if HAS_SCALE:
b_o *= scale
tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0,))
def chunk_local_cumsum_scalar(
g: torch.Tensor,
chunk_size: int,
reverse: bool = False,
scale: float = None,
cu_seqlens: torch.Tensor | None = None,
head_first: bool = False,
output_dtype: torch.dtype | None = torch.float,
chunk_indices: torch.LongTensor | None = None,
) -> torch.Tensor:
if head_first:
B, H, T = g.shape
else:
B, T, H = g.shape
if chunk_size != 2 ** (chunk_size.bit_length() - 1):
raise ValueError(
f"chunk_size must be a power of 2, chunk_size is{chunk_size}"
)
BT = chunk_size
if chunk_indices is None and cu_seqlens is not None:
chunk_indices = prepare_chunk_indices(cu_seqlens, BT)
NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices)
g_org, g = g, torch.empty_like(g, dtype=output_dtype or g.dtype)
grid = (NT, B * H)
chunk_local_cumsum_scalar_kernel[grid](
s=g_org,
o=g,
scale=scale,
cu_seqlens=cu_seqlens,
chunk_indices=chunk_indices,
T=T,
B=B,
H=H,
BT=BT,
HEAD_FIRST=head_first,
REVERSE=reverse,
)
return g
def origin_chunk_local_cumsum(
g: torch.Tensor,
chunk_size: int,
reverse: bool = False,
scale: float = None,
cu_seqlens: torch.Tensor | None = None,
head_first: bool = False,
output_dtype: torch.dtype | None = torch.float,
chunk_indices: torch.LongTensor | None = None,
**kwargs,
) -> torch.Tensor:
if cu_seqlens is not None:
if g.shape[0] != 1:
raise ValueError(
f"Only batch size 1 is supported when cu_seqlens are provided, current size is{g.shape[0]}"
)
if len(g.shape) == 3:
return chunk_local_cumsum_scalar(
g=g,
chunk_size=chunk_size,
reverse=reverse,
scale=scale,
cu_seqlens=cu_seqlens,
head_first=head_first,
output_dtype=output_dtype,
chunk_indices=chunk_indices,
)
else:
raise ValueError(
f"Unsupported input shape {g.shape}, "
f"which should be (B, T, H, D) if `head_first=False` "
f"or (B, H, T, D) otherwise",
)
@pytest.mark.parametrize(
('B', 'T', 'H', 'chunk_size', 'reverse', 'cu_seqlens'),
[
pytest.param(*test, id="B{}-T{}-H{}-chunk_size{}-reverse{}-cu_seqlens{}".format(*test))
for test in [
(1, 1024, 32, 64, False, None),
(1, 4096, 32, 64, False, None),
(1, 1024, 32, 64, True, None),
(1, 4096, 32, 64, True, None),
(1, 1024, 32, 64, False, [0, 175, 1024]),
(1, 4096, 32, 64, False, [0, 175, 1024, 2764, 4096]),
(1, 1024, 32, 64, True, [0, 175, 1024]),
(1, 4096, 32, 64, True, [0, 175, 1024, 2764, 4096]),
(2, 1024, 32, 64, False, None),
(2, 4096, 32, 64, False, None),
(2, 1024, 32, 64, True, None),
(2, 4096, 32, 64, True, None),
(1, 1024, 32, 16, False, None),
(1, 1024, 32, 16, False, [0, 10, 66, 140, 229, 351, 401, 574, 684, 819, 874, 922, 1024]),
]
]
)
def test_cumsum(B, T, H, chunk_size, reverse, cu_seqlens):
device = "npu:0"
device_dtype = torch.bfloat16
torch.manual_seed(42)
torch.npu.manual_seed(42)
if cu_seqlens is not None:
cu_seqlens = torch.LongTensor(cu_seqlens).to(device)
g = torch.randn((B, T, H), device=device, dtype=device_dtype)
cu_seqlens = cu_seqlens
ref_g = origin_chunk_local_cumsum(
g=g,
chunk_size=chunk_size,
reverse=reverse,
cu_seqlens=cu_seqlens,
head_first=False,
)
cur_g = chunk_local_cumsum(
g=g,
chunk_size=chunk_size,
reverse=reverse,
cu_seqlens=cu_seqlens,
head_first=False
)
print("g diff:", torch.max(torch.abs(ref_g - cur_g)))
assert_close('g', ref_g, cur_g, 0.001)