/**
 * Copyright (c) 2026 Huawei Technologies Co., Ltd.
 * This program is free software, you can redistribute it and/or modify it under the terms and conditions of
 * CANN Open Software License Agreement Version 2.0 (the "License").
 * Please refer to the License for details. You may not use this file except in compliance with the License.
 * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
 * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
 * See LICENSE in the root of the software repository for the full text of the License.
 */

/* !
 * \file symv.asc
 * \brief
 */

#include <algorithm>
#include <cstdint>
#include "acl/acl.h"
#include "cann_ops_blas.h"
#include "common/helper/aclblas_handle_internal.h"

extern void symv_kernel_do(
    uint8_t* a, uint8_t* x, uint8_t* y, uint8_t* z, uint8_t* workSpace, uint8_t* tilingGm,
    uint32_t numBlocks, void* stream);

constexpr uint32_t SYMV_MAX_CORE_NUM = 50;
constexpr uint32_t SYMV_TILE_SIZE = 128;
constexpr uint32_t SYMV_MAX_TILE_TASK = 4096;
constexpr uint32_t NUM_BLOCKS = 8;

struct SymvTilingData {
    uint32_t n;
    uint32_t lda;
    uint32_t useCoreNum;
    float alpha;
    float beta;
    int64_t incx;
    int64_t incy;
    uint32_t tileSize;
    uint32_t tileRows;
    uint32_t taskCount;
    uint16_t taskBi[SYMV_MAX_TILE_TASK];
    uint16_t taskBj[SYMV_MAX_TILE_TASK];
    uint8_t taskType[SYMV_MAX_TILE_TASK];
    uint32_t taskStart[SYMV_MAX_CORE_NUM];
    uint32_t taskStep[SYMV_MAX_CORE_NUM];
};

static SymvTilingData CalSymvTilingData(
    uint32_t totalRows, uint32_t lda, uint32_t vecCoreNum, float alpha, float beta, int64_t incx, int64_t incy)
{
    SymvTilingData tilingData{};
    tilingData.n = totalRows;
    tilingData.lda = lda;
    tilingData.alpha = alpha;
    tilingData.beta = beta;
    tilingData.incx = incx;
    tilingData.incy = incy;
    tilingData.tileSize = SYMV_TILE_SIZE;

    uint32_t taskCount = std::min(totalRows, SYMV_MAX_TILE_TASK);
    tilingData.taskCount = taskCount;

    uint32_t availableCoreNum = vecCoreNum == 0 ? 1U : vecCoreNum;
    if (availableCoreNum > SYMV_MAX_CORE_NUM) {
        availableCoreNum = SYMV_MAX_CORE_NUM;
    }
    tilingData.useCoreNum = std::min(taskCount, availableCoreNum);
    if (tilingData.useCoreNum == 0) {
        return tilingData;
    }

    for (uint32_t taskIdx = 0; taskIdx < taskCount; ++taskIdx) {
        tilingData.taskBi[taskIdx] = static_cast<uint16_t>(taskIdx);
        tilingData.taskBj[taskIdx] = 0;
        tilingData.taskType[taskIdx] = 0;
    }

    for (uint32_t i = 0; i < tilingData.useCoreNum; ++i) {
        tilingData.taskStart[i] = i;
        tilingData.taskStep[i] = tilingData.useCoreNum;
    }
    return tilingData;
}

aclblasStatus_t aclblasSsymv(
    aclblasHandle_t handle, aclblasFillMode_t uplo, int n, const float* alpha, const float* a, int lda,
    const float* x, int incx, const float* beta, float* y, int incy)
{
    (void)uplo;
    aclrtStream useStream = nullptr;
    if (handle != nullptr) {
        auto* h = reinterpret_cast<_aclblas_handle*>(handle);
        useStream = h->stream;
    }
    constexpr uint32_t numBlocks = 8;
    const size_t vecElementCount = static_cast<size_t>(n);
    const size_t matrixElementCount = static_cast<size_t>(n) * static_cast<size_t>(lda);
    const size_t vecByteSize = vecElementCount * sizeof(float);
    const size_t matrixByteSize = matrixElementCount * sizeof(float);

    SymvTilingData tiling = CalSymvTilingData(
        static_cast<uint32_t>(n), static_cast<uint32_t>(lda), numBlocks, *alpha, *beta, incx, incy);

    uint8_t* aDevice = nullptr;
    uint8_t* xDevice = nullptr;
    uint8_t* yDevice = nullptr;
    uint8_t* zDevice = nullptr;
    uint8_t* tilingDevice = nullptr;

    aclrtMalloc(reinterpret_cast<void**>(&aDevice), matrixByteSize, ACL_MEM_MALLOC_HUGE_FIRST);
    aclrtMalloc(reinterpret_cast<void**>(&xDevice), vecByteSize, ACL_MEM_MALLOC_HUGE_FIRST);
    aclrtMalloc(reinterpret_cast<void**>(&yDevice), vecByteSize, ACL_MEM_MALLOC_HUGE_FIRST);
    aclrtMalloc(reinterpret_cast<void**>(&zDevice), vecByteSize, ACL_MEM_MALLOC_HUGE_FIRST);
    aclrtMalloc(reinterpret_cast<void**>(&tilingDevice), sizeof(SymvTilingData), ACL_MEM_MALLOC_HUGE_FIRST);

    aclrtMemcpy(aDevice, matrixByteSize, a, matrixByteSize, ACL_MEMCPY_HOST_TO_DEVICE);
    aclrtMemcpy(xDevice, vecByteSize, x, vecByteSize, ACL_MEMCPY_HOST_TO_DEVICE);
    aclrtMemcpy(yDevice, vecByteSize, y, vecByteSize, ACL_MEMCPY_HOST_TO_DEVICE);
    aclrtMemcpy(tilingDevice, sizeof(SymvTilingData), &tiling, sizeof(SymvTilingData), ACL_MEMCPY_HOST_TO_DEVICE);

    symv_kernel_do(aDevice, xDevice, yDevice, zDevice, nullptr, tilingDevice, numBlocks, useStream);
    aclrtSynchronizeStream(useStream);
    aclrtMemcpy(y, vecByteSize, zDevice, vecByteSize, ACL_MEMCPY_DEVICE_TO_HOST);

    aclrtFree(aDevice);
    aclrtFree(xDevice);
    aclrtFree(yDevice);
    aclrtFree(zDevice);
    aclrtFree(tilingDevice);
    return ACLBLAS_STATUS_SUCCESS;
}