# -*- coding: utf-8 -*-
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
# Copyright (c) 2025, Huawei Technologies Co., Ltd. All rights reserved.

import os
from typing import List, Optional, Tuple

import pytest
import torch
import triton
import triton.language as tl

from mindspeed.ops.triton.utils import assert_close
from mindspeed.ops.triton.cumsum import chunk_local_cumsum
from mindspeed.ops.triton.utils import prepare_chunk_indices


@triton.heuristics({
    'HAS_SCALE': lambda args: args['scale'] is not None,
    'IS_VARLEN': lambda args: args['cu_seqlens'] is not None,
})
@triton.autotune(
    configs=[
        triton.Config({}, num_warps=num_warps)
        for num_warps in [1, 2, 4, 8]
    ],
    key=['B', 'H', 'BT', 'IS_VARLEN', 'REVERSE'],
)
@triton.jit(do_not_specialize=['T'])
def chunk_local_cumsum_scalar_kernel(
    s,
    o,
    scale,
    cu_seqlens,
    chunk_indices,
    T,
    B: tl.constexpr,
    H: tl.constexpr,
    BT: tl.constexpr,
    REVERSE: tl.constexpr,
    HAS_SCALE: tl.constexpr,
    IS_VARLEN: tl.constexpr,
    HEAD_FIRST: tl.constexpr,
):
    i_t, i_bh = tl.program_id(0), tl.program_id(1)
    i_b, i_h = i_bh // H, i_bh % H
    if IS_VARLEN:
        i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32)
        bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32)
        T = eos - bos
    else:
        bos, eos = i_b * T, i_b * T + T

    if HEAD_FIRST:
        p_s = tl.make_block_ptr(s + bos * H + i_h * T, (T,), (1,), (i_t * BT,), (BT,), (0,))
        p_o = tl.make_block_ptr(o + bos * H + i_h * T, (T,), (1,), (i_t * BT,), (BT,), (0,))
    else:
        p_s = tl.make_block_ptr(s + bos * H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,))
        p_o = tl.make_block_ptr(o + bos * H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,))

    b_s = tl.load(p_s, boundary_check=(0,)).to(tl.float32)
    b_o = tl.cumsum(b_s, axis=0)
    if REVERSE:
        b_z = tl.sum(b_s, axis=0)
        b_o = -b_o + b_z[None] + b_s
    if HAS_SCALE:
        b_o *= scale
    tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0,))


def chunk_local_cumsum_scalar(
    g: torch.Tensor,
    chunk_size: int,
    reverse: bool = False,
    scale: float = None,
    cu_seqlens: torch.Tensor | None = None,
    head_first: bool = False,
    output_dtype: torch.dtype | None = torch.float,
    chunk_indices: torch.LongTensor | None = None,
) -> torch.Tensor:
    if head_first:
        B, H, T = g.shape
    else:
        B, T, H = g.shape
    if chunk_size != 2 ** (chunk_size.bit_length() - 1):
        raise ValueError(
            f"chunk_size must be a power of 2, chunk_size is{chunk_size}"
        )
    BT = chunk_size
    if chunk_indices is None and cu_seqlens is not None:
        chunk_indices = prepare_chunk_indices(cu_seqlens, BT)
    NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices)
    g_org, g = g, torch.empty_like(g, dtype=output_dtype or g.dtype)
    grid = (NT, B * H)
    chunk_local_cumsum_scalar_kernel[grid](
        s=g_org,
        o=g,
        scale=scale,
        cu_seqlens=cu_seqlens,
        chunk_indices=chunk_indices,
        T=T,
        B=B,
        H=H,
        BT=BT,
        HEAD_FIRST=head_first,
        REVERSE=reverse,
    )
    return g


def origin_chunk_local_cumsum(
    g: torch.Tensor,
    chunk_size: int,
    reverse: bool = False,
    scale: float = None,
    cu_seqlens: torch.Tensor | None = None,
    head_first: bool = False,
    output_dtype: torch.dtype | None = torch.float,
    chunk_indices: torch.LongTensor | None = None,
    **kwargs,
) -> torch.Tensor:
    if cu_seqlens is not None:
        if g.shape[0] != 1:
            raise ValueError(
                f"Only batch size 1 is supported when cu_seqlens are provided, current size is{g.shape[0]}"
            )
    if len(g.shape) == 3:
        return chunk_local_cumsum_scalar(
            g=g,
            chunk_size=chunk_size,
            reverse=reverse,
            scale=scale,
            cu_seqlens=cu_seqlens,
            head_first=head_first,
            output_dtype=output_dtype,
            chunk_indices=chunk_indices,
        )
    else:
        raise ValueError(
            f"Unsupported input shape {g.shape}, "
            f"which should be (B, T, H, D) if `head_first=False` "
            f"or (B, H, T, D) otherwise",
        )


@pytest.mark.parametrize(
    ('B', 'T', 'H', 'chunk_size', 'reverse', 'cu_seqlens'),
    [
        pytest.param(*test, id="B{}-T{}-H{}-chunk_size{}-reverse{}-cu_seqlens{}".format(*test))
        for test in [
        (1, 1024, 32, 64, False, None),
        (1, 4096, 32, 64, False, None),
        (1, 1024, 32, 64, True, None),
        (1, 4096, 32, 64, True, None),
        (1, 1024, 32, 64, False, [0, 175, 1024]),
        (1, 4096, 32, 64, False, [0, 175, 1024, 2764, 4096]),
        (1, 1024, 32, 64, True, [0, 175, 1024]),
        (1, 4096, 32, 64, True, [0, 175, 1024, 2764, 4096]),
        (2, 1024, 32, 64, False, None),
        (2, 4096, 32, 64, False, None),
        (2, 1024, 32, 64, True, None),
        (2, 4096, 32, 64, True, None),
        (1, 1024, 32, 16, False, None),# 统一用例
        (1, 1024, 32, 16, False, [0, 10, 66, 140, 229, 351, 401, 574, 684, 819, 874, 922, 1024]),
    ]
    ]
)
def test_cumsum(B, T, H, chunk_size, reverse, cu_seqlens):
    device = "npu:0"
    device_dtype = torch.bfloat16
    torch.manual_seed(42)
    torch.npu.manual_seed(42)  # 补充NPU随机种子

    if cu_seqlens is not None:
        cu_seqlens = torch.LongTensor(cu_seqlens).to(device)

    g = torch.randn((B, T, H), device=device, dtype=device_dtype)
    cu_seqlens = cu_seqlens

    ref_g = origin_chunk_local_cumsum(
        g=g,
        chunk_size=chunk_size,
        reverse=reverse,
        cu_seqlens=cu_seqlens,
        head_first=False,
    )

    cur_g = chunk_local_cumsum(
        g=g,
        chunk_size=chunk_size,
        reverse=reverse,
        cu_seqlens=cu_seqlens,
        head_first=False
    )

    print("g diff:", torch.max(torch.abs(ref_g - cur_g)))
    assert_close('g', ref_g, cur_g, 0.001)