import os
from typing import List, Optional, Tuple
import pytest
import torch
import torch.nn.functional as F
import triton
import triton.language as tl
from mindspeed.ops.triton.solve_tril import solve_tril
from mindspeed.ops.triton.chunk_scaled_dot_kkt import chunk_scaled_dot_kkt_fwd
from mindspeed.ops.triton.utils import is_amd, is_tma_supported, assert_close, device, make_tensor_descriptor, prepare_chunk_indices, device_platform, input_guard
FLA_TRIL_PRECISION = os.environ.get('FLA_TRIL_PRECISION', 'ieee')
ALLOWED_TRIL_PRECISIONS = ['ieee', 'tf32'] if is_amd else ['ieee', 'tf32', 'tf32x3']
assert FLA_TRIL_PRECISION in ALLOWED_TRIL_PRECISIONS, \
f'FLA_TRIL_PRECISION must be one of {ALLOWED_TRIL_PRECISIONS}, but got {FLA_TRIL_PRECISION}'
@triton.heuristics({
'IS_VARLEN': lambda args: args['cu_seqlens'] is not None,
})
@triton.autotune(
configs=[
triton.Config({}, num_warps=num_warps, num_stages=num_stages)
for num_warps in [1, 2, 4, 8]
for num_stages in [2, 3, 4, 5]
],
key=['BT'],
)
@triton.jit(do_not_specialize=['T'])
def solve_tril_16x16_kernel(
A,
Ai,
cu_seqlens,
chunk_indices,
T,
H: tl.constexpr,
BT: tl.constexpr,
USE_TMA: tl.constexpr,
IS_VARLEN: tl.constexpr,
DOT_PRECISION: tl.constexpr,
):
i_t, i_bh = tl.program_id(0), tl.program_id(1)
i_b, i_h = i_bh // H, i_bh % H
if IS_VARLEN:
i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32)
bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32)
T = eos - bos
else:
bos, eos = i_b * T, i_b * T + T
o_i = tl.arange(0, 16)
m_A = o_i[:, None] > o_i[None, :]
m_I = o_i[:, None] == o_i[None, :]
A = A + (bos * H + i_h) * BT
Ai = Ai + (bos * H + i_h) * 16
offset = (i_t * 16) % BT
if not USE_TMA:
p_A = tl.make_block_ptr(A, (T, BT), (H * BT, 1), (i_t * 16, offset), (16, 16), (1, 0))
b_A = tl.load(p_A, boundary_check=(0, 1)).to(tl.float32)
else:
desc = make_tensor_descriptor(A, [T, BT], [H * BT, 1], [16, 16])
desc_o = make_tensor_descriptor(Ai, [T, 16], [H * 16, 1], [16, 16])
b_A = desc.load([i_t * 16, offset]).to(tl.float32)
b_A = -tl.where(m_A, b_A, 0)
for i in range(2, min(16, T - i_t * 16)):
b_a = -tl.load(A + (i_t * 16 + i) * H * BT + o_i + offset)
b_a = b_a + tl.sum(b_a[:, None] * b_A, 0)
b_A = tl.where((o_i == i)[:, None], b_a, b_A)
b_A += m_I
if not USE_TMA:
p_Ai = tl.make_block_ptr(Ai, (T, 16), (H * 16, 1), (i_t * 16, 0), (16, 16), (1, 0))
tl.store(p_Ai, b_A.to(p_Ai.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1))
else:
desc_o.store([i_t * 16, 0], b_A.to(desc_o.dtype, fp_downcast_rounding="rtne"))
@triton.heuristics({
'IS_VARLEN': lambda args: args['cu_seqlens'] is not None,
})
@triton.autotune(
configs=[
triton.Config({}, num_warps=num_warps, num_stages=num_stages)
for num_warps in [1, 2, 4, 8]
for num_stages in [2, 3, 4, 5]
],
key=['H', 'BT', 'IS_VARLEN'],
)
@triton.jit(do_not_specialize=['T'])
def merge_16x16_to_32x32_inverse_kernel(
A,
Ai,
cu_seqlens,
chunk_indices,
T,
H: tl.constexpr,
BT: tl.constexpr,
USE_TMA: tl.constexpr,
IS_VARLEN: tl.constexpr,
DOT_PRECISION: tl.constexpr,
):
i_t, i_bh = tl.program_id(0), tl.program_id(1)
i_b, i_h = i_bh // H, i_bh % H
if IS_VARLEN:
i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32)
bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32)
T = eos - bos
else:
bos, eos = i_b * T, i_b * T + T
o_i = tl.arange(0, 16)
m_A = o_i[:, None] > o_i[None, :]
m_I = o_i[:, None] == o_i[None, :]
A += (bos * H + i_h) * BT
Ai += (bos * H + i_h) * BT
if not USE_TMA:
p_A_11 = tl.make_block_ptr(A, (T, BT), (H * BT, 1), (i_t * BT, 0), (16, 16), (1, 0))
p_A_22 = tl.make_block_ptr(A, (T, BT), (H * BT, 1), (i_t * BT + 16, 16), (16, 16), (1, 0))
b_Ai_11 = tl.load(p_A_11, boundary_check=(0, 1)).to(tl.float32)
b_Ai_22 = tl.load(p_A_22, boundary_check=(0, 1)).to(tl.float32)
else:
desc = make_tensor_descriptor(A, [T, BT], [H * BT, 1], [16, 16])
desc_o = make_tensor_descriptor(Ai, [T, BT], [H * BT, 1], [16, 16])
b_Ai_11 = desc.load([i_t * BT + 0, 0]).to(tl.float32)
b_Ai_22 = desc.load([i_t * BT + 16, 16]).to(tl.float32)
b_Ai_11 = -tl.where(m_A, b_Ai_11, 0)
b_Ai_22 = -tl.where(m_A, b_Ai_22, 0)
for i in range(2, min(16, T - i_t * BT)):
b_a_11 = -tl.load(A + (i_t * BT + i) * H * BT + o_i)
b_a_11 += tl.sum(b_a_11[:, None] * b_Ai_11, 0)
b_Ai_11 = tl.where((o_i == i)[:, None], b_a_11, b_Ai_11)
for i in range(16 + 2, min(32, T - i_t * BT)):
b_a_22 = -tl.load(A + (i_t * BT + i) * H * BT + o_i + 16)
b_a_22 += tl.sum(b_a_22[:, None] * b_Ai_22, 0)
b_Ai_22 = tl.where((o_i == i - 16)[:, None], b_a_22, b_Ai_22)
b_Ai_11 += m_I
b_Ai_22 += m_I
if not USE_TMA:
p_A_21 = tl.make_block_ptr(A, (T, BT), (H * BT, 1), (i_t * BT + 16, 0), (16, 16), (1, 0))
b_A_21 = tl.load(p_A_21, boundary_check=(0, 1)).to(tl.float32)
else:
b_A_21 = desc.load([i_t * BT + 16, 0]).to(tl.float32)
b_Ai_21 = -tl.dot(tl.dot(b_Ai_22, b_A_21, input_precision=DOT_PRECISION), b_Ai_11, input_precision=DOT_PRECISION)
if not USE_TMA:
p_Ai_11 = tl.make_block_ptr(Ai, (T, BT), (H * BT, 1), (i_t * BT, 0), (16, 16), (1, 0))
p_Ai_21 = tl.make_block_ptr(Ai, (T, BT), (H * BT, 1), (i_t * BT + 16, 0), (16, 16), (1, 0))
p_Ai_22 = tl.make_block_ptr(Ai, (T, BT), (H * BT, 1), (i_t * BT + 16, 16), (16, 16), (1, 0))
tl.store(p_Ai_11, b_Ai_11.to(p_Ai_11.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1))
tl.store(p_Ai_22, b_Ai_22.to(p_Ai_22.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1))
tl.store(p_Ai_21, b_Ai_21.to(p_Ai_21.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1))
else:
desc_o.store([i_t * BT + 0, 0], b_Ai_11.to(desc_o.dtype, fp_downcast_rounding="rtne"))
desc_o.store([i_t * BT + 16, 0], b_Ai_21.to(desc_o.dtype, fp_downcast_rounding="rtne"))
desc_o.store([i_t * BT + 16, 16], b_Ai_22.to(desc_o.dtype, fp_downcast_rounding="rtne"))
@triton.heuristics({
'IS_VARLEN': lambda args: args['cu_seqlens'] is not None,
})
@triton.autotune(
configs=[
triton.Config({}, num_warps=num_warps, num_stages=num_stages)
for num_warps in [2, 4, 8]
for num_stages in [2, 3, 4, 5]
],
key=['H', 'BT', 'IS_VARLEN'],
)
@triton.jit(do_not_specialize=['T'])
def merge_16x16_to_64x64_inverse_kernel(
A,
Ai,
cu_seqlens,
chunk_indices,
T,
H: tl.constexpr,
BT: tl.constexpr,
USE_TMA: tl.constexpr,
IS_VARLEN: tl.constexpr,
DOT_PRECISION: tl.constexpr,
):
i_t, i_bh = tl.program_id(0), tl.program_id(1)
i_b, i_h = i_bh // H, i_bh % H
if IS_VARLEN:
i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32)
bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32)
T = eos - bos
else:
bos, eos = i_b * T, i_b * T + T
o_i = tl.arange(0, 16)
m_A = o_i[:, None] > o_i[None, :]
m_I = o_i[:, None] == o_i[None, :]
A += (bos * H + i_h) * BT
Ai += (bos * H + i_h) * BT
if not USE_TMA:
p_A_11 = tl.make_block_ptr(A, (T, BT), (H * BT, 1), (i_t * BT, 0), (16, 16), (1, 0))
p_A_22 = tl.make_block_ptr(A, (T, BT), (H * BT, 1), (i_t * BT + 16, 16), (16, 16), (1, 0))
p_A_33 = tl.make_block_ptr(A, (T, BT), (H * BT, 1), (i_t * BT + 32, 32), (16, 16), (1, 0))
p_A_44 = tl.make_block_ptr(A, (T, BT), (H * BT, 1), (i_t * BT + 48, 48), (16, 16), (1, 0))
b_Ai_11 = tl.load(p_A_11, boundary_check=(0, 1)).to(tl.float32)
b_Ai_22 = tl.load(p_A_22, boundary_check=(0, 1)).to(tl.float32)
b_Ai_33 = tl.load(p_A_33, boundary_check=(0, 1)).to(tl.float32)
b_Ai_44 = tl.load(p_A_44, boundary_check=(0, 1)).to(tl.float32)
else:
desc = make_tensor_descriptor(A, [T, BT], [H * BT, 1], [16, 16])
desc_o = make_tensor_descriptor(Ai, [T, BT], [H * BT, 1], [16, 16])
b_Ai_11 = desc.load([i_t * BT + 0, 0]).to(tl.float32)
b_Ai_22 = desc.load([i_t * BT + 16, 16]).to(tl.float32)
b_Ai_33 = desc.load([i_t * BT + 32, 32]).to(tl.float32)
b_Ai_44 = desc.load([i_t * BT + 48, 48]).to(tl.float32)
b_Ai_11 = -tl.where(m_A, b_Ai_11, 0)
b_Ai_22 = -tl.where(m_A, b_Ai_22, 0)
b_Ai_33 = -tl.where(m_A, b_Ai_33, 0)
b_Ai_44 = -tl.where(m_A, b_Ai_44, 0)
for i in range(2, min(16, T - i_t * BT)):
b_a_11 = -tl.load(A + (i_t * BT + i) * H * BT + o_i)
b_a_11 += tl.sum(b_a_11[:, None] * b_Ai_11, 0)
b_Ai_11 = tl.where((o_i == i)[:, None], b_a_11, b_Ai_11)
for i in range(16 + 2, min(32, T - i_t * BT)):
b_a_22 = -tl.load(A + (i_t * BT + i) * H * BT + o_i + 16)
b_a_22 += tl.sum(b_a_22[:, None] * b_Ai_22, 0)
b_Ai_22 = tl.where((o_i == i - 16)[:, None], b_a_22, b_Ai_22)
for i in range(32 + 2, min(48, T - i_t * BT)):
b_a_33 = -tl.load(A + (i_t * BT + i) * H * BT + o_i + 32)
b_a_33 += tl.sum(b_a_33[:, None] * b_Ai_33, 0)
b_Ai_33 = tl.where((o_i == i - 32)[:, None], b_a_33, b_Ai_33)
for i in range(48 + 2, min(64, T - i_t * BT)):
b_a_44 = -tl.load(A + (i_t * BT + i) * H * BT + o_i + 48)
b_a_44 += tl.sum(b_a_44[:, None] * b_Ai_44, 0)
b_Ai_44 = tl.where((o_i == i - 48)[:, None], b_a_44, b_Ai_44)
b_Ai_11 += m_I
b_Ai_22 += m_I
b_Ai_33 += m_I
b_Ai_44 += m_I
if not USE_TMA:
p_A_21 = tl.make_block_ptr(A, (T, BT), (H * BT, 1), (i_t * BT + 16, 0), (16, 16), (1, 0))
p_A_31 = tl.make_block_ptr(A, (T, BT), (H * BT, 1), (i_t * BT + 32, 0), (16, 16), (1, 0))
p_A_32 = tl.make_block_ptr(A, (T, BT), (H * BT, 1), (i_t * BT + 32, 16), (16, 16), (1, 0))
p_A_41 = tl.make_block_ptr(A, (T, BT), (H * BT, 1), (i_t * BT + 48, 0), (16, 16), (1, 0))
p_A_42 = tl.make_block_ptr(A, (T, BT), (H * BT, 1), (i_t * BT + 48, 16), (16, 16), (1, 0))
p_A_43 = tl.make_block_ptr(A, (T, BT), (H * BT, 1), (i_t * BT + 48, 32), (16, 16), (1, 0))
b_A_21 = tl.load(p_A_21, boundary_check=(0, 1)).to(tl.float32)
b_A_31 = tl.load(p_A_31, boundary_check=(0, 1)).to(tl.float32)
b_A_32 = tl.load(p_A_32, boundary_check=(0, 1)).to(tl.float32)
b_A_41 = tl.load(p_A_41, boundary_check=(0, 1)).to(tl.float32)
b_A_42 = tl.load(p_A_42, boundary_check=(0, 1)).to(tl.float32)
b_A_43 = tl.load(p_A_43, boundary_check=(0, 1)).to(tl.float32)
else:
b_A_21 = desc.load([i_t * BT + 16, 0]).to(tl.float32)
b_A_31 = desc.load([i_t * BT + 32, 0]).to(tl.float32)
b_A_32 = desc.load([i_t * BT + 32, 16]).to(tl.float32)
b_A_41 = desc.load([i_t * BT + 48, 0]).to(tl.float32)
b_A_42 = desc.load([i_t * BT + 48, 16]).to(tl.float32)
b_A_43 = desc.load([i_t * BT + 48, 32]).to(tl.float32)
b_Ai_21 = -tl.dot(tl.dot(b_Ai_22, b_A_21, input_precision=DOT_PRECISION), b_Ai_11, input_precision=DOT_PRECISION)
b_Ai_32 = -tl.dot(tl.dot(b_Ai_33, b_A_32, input_precision=DOT_PRECISION), b_Ai_22, input_precision=DOT_PRECISION)
b_Ai_43 = -tl.dot(tl.dot(b_Ai_44, b_A_43, input_precision=DOT_PRECISION), b_Ai_33, input_precision=DOT_PRECISION)
b_Ai_31 = -tl.dot(
b_Ai_33,
tl.dot(b_A_31, b_Ai_11, input_precision=DOT_PRECISION) +
tl.dot(b_A_32, b_Ai_21, input_precision=DOT_PRECISION),
input_precision=DOT_PRECISION,
)
b_Ai_42 = -tl.dot(
b_Ai_44,
tl.dot(b_A_42, b_Ai_22, input_precision=DOT_PRECISION) +
tl.dot(b_A_43, b_Ai_32, input_precision=DOT_PRECISION),
input_precision=DOT_PRECISION,
)
b_Ai_41 = -tl.dot(
b_Ai_44,
tl.dot(b_A_41, b_Ai_11, input_precision=DOT_PRECISION) +
tl.dot(b_A_42, b_Ai_21, input_precision=DOT_PRECISION) +
tl.dot(b_A_43, b_Ai_31, input_precision=DOT_PRECISION),
input_precision=DOT_PRECISION,
)
if not USE_TMA:
p_Ai_11 = tl.make_block_ptr(Ai, (T, BT), (H * BT, 1), (i_t * BT, 0), (16, 16), (1, 0))
p_Ai_22 = tl.make_block_ptr(Ai, (T, BT), (H * BT, 1), (i_t * BT + 16, 16), (16, 16), (1, 0))
p_Ai_33 = tl.make_block_ptr(Ai, (T, BT), (H * BT, 1), (i_t * BT + 32, 32), (16, 16), (1, 0))
p_Ai_44 = tl.make_block_ptr(Ai, (T, BT), (H * BT, 1), (i_t * BT + 48, 48), (16, 16), (1, 0))
p_Ai_21 = tl.make_block_ptr(Ai, (T, BT), (H * BT, 1), (i_t * BT + 16, 0), (16, 16), (1, 0))
p_Ai_31 = tl.make_block_ptr(Ai, (T, BT), (H * BT, 1), (i_t * BT + 32, 0), (16, 16), (1, 0))
p_Ai_32 = tl.make_block_ptr(Ai, (T, BT), (H * BT, 1), (i_t * BT + 32, 16), (16, 16), (1, 0))
p_Ai_41 = tl.make_block_ptr(Ai, (T, BT), (H * BT, 1), (i_t * BT + 48, 0), (16, 16), (1, 0))
p_Ai_42 = tl.make_block_ptr(Ai, (T, BT), (H * BT, 1), (i_t * BT + 48, 16), (16, 16), (1, 0))
p_Ai_43 = tl.make_block_ptr(Ai, (T, BT), (H * BT, 1), (i_t * BT + 48, 32), (16, 16), (1, 0))
tl.store(p_Ai_11, b_Ai_11.to(p_Ai_11.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1))
tl.store(p_Ai_22, b_Ai_22.to(p_Ai_22.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1))
tl.store(p_Ai_33, b_Ai_33.to(p_Ai_33.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1))
tl.store(p_Ai_44, b_Ai_44.to(p_Ai_44.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1))
tl.store(p_Ai_21, b_Ai_21.to(p_Ai_21.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1))
tl.store(p_Ai_31, b_Ai_31.to(p_Ai_31.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1))
tl.store(p_Ai_32, b_Ai_32.to(p_Ai_32.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1))
tl.store(p_Ai_41, b_Ai_41.to(p_Ai_41.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1))
tl.store(p_Ai_42, b_Ai_42.to(p_Ai_42.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1))
tl.store(p_Ai_43, b_Ai_43.to(p_Ai_43.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1))
else:
desc_o.store([i_t * BT + 0, 0], b_Ai_11.to(desc_o.dtype, fp_downcast_rounding="rtne"))
desc_o.store([i_t * BT + 16, 16], b_Ai_22.to(desc_o.dtype, fp_downcast_rounding="rtne"))
desc_o.store([i_t * BT + 32, 32], b_Ai_33.to(desc_o.dtype, fp_downcast_rounding="rtne"))
desc_o.store([i_t * BT + 48, 48], b_Ai_44.to(desc_o.dtype, fp_downcast_rounding="rtne"))
desc_o.store([i_t * BT + 16, 0], b_Ai_21.to(desc_o.dtype, fp_downcast_rounding="rtne"))
desc_o.store([i_t * BT + 32, 0], b_Ai_31.to(desc_o.dtype, fp_downcast_rounding="rtne"))
desc_o.store([i_t * BT + 32, 16], b_Ai_32.to(desc_o.dtype, fp_downcast_rounding="rtne"))
desc_o.store([i_t * BT + 48, 0], b_Ai_41.to(desc_o.dtype, fp_downcast_rounding="rtne"))
desc_o.store([i_t * BT + 48, 16], b_Ai_42.to(desc_o.dtype, fp_downcast_rounding="rtne"))
desc_o.store([i_t * BT + 48, 32], b_Ai_43.to(desc_o.dtype, fp_downcast_rounding="rtne"))
@input_guard
def solve_tril_ori(
A: torch.Tensor,
cu_seqlens: torch.Tensor | None = None,
output_dtype: torch.dtype = torch.float,
) -> torch.Tensor:
"""
Compute the inverse of the matrix I + A
A should be strictly lower triangular, i.e., A.triu() == 0.
Args:
A (torch.Tensor):
[B, T, H, BT], where BT should only be 16, 32, or 64.
cu_seqlens (torch.Tensor):
The cumulative sequence lengths of the input tensor. Default: `None`.
output_dtype (torch.dtype):
The dtype of the output tensor. Default: `torch.float`.
If `None`, the output dtype will be the same as the input dtype.
Returns:
(I + A)^-1 with the same shape as A
"""
if A.shape[-1] not in [16, 32, 64]:
raise ValueError(
f"A shape BT should in [16,32, 64], but current is {A.shape[-1]}"
)
output_dtype = A.dtype if output_dtype is None else output_dtype
B, T, H, BT = A.shape
chunk_indices = prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None
NT = len(chunk_indices) if cu_seqlens is not None else triton.cdiv(T, BT)
Ai = torch.zeros_like(A, dtype=output_dtype)
if BT == 16:
merge_fn = solve_tril_16x16_kernel
elif BT == 32:
merge_fn = merge_16x16_to_32x32_inverse_kernel
elif BT == 64:
merge_fn = merge_16x16_to_64x64_inverse_kernel
merge_fn[NT, B * H](
A=A,
Ai=Ai,
cu_seqlens=cu_seqlens,
chunk_indices=chunk_indices,
T=T,
H=H,
BT=BT,
USE_TMA=is_tma_supported,
DOT_PRECISION=FLA_TRIL_PRECISION,
)
return Ai
@pytest.mark.parametrize(
('B', 'T', 'H', 'chunk_size'),
[
pytest.param(*test, id="B{}-T{}-H{}-chunk_size{}".format(*test))
for test in [
(1, 63, 1, 16),
(2, 500, 4, 32),
(2, 1000, 5, 64),
(3, 1024, 6, 64),
(4, 2048, 8, 64),
(1, 1024, 32, 16),
]
]
)
@pytest.mark.skipif(
device_platform == 'intel',
reason='Intel Pytorch Failure'
)
def test_solve_tril(B, T, H, chunk_size):
k = F.normalize(torch.randn((B, H, T, 128), dtype=torch.float32, device=device), dim=-1)
padding_size = (chunk_size - T % chunk_size) % chunk_size
k_padded = F.pad(k, (0, 0, 0, padding_size, 0, 0, 0, 0))
k_padded = k_padded.reshape(B, H, -1, chunk_size, 128)
A = (k_padded @ k_padded.transpose(-1, -2)).tril(-1)
ref = torch.inverse(A + torch.eye(A.shape[-1], device=A.device)[None, None, None, ...])
ref = ref.reshape(B, H, -1, chunk_size)[:, :, :T, :]
tri_ori = solve_tril_ori(A.reshape(B, H, -1, chunk_size)[:, :, :T, :].transpose(1, 2)).transpose(1, 2)
tri = solve_tril(A.reshape(B, H, -1, chunk_size)[:, :, :T, :].transpose(1, 2)).transpose(1, 2)
print("solve_tril_varlen ref and tri diff:", torch.max(torch.abs(ref - tri)))
assert_close('solve_tril_varlen', ref, tri, 0.0001)
print("solve_tril_varlen ori and tri diff:", torch.max(torch.abs(tri_ori - tri)))
assert_close('solve_tril_varlen', tri_ori, tri, 0.0001)
@pytest.mark.parametrize(
('H', 'D', 'chunk_size', 'cu_seqlens'),
[
pytest.param(*test, id="H{}-D{}-chunk_size{}-cu_seqlens{}".format(*test))
for test in [
(4, 64, 16, [0, 15]),
(4, 64, 32, [0, 256, 500, 1000]),
(4, 100, 64, [0, 15, 100, 300, 1200, 2000]),
(4, 64, 16, [0, 1, 100, 300, 1200, 2048]),
(4, 128, 32, [0, 200, 512, 1200, 2048]),
(32, 128, 16, [0, 10, 66, 140, 229, 351, 401, 574, 684, 819, 874, 922, 1024]),
]
],
)
@pytest.mark.skipif(
os.getenv('SKIP_TEST_CHUNK_VARLEN') == '1',
reason='Test skipped:Variable-length chunk test is disabled by environment variable SKIP_TEST_CHUNK_VARLEN=1',
)
@pytest.mark.skipif(
device_platform == 'intel',
reason='Intel Pytorch Failure',
)
def test_solve_tril_varlen(
H: int,
D: int,
chunk_size: int,
cu_seqlens: list[int],
):
T = cu_seqlens[-1]
cu_seqlens = torch.tensor(cu_seqlens, dtype=torch.int32, device=device)
k = F.normalize(torch.randn((1, T, H, D), dtype=torch.bfloat16, device=device), dim=-1)
g = torch.randn((1, T, H), device=device, dtype=torch.bfloat16)
beta = torch.randn((1, T, H), dtype=torch.bfloat16, device=device).sigmoid()
A = chunk_scaled_dot_kkt_fwd(k=k, g=g, beta=beta, cu_seqlens=cu_seqlens, chunk_size=chunk_size)
ref = torch.zeros_like(A)
for i in range(len(cu_seqlens) - 1):
for j in range(cu_seqlens[i], cu_seqlens[i + 1], chunk_size):
actual_size = min(chunk_size, cu_seqlens[i + 1] - j)
ref[:, j:j + actual_size, :, :actual_size] = torch.inverse(
A[:, j:j + actual_size, :, :actual_size].transpose(1, 2) +
torch.eye(actual_size, device=A.device, dtype=A.dtype)[None, None, ...],
).transpose(1, 2)
tri_ori = solve_tril_ori(A, cu_seqlens=cu_seqlens)
tri = solve_tril(A, cu_seqlens=cu_seqlens)
print("solve_tril_varlen ref and tri diff:", torch.max(torch.abs(ref - tri)))
assert_close('solve_tril_varlen', ref, tri, 0.0001)
print("solve_tril_varlen ori and tri diff:", torch.max(torch.abs(tri_ori - tri)))
assert_close('solve_tril_varlen', tri_ori, tri, 0.0001)