# -*- coding: utf-8 -*-
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
# Copyright (c) 2025, Huawei Technologies Co., Ltd. All rights reserved.

import os
from typing import Optional

import torch
import triton
import triton.language as tl

from mindspeed.ops.triton.utils import prepare_chunk_indices, make_tensor_descriptor, input_guard, is_amd


FLA_TRIL_PRECISION = os.environ.get('FLA_TRIL_PRECISION', 'ieee')


@triton.heuristics({"IS_VARLEN": lambda args: args["cu_seqlens"] is not None})
@triton.jit(do_not_specialize=["T", "TPP"])
def solve_tril_16x16_kernel(
    A,
    Ai,
    cu_seqlens,
    chunk_indices,
    T,
    H: tl.constexpr,
    BT: tl.constexpr,
    TPP: tl.constexpr,
    USE_TMA: tl.constexpr,
    IS_VARLEN: tl.constexpr,
    DOT_PRECISION: tl.constexpr
):
    pid_t, pid_bh = tl.program_id(0), tl.program_id(1)
    i_b, i_h = pid_bh // H, pid_bh % H

    base_t = pid_t * TPP

    if IS_VARLEN:
        i_n = tl.load(chunk_indices + base_t * 2).to(tl.int32)
        bos = tl.load(cu_seqlens + i_n).to(tl.int32)
        eos = tl.load(cu_seqlens + i_n + 1).to(tl.int32)
        T_eff = eos - bos
    else:
        bos, eos = i_b * T, i_b * T + T
        T_eff = T

    o_i = tl.arange(0, 16)
    o_i_fp32 = tl.arange(0, 16).to(tl.float32)
    m_A = o_i_fp32[:, None] > o_i_fp32[None, :]
    m_I = o_i_fp32[:, None] == o_i_fp32[None, :]

    A  = A  + (bos * H + i_h) * BT
    Ai = Ai + (bos * H + i_h) * BT

    for tpp in tl.static_range(0, TPP):
        tile_t = base_t + tpp
        tile_row = tile_t * 16

        offset = (tile_t * 16) % BT

        if not USE_TMA:
            p_A = tl.make_block_ptr(
                A, (T_eff, BT), (H * BT, 1), (tile_row, offset), (16, 16), (1, 0)
            )
            b_A_raw = tl.load(p_A, boundary_check=(0, 1)).to(tl.float32)
        else:
            desc = make_tensor_descriptor(A, [T_eff, BT], [H * BT, 1], [16, 16])
            desc_o = make_tensor_descriptor(Ai, [T_eff, 16], [H * 16, 1], [16, 16])
            b_A_raw = desc.load([tile_row, offset]).to(tl.float32)

        b_A_neg = -b_A_raw
        b_A = b_A_neg * m_A
        for i in range(2, min(16, T_eff - tile_row)):
            slice_res = tl.extract_slice(b_A_neg, [i, 0], [1, 16], [1, 1])
            b_a_val = tl.reshape(slice_res, (16,), can_reorder=True)
            dot_prod = tl.sum(b_a_val[:, None] * b_A, 0) 
            b_a_update = b_a_val + dot_prod
            b_A = tl.where((o_i_fp32 == i)[:, None], b_a_update, b_A)
        b_A += m_I

        if not USE_TMA:
            p_Ai = tl.make_block_ptr(
                Ai, (T_eff, 16), (H * 16, 1), (tile_row, 0), (16, 16), (1, 0)
            )
            tl.store(p_Ai, b_A.to(p_Ai.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1))
        else:
            desc_o.store([tile_row, 0], b_A.to(desc_o.dtype, fp_downcast_rounding="rtne"))


@triton.heuristics({
    'IS_VARLEN': lambda args: args['cu_seqlens'] is not None
})
@triton.autotune(
    configs=[
        triton.Config({}, num_warps=num_warps, num_stages=num_stages)
        for num_warps in [1, 2, 4, 8]
        for num_stages in [2, 3, 4, 5]
    ],
    key=['H', 'BT', 'IS_VARLEN'],
)
@triton.jit(do_not_specialize=["T", "TPP"])
def merge_16x16_to_32x32_inverse_kernel(
    A,
    Ai,
    cu_seqlens,
    chunk_indices,
    T,
    H: tl.constexpr,
    BT: tl.constexpr,
    TPP: tl.constexpr,
    USE_TMA: tl.constexpr,
    IS_VARLEN: tl.constexpr,
    DOT_PRECISION: tl.constexpr
):
    i_t, i_bh = tl.program_id(0), tl.program_id(1)
    i_b, i_h = i_bh // H, i_bh % H
    if IS_VARLEN:
        i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32)
        bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32)
        T = eos - bos
    else:
        bos, eos = i_b * T, i_b * T + T

    o_i = tl.arange(0, 16)
    m_A = o_i[:, None] > o_i[None, :]
    m_I = o_i[:, None] == o_i[None, :]
    A += (bos * H + i_h) * BT
    Ai += (bos * H + i_h) * BT

    if not USE_TMA:
        p_A_11 = tl.make_block_ptr(A, (T, BT), (H * BT, 1), (i_t * BT, 0), (16, 16), (1, 0))
        p_A_22 = tl.make_block_ptr(A, (T, BT), (H * BT, 1), (i_t * BT + 16, 16), (16, 16), (1, 0))
        b_Ai_11 = tl.load(p_A_11, boundary_check=(0, 1)).to(tl.float32)
        b_Ai_22 = tl.load(p_A_22, boundary_check=(0, 1)).to(tl.float32)
    else:
        desc = make_tensor_descriptor(A, [T, BT], [H * BT, 1], [16, 16])
        desc_o = make_tensor_descriptor(Ai, [T, BT], [H * BT, 1], [16, 16])
        b_Ai_11 = desc.load([i_t * BT + 0, 0]).to(tl.float32)
        b_Ai_22 = desc.load([i_t * BT + 16, 16]).to(tl.float32)

    b_Ai_11 = -tl.where(m_A, b_Ai_11, 0)
    b_Ai_22 = -tl.where(m_A, b_Ai_22, 0)

    for i in range(2, min(16, T - i_t * BT)):
        b_a_11 = -tl.load(A + (i_t * BT + i) * H * BT + o_i)
        b_a_11 += tl.sum(b_a_11[:, None] * b_Ai_11, 0)
        b_Ai_11 = tl.where((o_i == i)[:, None], b_a_11, b_Ai_11)
    for i in range(16 + 2, min(32, T - i_t * BT)):
        b_a_22 = -tl.load(A + (i_t * BT + i) * H * BT + o_i + 16)
        b_a_22 += tl.sum(b_a_22[:, None] * b_Ai_22, 0)
        b_Ai_22 = tl.where((o_i == i - 16)[:, None], b_a_22, b_Ai_22)

    b_Ai_11 += m_I
    b_Ai_22 += m_I

    if not USE_TMA:
        p_A_21 = tl.make_block_ptr(A, (T, BT), (H * BT, 1), (i_t * BT + 16, 0), (16, 16), (1, 0))
        b_A_21 = tl.load(p_A_21, boundary_check=(0, 1)).to(tl.float32)
    else:
        b_A_21 = desc.load([i_t * BT + 16, 0]).to(tl.float32)

    b_Ai_21 = -tl.dot(tl.dot(b_Ai_22, b_A_21, input_precision=DOT_PRECISION), b_Ai_11, input_precision=DOT_PRECISION)

    if not USE_TMA:
        p_Ai_11 = tl.make_block_ptr(Ai, (T, BT), (H * BT, 1), (i_t * BT, 0), (16, 16), (1, 0))
        p_Ai_21 = tl.make_block_ptr(Ai, (T, BT), (H * BT, 1), (i_t * BT + 16, 0), (16, 16), (1, 0))
        p_Ai_22 = tl.make_block_ptr(Ai, (T, BT), (H * BT, 1), (i_t * BT + 16, 16), (16, 16), (1, 0))
        tl.store(p_Ai_11, b_Ai_11.to(p_Ai_11.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1))
        tl.store(p_Ai_22, b_Ai_22.to(p_Ai_22.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1))
        tl.store(p_Ai_21, b_Ai_21.to(p_Ai_21.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1))
    else:
        desc_o.store([i_t * BT + 0, 0], b_Ai_11.to(desc_o.dtype, fp_downcast_rounding="rtne"))
        desc_o.store([i_t * BT + 16, 0], b_Ai_21.to(desc_o.dtype, fp_downcast_rounding="rtne"))
        desc_o.store([i_t * BT + 16, 16], b_Ai_22.to(desc_o.dtype, fp_downcast_rounding="rtne"))


@triton.heuristics({"IS_VARLEN": lambda args: args["cu_seqlens"] is not None})
@triton.jit(do_not_specialize=["T", "TPP"])
def solve_tril_64x64_kernel(
    A,
    Ai,
    cu_seqlens,
    chunk_indices,
    T,
    H: tl.constexpr,
    BT: tl.constexpr,
    TPP: tl.constexpr,
    USE_TMA: tl.constexpr,
    IS_VARLEN: tl.constexpr,
    DOT_PRECISION: tl.constexpr
):
    pid_t, pid_bh = tl.program_id(0), tl.program_id(1)
    i_b, i_h = pid_bh // H, pid_bh % H

    base_t = pid_t * TPP

    if IS_VARLEN:
        i_n = tl.load(chunk_indices + base_t * 2).to(tl.int32)
        bos = tl.load(cu_seqlens + i_n).to(tl.int32)
        eos = tl.load(cu_seqlens + i_n + 1).to(tl.int32)
        T_eff = eos - bos
    else:
        bos, eos = i_b * T, i_b * T + T
        T_eff = T

    o_i_fp32 = tl.arange(0, 64).to(tl.float32)
    m_A = o_i_fp32[:, None] > o_i_fp32[None, :]  
    m_I = o_i_fp32[:, None] == o_i_fp32[None, :] 

    A = A + (bos * H + i_h) * BT
    Ai = Ai + (bos * H + i_h) * BT

    for tpp in tl.static_range(0, TPP):
        tile_t = base_t + tpp
        tile_row = tile_t * 64

        offset = (tile_t * 64) % BT

        if not USE_TMA:
            p_A = tl.make_block_ptr(
                A, (T_eff, BT), (H * BT, 1), (tile_row, offset), (64, 64), (1, 0)
            )
            b_A_raw = tl.load(p_A, boundary_check=(0, 1)).to(tl.float32)
        else:
            desc = make_tensor_descriptor(A, [T_eff, BT], [H * BT, 1], [64, 64])
            desc_o = make_tensor_descriptor(Ai, [T_eff, 64], [H * 64, 1], [64, 64])
            b_A_raw = desc.load([tile_row, offset]).to(tl.float32)

        b_A_neg = -b_A_raw
        b_A = b_A_neg * m_A

        # Fully On-Chip
        limit = min(64, T_eff - tile_row)
        for i in range(2, limit):     
            slice_res = tl.extract_slice(b_A_neg, [i, 0], [1, 64], [1, 1])
            b_a_val = tl.reshape(slice_res, (64,), can_reorder=True)
            dot_prod = tl.sum(b_a_val[:, None] * b_A, 0) 
            b_a_update = b_a_val + dot_prod
            b_A = tl.where((o_i_fp32 == i)[:, None], b_a_update, b_A)

        b_A += m_I

        if not USE_TMA:
            p_Ai = tl.make_block_ptr(
                Ai, (T_eff, 64), (H * 64, 1), (tile_row, 0), (64, 64), (1, 0)
            )
            tl.store(p_Ai, b_A.to(p_Ai.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1))
        else:
            desc_o.store([tile_row, 0], b_A.to(desc_o.dtype, fp_downcast_rounding="rtne"))


@input_guard
def solve_tril(
    A: torch.Tensor,
    cu_seqlens: Optional[torch.Tensor] = None,
    output_dtype: torch.dtype = torch.float
) -> torch.Tensor:
    """
    Compute the inverse of the matrix I + A
    A should be strictly lower triangular, i.e., A.triu() == 0.

    Args:
        A (torch.Tensor):
            [B, T, H, BT], where BT should only be 16, 32, or 64.
        cu_seqlens (torch.Tensor):
            The cumulative sequence lengths of the input tensor. Default: `None`.
        output_dtype (torch.dtype):
            The dtype of the output tensor. Default: `torch.float`.
            If `None`, the output dtype will be the same as the input dtype.

    Returns:
        (I + A)^-1 with the same shape as A
    """
    if A.shape[-1] not in [16, 32, 64]:
        raise ValueError(
            f"A shape BT should in [16,32, 64], but current is {A.shape[-1]}"
        )
    output_dtype = A.dtype if output_dtype is None else output_dtype

    B, T, H, BT = A.shape
    chunk_indices = prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None
    NT = len(chunk_indices) if cu_seqlens is not None else triton.cdiv(T, BT)

    Ai = torch.zeros_like(A, dtype=output_dtype)

    if BT == 16:
        TPP = 4
        grid0 = (NT + TPP - 1) // TPP
        merge_fn = solve_tril_16x16_kernel
    elif BT == 32:
        TPP = 4
        grid0 = NT
        merge_fn = merge_16x16_to_32x32_inverse_kernel
    elif BT == 64:
        TPP = 22
        grid0 = (NT + TPP - 1) // TPP
        merge_fn = solve_tril_64x64_kernel

    merge_fn[grid0, B * H](
        A=A,
        Ai=Ai,
        cu_seqlens=cu_seqlens,
        chunk_indices=chunk_indices,
        T=T,
        H=H,
        BT=BT,
        TPP=TPP,
        USE_TMA=False,
        DOT_PRECISION=FLA_TRIL_PRECISION,
    )
    return Ai