/**
 * Copyright (c) 2026 Huawei Technologies Co., Ltd.
 * This program is free software, you can redistribute it and/or modify it under the terms and conditions of
 * CANN Open Software License Agreement Version 2.0 (the "License").
 * Please refer to the License for details. You may not use this file except in compliance with the License.
 * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
 * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
 * See LICENSE in the root of the software repository for the full text of the License.
 */

#include <cstdint>
#include "kernel_operator.h"
#include "simt_api/asc_simt.h"
#include "sgeqrf_batched_tiling_data.h"

using namespace AscendC;

constexpr uint32_t GEQRF_SIMT_THREADS = 2048;
constexpr uint32_t GEQRF_SMEM_SIZE = 2048;

__ubuf__ float g_smem[GEQRF_SMEM_SIZE];

__simt_callee__ __aicore__ inline float WarpReduceSum(float val)
{
    for (int offset = 16; offset > 0; offset /= 2) {
        val += asc_shfl_down(val, offset, 32);
    }
    return val;
}

__simt_callee__ __aicore__ inline float BlockReduceSum(float localVal, __ubuf__ float* smem)
{
    uint32_t tid = threadIdx.x;
    uint32_t warpId = tid / 32;
    uint32_t laneId = tid % 32;
    uint32_t numWarps = (blockDim.x + 31) / 32;

    float warpSum = WarpReduceSum(localVal);
    if (laneId == 0) {
        smem[warpId] = warpSum;
    }
    asc_syncthreads();

    float result = 0.0f;
    if (tid < 32) {
        float val = (tid < numWarps) ? smem[tid] : 0.0f;
        result = WarpReduceSum(val);
    }
    asc_syncthreads();
    return result;
}

__simt_callee__ __aicore__ inline void ComputeCoefficients(
    uint32_t m, uint32_t n, uint64_t lda, uint32_t i, float tau, __gm__ float* A, __ubuf__ float* smem)
{
    uint32_t tid = threadIdx.x;
    uint32_t numWarps = (blockDim.x + 31) / 32;

    for (uint32_t c = i + 1; c < n; c++) {
        float local_dot = (tid == 0) ? A[i + c * lda] : 0.0f;
        for (uint32_t row = i + 1 + tid; row < m; row += blockDim.x) {
            local_dot += A[row + i * lda] * A[row + c * lda];
        }
        float dot = BlockReduceSum(local_dot, smem);
        if (tid == 0) {
            smem[numWarps + (c - i - 1)] = tau * dot;
        }
        asc_syncthreads();
    }
}

__simt_callee__ __aicore__ inline void ApplyUpdate(
    uint32_t m, uint32_t n, uint64_t lda, uint32_t i, __gm__ float* A, __ubuf__ float* smem)
{
    uint32_t tid = threadIdx.x;
    uint32_t numWarps = (blockDim.x + 31) / 32;

    for (uint32_t c = i + 1; c < n; c++) {
        float coeff = smem[numWarps + (c - i - 1)];
        if (tid == 0) {
            smem[0] = A[i + c * lda];
        }
        asc_syncthreads();
        float aiOrig = smem[0];
        for (uint32_t row = i + 1 + tid; row < m; row += blockDim.x) {
            A[row + c * lda] -= coeff * A[row + i * lda];
        }
        if (tid == 0) {
            A[i + c * lda] = aiOrig - coeff;
        }
        asc_syncthreads();
    }
}

__simt_callee__ __aicore__ inline void ApplyRank1Update(
    uint32_t m, uint32_t n, uint64_t lda, uint32_t i, float tau, __gm__ float* A, __ubuf__ float* smem)
{
    ComputeCoefficients(m, n, lda, i, tau, A, smem);
    ApplyUpdate(m, n, lda, i, A, smem);
}

__simt_callee__ __aicore__ inline float ComputeSigma(
    uint32_t m, uint64_t lda, uint32_t i, __gm__ float* A, __ubuf__ float* smem)
{
    uint32_t tid = threadIdx.x;

    float local_sigma = 0.0f;
    for (uint32_t row = i + 1 + tid; row < m; row += blockDim.x) {
        float val = A[row + i * lda];
        local_sigma += val * val;
    }

    float sigma = BlockReduceSum(local_sigma, smem);
    if (tid == 0) {
        smem[0] = sigma;
    }
    asc_syncthreads();
    return smem[0];
}

__simt_callee__ __aicore__ inline void ComputeTauAndNormalize(
    uint32_t m, uint64_t lda, uint32_t i, float sigma, __gm__ float* A, __gm__ float* Tau, __ubuf__ float* smem)
{
    uint32_t tid = threadIdx.x;

    float x1 = A[i + i * lda];
    float tau;
    float alpha;
    if (sigma == 0.0f) {
        tau = 0.0f;
        alpha = x1;
    } else {
        float normX = sqrtf(sigma + x1 * x1);
        alpha = (x1 >= 0.0f) ? -normX : normX;
        tau = (alpha - x1) / alpha;
    }
    if (tid == 0) {
        smem[0] = alpha;
        smem[1] = tau;
    }
    asc_syncthreads();
    alpha = smem[0];
    tau = smem[1];
    asc_syncthreads();

    if (tau == 0.0f) {
        if (tid == 0) {
            A[i + i * lda] = alpha;
            Tau[i] = tau;
        }
        asc_syncthreads();
        return;
    }

    float vScale = x1 - alpha;
    for (uint32_t row = i + 1 + tid; row < m; row += blockDim.x) {
        A[row + i * lda] = A[row + i * lda] / vScale;
    }
    if (tid == 0) {
        A[i + i * lda] = alpha;
        Tau[i] = tau;
    }
    asc_syncthreads();
}

__simt_vf__ __aicore__ LAUNCH_BOUND(GEQRF_SIMT_THREADS) inline void GeqrfSimtFp32(
    uint32_t m, uint32_t n, uint64_t lda, uint32_t numB, uint32_t startB, __gm__ float* aarrayBase,
    __gm__ float* tauarrayBase, __ubuf__ float* smem)
{
    __gm__ uintptr_t* aPtrAddrs = reinterpret_cast<__gm__ uintptr_t*>(aarrayBase);
    __gm__ uintptr_t* tauPtrAddrs = reinterpret_cast<__gm__ uintptr_t*>(tauarrayBase);
    uint32_t k = (m < n) ? m : n;

    for (uint32_t b = 0; b < numB; b++) {
        __gm__ float* A = reinterpret_cast<__gm__ float*>(aPtrAddrs[startB + b]);
        __gm__ float* Tau = reinterpret_cast<__gm__ float*>(tauPtrAddrs[startB + b]);

        for (uint32_t i = 0; i < k; i++) {
            float sigma = ComputeSigma(m, lda, i, A, smem);
            ComputeTauAndNormalize(m, lda, i, sigma, A, Tau, smem);
            float tau = smem[1];
            if (tau != 0.0f) {
                ApplyRank1Update(m, n, lda, i, tau, A, smem);
            }
        }
    }
}

__global__ __aicore__ void sgeqrf_batched(GM_ADDR aarrayPtr, GM_ADDR tauarrayPtr, GM_ADDR tilingGm)
{
    KERNEL_TASK_TYPE_DEFAULT(KERNEL_TYPE_AIV_ONLY);

    const auto* td = reinterpret_cast<__gm__ GeqrfBatchedTilingData*>(tilingGm);
    uint32_t coreIdx = GetBlockIdx();
    uint32_t startB = coreIdx * td->batchPerCore;
    uint32_t numB = (coreIdx == td->usedCoreNum - 1) ? td->batchTail : td->batchPerCore;
    if (numB == 0) {
        return;
    }

    __gm__ float* aarrayBase = reinterpret_cast<__gm__ float*>(aarrayPtr);
    __gm__ float* tauarrayBase = reinterpret_cast<__gm__ float*>(tauarrayPtr);

    uint32_t warpSize = 32;
    uint32_t numThreads = ((td->m + warpSize - 1) / warpSize) * warpSize;
    if (numThreads < warpSize) {
        numThreads = warpSize;
    }
    if (numThreads > GEQRF_SIMT_THREADS) {
        numThreads = GEQRF_SIMT_THREADS;
    }

    asc_vf_call<GeqrfSimtFp32>(
        dim3{numThreads, 1, 1}, td->m, td->n, static_cast<uint64_t>(td->lda), numB, startB, aarrayBase, tauarrayBase,
        g_smem);
}

void sgeqrf_batched_kernel_do(
    GM_ADDR aarrayPtr, GM_ADDR tauarrayPtr, GM_ADDR tilingGm, uint32_t numBlocks, void* stream)
{
    sgeqrf_batched<<<numBlocks, 0, stream>>>(aarrayPtr, tauarrayPtr, tilingGm);
}