import os
from typing import Optional
import torch
import triton
import triton.language as tl
from mindspeed.ops.triton.utils import prepare_chunk_indices, make_tensor_descriptor, input_guard, is_amd
FLA_TRIL_PRECISION = os.environ.get('FLA_TRIL_PRECISION', 'ieee')
@triton.heuristics({"IS_VARLEN": lambda args: args["cu_seqlens"] is not None})
@triton.jit(do_not_specialize=["T", "TPP"])
def solve_tril_16x16_kernel(
A,
Ai,
cu_seqlens,
chunk_indices,
T,
H: tl.constexpr,
BT: tl.constexpr,
TPP: tl.constexpr,
USE_TMA: tl.constexpr,
IS_VARLEN: tl.constexpr,
DOT_PRECISION: tl.constexpr
):
pid_t, pid_bh = tl.program_id(0), tl.program_id(1)
i_b, i_h = pid_bh // H, pid_bh % H
base_t = pid_t * TPP
if IS_VARLEN:
i_n = tl.load(chunk_indices + base_t * 2).to(tl.int32)
bos = tl.load(cu_seqlens + i_n).to(tl.int32)
eos = tl.load(cu_seqlens + i_n + 1).to(tl.int32)
T_eff = eos - bos
else:
bos, eos = i_b * T, i_b * T + T
T_eff = T
o_i = tl.arange(0, 16)
o_i_fp32 = tl.arange(0, 16).to(tl.float32)
m_A = o_i_fp32[:, None] > o_i_fp32[None, :]
m_I = o_i_fp32[:, None] == o_i_fp32[None, :]
A = A + (bos * H + i_h) * BT
Ai = Ai + (bos * H + i_h) * BT
for tpp in tl.static_range(0, TPP):
tile_t = base_t + tpp
tile_row = tile_t * 16
offset = (tile_t * 16) % BT
if not USE_TMA:
p_A = tl.make_block_ptr(
A, (T_eff, BT), (H * BT, 1), (tile_row, offset), (16, 16), (1, 0)
)
b_A_raw = tl.load(p_A, boundary_check=(0, 1)).to(tl.float32)
else:
desc = make_tensor_descriptor(A, [T_eff, BT], [H * BT, 1], [16, 16])
desc_o = make_tensor_descriptor(Ai, [T_eff, 16], [H * 16, 1], [16, 16])
b_A_raw = desc.load([tile_row, offset]).to(tl.float32)
b_A_neg = -b_A_raw
b_A = b_A_neg * m_A
for i in range(2, min(16, T_eff - tile_row)):
slice_res = tl.extract_slice(b_A_neg, [i, 0], [1, 16], [1, 1])
b_a_val = tl.reshape(slice_res, (16,), can_reorder=True)
dot_prod = tl.sum(b_a_val[:, None] * b_A, 0)
b_a_update = b_a_val + dot_prod
b_A = tl.where((o_i_fp32 == i)[:, None], b_a_update, b_A)
b_A += m_I
if not USE_TMA:
p_Ai = tl.make_block_ptr(
Ai, (T_eff, 16), (H * 16, 1), (tile_row, 0), (16, 16), (1, 0)
)
tl.store(p_Ai, b_A.to(p_Ai.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1))
else:
desc_o.store([tile_row, 0], b_A.to(desc_o.dtype, fp_downcast_rounding="rtne"))
@triton.heuristics({
'IS_VARLEN': lambda args: args['cu_seqlens'] is not None
})
@triton.autotune(
configs=[
triton.Config({}, num_warps=num_warps, num_stages=num_stages)
for num_warps in [1, 2, 4, 8]
for num_stages in [2, 3, 4, 5]
],
key=['H', 'BT', 'IS_VARLEN'],
)
@triton.jit(do_not_specialize=["T", "TPP"])
def merge_16x16_to_32x32_inverse_kernel(
A,
Ai,
cu_seqlens,
chunk_indices,
T,
H: tl.constexpr,
BT: tl.constexpr,
TPP: tl.constexpr,
USE_TMA: tl.constexpr,
IS_VARLEN: tl.constexpr,
DOT_PRECISION: tl.constexpr
):
i_t, i_bh = tl.program_id(0), tl.program_id(1)
i_b, i_h = i_bh // H, i_bh % H
if IS_VARLEN:
i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32)
bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32)
T = eos - bos
else:
bos, eos = i_b * T, i_b * T + T
o_i = tl.arange(0, 16)
m_A = o_i[:, None] > o_i[None, :]
m_I = o_i[:, None] == o_i[None, :]
A += (bos * H + i_h) * BT
Ai += (bos * H + i_h) * BT
if not USE_TMA:
p_A_11 = tl.make_block_ptr(A, (T, BT), (H * BT, 1), (i_t * BT, 0), (16, 16), (1, 0))
p_A_22 = tl.make_block_ptr(A, (T, BT), (H * BT, 1), (i_t * BT + 16, 16), (16, 16), (1, 0))
b_Ai_11 = tl.load(p_A_11, boundary_check=(0, 1)).to(tl.float32)
b_Ai_22 = tl.load(p_A_22, boundary_check=(0, 1)).to(tl.float32)
else:
desc = make_tensor_descriptor(A, [T, BT], [H * BT, 1], [16, 16])
desc_o = make_tensor_descriptor(Ai, [T, BT], [H * BT, 1], [16, 16])
b_Ai_11 = desc.load([i_t * BT + 0, 0]).to(tl.float32)
b_Ai_22 = desc.load([i_t * BT + 16, 16]).to(tl.float32)
b_Ai_11 = -tl.where(m_A, b_Ai_11, 0)
b_Ai_22 = -tl.where(m_A, b_Ai_22, 0)
for i in range(2, min(16, T - i_t * BT)):
b_a_11 = -tl.load(A + (i_t * BT + i) * H * BT + o_i)
b_a_11 += tl.sum(b_a_11[:, None] * b_Ai_11, 0)
b_Ai_11 = tl.where((o_i == i)[:, None], b_a_11, b_Ai_11)
for i in range(16 + 2, min(32, T - i_t * BT)):
b_a_22 = -tl.load(A + (i_t * BT + i) * H * BT + o_i + 16)
b_a_22 += tl.sum(b_a_22[:, None] * b_Ai_22, 0)
b_Ai_22 = tl.where((o_i == i - 16)[:, None], b_a_22, b_Ai_22)
b_Ai_11 += m_I
b_Ai_22 += m_I
if not USE_TMA:
p_A_21 = tl.make_block_ptr(A, (T, BT), (H * BT, 1), (i_t * BT + 16, 0), (16, 16), (1, 0))
b_A_21 = tl.load(p_A_21, boundary_check=(0, 1)).to(tl.float32)
else:
b_A_21 = desc.load([i_t * BT + 16, 0]).to(tl.float32)
b_Ai_21 = -tl.dot(tl.dot(b_Ai_22, b_A_21, input_precision=DOT_PRECISION), b_Ai_11, input_precision=DOT_PRECISION)
if not USE_TMA:
p_Ai_11 = tl.make_block_ptr(Ai, (T, BT), (H * BT, 1), (i_t * BT, 0), (16, 16), (1, 0))
p_Ai_21 = tl.make_block_ptr(Ai, (T, BT), (H * BT, 1), (i_t * BT + 16, 0), (16, 16), (1, 0))
p_Ai_22 = tl.make_block_ptr(Ai, (T, BT), (H * BT, 1), (i_t * BT + 16, 16), (16, 16), (1, 0))
tl.store(p_Ai_11, b_Ai_11.to(p_Ai_11.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1))
tl.store(p_Ai_22, b_Ai_22.to(p_Ai_22.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1))
tl.store(p_Ai_21, b_Ai_21.to(p_Ai_21.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1))
else:
desc_o.store([i_t * BT + 0, 0], b_Ai_11.to(desc_o.dtype, fp_downcast_rounding="rtne"))
desc_o.store([i_t * BT + 16, 0], b_Ai_21.to(desc_o.dtype, fp_downcast_rounding="rtne"))
desc_o.store([i_t * BT + 16, 16], b_Ai_22.to(desc_o.dtype, fp_downcast_rounding="rtne"))
@triton.heuristics({"IS_VARLEN": lambda args: args["cu_seqlens"] is not None})
@triton.jit(do_not_specialize=["T", "TPP"])
def solve_tril_64x64_kernel(
A,
Ai,
cu_seqlens,
chunk_indices,
T,
H: tl.constexpr,
BT: tl.constexpr,
TPP: tl.constexpr,
USE_TMA: tl.constexpr,
IS_VARLEN: tl.constexpr,
DOT_PRECISION: tl.constexpr
):
pid_t, pid_bh = tl.program_id(0), tl.program_id(1)
i_b, i_h = pid_bh // H, pid_bh % H
base_t = pid_t * TPP
if IS_VARLEN:
i_n = tl.load(chunk_indices + base_t * 2).to(tl.int32)
bos = tl.load(cu_seqlens + i_n).to(tl.int32)
eos = tl.load(cu_seqlens + i_n + 1).to(tl.int32)
T_eff = eos - bos
else:
bos, eos = i_b * T, i_b * T + T
T_eff = T
o_i_fp32 = tl.arange(0, 64).to(tl.float32)
m_A = o_i_fp32[:, None] > o_i_fp32[None, :]
m_I = o_i_fp32[:, None] == o_i_fp32[None, :]
A = A + (bos * H + i_h) * BT
Ai = Ai + (bos * H + i_h) * BT
for tpp in tl.static_range(0, TPP):
tile_t = base_t + tpp
tile_row = tile_t * 64
offset = (tile_t * 64) % BT
if not USE_TMA:
p_A = tl.make_block_ptr(
A, (T_eff, BT), (H * BT, 1), (tile_row, offset), (64, 64), (1, 0)
)
b_A_raw = tl.load(p_A, boundary_check=(0, 1)).to(tl.float32)
else:
desc = make_tensor_descriptor(A, [T_eff, BT], [H * BT, 1], [64, 64])
desc_o = make_tensor_descriptor(Ai, [T_eff, 64], [H * 64, 1], [64, 64])
b_A_raw = desc.load([tile_row, offset]).to(tl.float32)
b_A_neg = -b_A_raw
b_A = b_A_neg * m_A
limit = min(64, T_eff - tile_row)
for i in range(2, limit):
slice_res = tl.extract_slice(b_A_neg, [i, 0], [1, 64], [1, 1])
b_a_val = tl.reshape(slice_res, (64,), can_reorder=True)
dot_prod = tl.sum(b_a_val[:, None] * b_A, 0)
b_a_update = b_a_val + dot_prod
b_A = tl.where((o_i_fp32 == i)[:, None], b_a_update, b_A)
b_A += m_I
if not USE_TMA:
p_Ai = tl.make_block_ptr(
Ai, (T_eff, 64), (H * 64, 1), (tile_row, 0), (64, 64), (1, 0)
)
tl.store(p_Ai, b_A.to(p_Ai.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1))
else:
desc_o.store([tile_row, 0], b_A.to(desc_o.dtype, fp_downcast_rounding="rtne"))
@input_guard
def solve_tril(
A: torch.Tensor,
cu_seqlens: Optional[torch.Tensor] = None,
output_dtype: torch.dtype = torch.float
) -> torch.Tensor:
"""
Compute the inverse of the matrix I + A
A should be strictly lower triangular, i.e., A.triu() == 0.
Args:
A (torch.Tensor):
[B, T, H, BT], where BT should only be 16, 32, or 64.
cu_seqlens (torch.Tensor):
The cumulative sequence lengths of the input tensor. Default: `None`.
output_dtype (torch.dtype):
The dtype of the output tensor. Default: `torch.float`.
If `None`, the output dtype will be the same as the input dtype.
Returns:
(I + A)^-1 with the same shape as A
"""
if A.shape[-1] not in [16, 32, 64]:
raise ValueError(
f"A shape BT should in [16,32, 64], but current is {A.shape[-1]}"
)
output_dtype = A.dtype if output_dtype is None else output_dtype
B, T, H, BT = A.shape
chunk_indices = prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None
NT = len(chunk_indices) if cu_seqlens is not None else triton.cdiv(T, BT)
Ai = torch.zeros_like(A, dtype=output_dtype)
if BT == 16:
TPP = 4
grid0 = (NT + TPP - 1) // TPP
merge_fn = solve_tril_16x16_kernel
elif BT == 32:
TPP = 4
grid0 = NT
merge_fn = merge_16x16_to_32x32_inverse_kernel
elif BT == 64:
TPP = 22
grid0 = (NT + TPP - 1) // TPP
merge_fn = solve_tril_64x64_kernel
merge_fn[grid0, B * H](
A=A,
Ai=Ai,
cu_seqlens=cu_seqlens,
chunk_indices=chunk_indices,
T=T,
H=H,
BT=BT,
TPP=TPP,
USE_TMA=False,
DOT_PRECISION=FLA_TRIL_PRECISION,
)
return Ai