/**
 * Copyright (c) 2026 Huawei Technologies Co., Ltd.
 * This program is free software, you can redistribute it and/or modify it under the terms and conditions of
 * CANN Open Software License Agreement Version 2.0 (the "License").
 * Please refer to the License for details. You may not use this file except in compliance with the License.
 * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
 * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
 * See LICENSE in the root of the software repository for the full text of the License.
 */

/*!
 * \file sscal_host.cpp
 * \brief Single-precision vector scaling: x = alpha * x
 *        Arch35 (ascend950) host-side implementation.
 *        Dual-path: AIV SIMD for incx==1, SIMT for incx!=1.
 */

#include <cstdint>
#include <algorithm>
#include "acl/acl.h"
#include "log/log.h"
#include "cann_ops_blas.h"
#include "common/helper/aclblas_handle_internal.h"
#include "common/helper/host_utils.h"
#include "common/helper/kernel_constant.h"
#include "sscal_tiling_data.h"

void sscal_kernel_do(uint8_t* x, uint8_t* workSpace, const SscalTilingData& tiling,
                     uint32_t numBlocks, void *stream);

static aclblasStatus_t ValidateSscalParams(aclblasHandle_t handle, const float* x, int incx)
{
    if (handle == nullptr) {
        OP_LOGE("aclblasSscal", "handle is nullptr");
        return ACLBLAS_STATUS_HANDLE_IS_NULLPTR;
    }
    if (x == nullptr) {
        OP_LOGE("aclblasSscal", "x must not be nullptr");
        return ACLBLAS_STATUS_INVALID_VALUE;
    }
    return ACLBLAS_STATUS_SUCCESS;
}

static SscalTilingData CalSscalTilingDataContiguous(uint32_t totalFloatNum, uint32_t aivCoreNum, float alpha)
{
    SscalTilingData tiling{};
    tiling.totalN = totalFloatNum;
    tiling.incx = 1;

    constexpr uint32_t alignUnit = ELEMENTS_PER_BLOCK;
    if (aivCoreNum == 0) {
        OP_LOGE("aclblasSscal", "aivCoreNum is 0, skip tiling calculation");
        return tiling;
    }
    uint32_t rawPerCore = totalFloatNum / aivCoreNum;
    tiling.perCoreN = (rawPerCore / alignUnit) * alignUnit;
    tiling.remainder = totalFloatNum - tiling.perCoreN * aivCoreNum;

    constexpr uint32_t queueCount = 2;
    uint32_t maxElements = UB_SIZE / (queueCount * sizeof(float));
    tiling.tileSize = (maxElements / alignUnit) * alignUnit;

    tiling.alpha = alpha;
    tiling.nthreads = 0;
    tiling.useCoreNum = 0;

    return tiling;
}

static SscalTilingData CalSscalTilingDataStrided(int64_t n, int64_t incx, uint32_t aivCoreNum, float alpha)
{
    SscalTilingData tiling{};
    tiling.totalN = static_cast<uint32_t>(n);
    tiling.incx = incx;
    tiling.alpha = alpha;

    uint32_t useCoreNum = std::min(aivCoreNum, static_cast<uint32_t>(n));
    if (useCoreNum == 0) {
        useCoreNum = 1;
    }
    if (useCoreNum > SSCAL_MAX_CORE_NUM) {
        useCoreNum = SSCAL_MAX_CORE_NUM;
    }
    tiling.useCoreNum = useCoreNum;

    uint32_t baseCount = static_cast<uint32_t>(n) / useCoreNum;
    uint32_t remain = static_cast<uint32_t>(n) % useCoreNum;
    uint32_t offset = 0;
    for (uint32_t i = 0; i < useCoreNum; i++) {
        tiling.startOffset[i] = offset;
        tiling.calCount[i] = baseCount + (i < remain ? 1 : 0);
        offset += tiling.calCount[i];
    }
    for (uint32_t i = useCoreNum; i < SSCAL_MAX_CORE_NUM; i++) {
        tiling.startOffset[i] = 0;
        tiling.calCount[i] = 0;
    }

    tiling.nthreads = std::min(
        CeilAlign<uint32_t>(CeilDiv<uint32_t>(static_cast<uint32_t>(n), useCoreNum), SIMT_MIN_THREAD_NUM),
        SIMT_MAX_THREAD_NUM);

    tiling.perCoreN = 0;
    tiling.remainder = 0;
    tiling.tileSize = 0;

    return tiling;
}

aclblasStatus_t aclblasSscal(aclblasHandle_t handle, int n, const float* alpha, float* x, int incx)
{
    if (n <= 0) {
        return ACLBLAS_STATUS_SUCCESS;
    }

    if (alpha == nullptr) {
        OP_LOGE("aclblasSscal", "alpha must not be nullptr");
        return ACLBLAS_STATUS_INVALID_VALUE;
    }

    aclblasStatus_t status = ValidateSscalParams(handle, x, incx);
    if (status != ACLBLAS_STATUS_SUCCESS) {
        return status;
    }

    if (incx <= 0) {
        return ACLBLAS_STATUS_SUCCESS;
    }

    uint32_t aivCoreNum = GetAivCoreCount();
    if (aivCoreNum == 0) {
        OP_LOGE("aclblasSscal", "GetAivCoreCount failed");
        return ACLBLAS_STATUS_EXECUTION_FAILED;
    }

    auto* h = reinterpret_cast<_aclblas_handle*>(handle);
    uint32_t totalN = static_cast<uint32_t>(n);

    SscalTilingData tiling;
    uint32_t numBlocks;

    if (incx == 1) {
        numBlocks = (totalN < aivCoreNum) ? totalN : aivCoreNum;
        tiling = CalSscalTilingDataContiguous(totalN, numBlocks, *alpha);
    } else {
        numBlocks = std::min(CeilDiv<uint32_t>(totalN, SIMT_MIN_THREAD_NUM), aivCoreNum);
        tiling = CalSscalTilingDataStrided(n, incx, numBlocks, *alpha);
    }

    OP_LOGD(
        "aclblasSscal", "tiling: n=%d incx=%d numBlocks=%u alpha=%.6f", n, incx, numBlocks,
        static_cast<double>(*alpha));

    sscal_kernel_do(reinterpret_cast<uint8_t*>(x), nullptr, tiling, numBlocks, h->stream);

    return ACLBLAS_STATUS_SUCCESS;
}