/**
 * Copyright (c) 2026 Huawei Technologies Co., Ltd.
 * This program is free software, you can redistribute it and/or modify it under the terms and conditions of
 * CANN Open Software License Agreement Version 2.0 (the "License").
 * Please refer to the License for details. You may not use this file except in compliance with the License.
 * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
 * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
 * See LICENSE in the root of the software repository for the full text of the License.
 */

/*!
 * \file main.asc
 * \brief Pure asc_vf_call Example — VF definition and invocation, no data, no DataCopy
 */

#include <iostream>
#include "acl/acl.h"
#include "kernel_operator.h"

constexpr static int64_t kNumElements = 16384;
constexpr static uint32_t kTileSize = 64 * 1024;  // 64KB UB tile

#define CHECK_ACL(call)                                              \
    do {                                                             \
        aclError err = (call);                                       \
        if (err != ACL_SUCCESS) {                                    \
            std::cerr << "ACL error: " << err << " at " << __FILE__  \
                      << ":" << __LINE__ << std::endl;               \
            return 1;                                                \
        }                                                            \
    } while (0)

// ---------------------------------------------------------------------------
// VF: vector function using AscendC::Reg API
// ---------------------------------------------------------------------------
__simd_vf__ inline void add_vf(__ubuf__ float *xAddr, __ubuf__ float *yAddr, __ubuf__ float *zAddr, uint32_t n, uint32_t loopNum)
{
    constexpr static uint32_t vectorLength = AscendC::VECTOR_REG_WIDTH / sizeof(float);
    AscendC::Reg::RegTensor<float> xReg, yReg, zReg;
    AscendC::Reg::MaskReg pMask;

    for (uint16_t i = 0; i < loopNum; ++i) {
        pMask = AscendC::Reg::UpdateMask<float>(n);
        AscendC::Reg::LoadAlign(xReg, (__ubuf__ float *)xAddr + i * vectorLength);
        AscendC::Reg::LoadAlign(yReg, (__ubuf__ float *)yAddr + i * vectorLength);
        AscendC::Reg::Add(zReg, xReg, yReg, pMask);
        AscendC::Reg::StoreAlign<float, AscendC::Reg::StoreDist::DIST_NORM_B32>(
            (__ubuf__ float *)zAddr + i * vectorLength, zReg, pMask);
    }
}

// ---------------------------------------------------------------------------
// VF: multiply-add — z = x * y + z
// ---------------------------------------------------------------------------
__simd_vf__ inline void muladd_vf(__ubuf__ float *xAddr, __ubuf__ float *yAddr, __ubuf__ float *zAddr, uint32_t n, uint32_t loopNum)
{
    constexpr static uint32_t vectorLength = AscendC::VECTOR_REG_WIDTH / sizeof(float);
    AscendC::Reg::RegTensor<float> xReg, yReg, zReg, tmpReg;
    AscendC::Reg::MaskReg pMask;

    for (uint16_t i = 0; i < loopNum; ++i) {
        pMask = AscendC::Reg::UpdateMask<float>(n);
        AscendC::Reg::LoadAlign(xReg, (__ubuf__ float *)xAddr + i * vectorLength);
        AscendC::Reg::LoadAlign(yReg, (__ubuf__ float *)yAddr + i * vectorLength);
        AscendC::Reg::LoadAlign(zReg, (__ubuf__ float *)zAddr + i * vectorLength);
        AscendC::Reg::Mul(tmpReg, xReg, yReg, pMask);
        AscendC::Reg::Add(zReg, tmpReg, zReg, pMask);
        AscendC::Reg::StoreAlign<float, AscendC::Reg::StoreDist::DIST_NORM_B32>(
            (__ubuf__ float *)zAddr + i * vectorLength, zReg, pMask);
    }
}

// ---------------------------------------------------------------------------
// VF: GELU activation — y = x / (1 + exp(-1.595769121 * (x + 0.044715 * x^3)))
// ---------------------------------------------------------------------------
__simd_vf__ inline void gelu_vf(__ubuf__ float *xAddr, __ubuf__ float *yAddr, uint32_t n, uint32_t loopNum)
{
    const float NEG_SQRT_EIGHT_OVER_PI = -1.595769121 * 0.044715;
    const float TANH_APPROX_FACTOR = 1 / 0.044715;
    constexpr static uint32_t vectorLength = AscendC::VECTOR_REG_WIDTH / sizeof(float);
    AscendC::Reg::RegTensor<float> xReg, yReg, cubeReg, tReg;
    AscendC::Reg::MaskReg pMask;

    for (uint16_t i = 0; i < loopNum; ++i) {
        pMask = AscendC::Reg::UpdateMask<float>(n);
        AscendC::Reg::LoadAlign(xReg, (__ubuf__ float *)xAddr + i * vectorLength);
        AscendC::Reg::Mul(cubeReg, xReg, xReg, pMask);
        AscendC::Reg::Mul(cubeReg, cubeReg, xReg, pMask);
        AscendC::Reg::Muls(tReg, xReg, TANH_APPROX_FACTOR, pMask);
        AscendC::Reg::Add(cubeReg, cubeReg, tReg, pMask);
        AscendC::Reg::Muls(cubeReg, cubeReg, NEG_SQRT_EIGHT_OVER_PI, pMask);
        AscendC::Reg::Exp(cubeReg, cubeReg, pMask);
        AscendC::Reg::Adds(cubeReg, cubeReg, 1.0f, pMask);
        AscendC::Reg::Div(yReg, xReg, cubeReg, pMask);
        AscendC::Reg::StoreAlign<float, AscendC::Reg::StoreDist::DIST_NORM_B32>(
            (__ubuf__ float *)yAddr + i * vectorLength, yReg, pMask);
    }
}

// ---------------------------------------------------------------------------
// VF: ReduceSum pattern — [256, 64] -> [256, 1]
// For each row: reduce-sum across the 64 columns, then store the single
// per-row scalar contiguously to outAddr.
// Shape is fixed and already aligned, so no separate aligned stride is needed.
// ---------------------------------------------------------------------------
__simd_vf__ inline void reducesum_vf(__ubuf__ float *xAddr, __ubuf__ float *outAddr)
{
    constexpr static uint32_t kNumRow = 256;
    constexpr static uint32_t kNumCol = 64;

    AscendC::Reg::RegTensor<float> rowReg, rowSumReg;

    // Mask selecting a single element — used for the first-element store.
    AscendC::Reg::MaskReg singleElemMask = AscendC::Reg::CreateMask<float, AscendC::Reg::MaskPattern::VL1>();
    // Mask selecting the kNumCol valid elements of each row.
    uint32_t colCount = kNumCol;
    AscendC::Reg::MaskReg rowMask = AscendC::Reg::UpdateMask<float>(colCount);

    for (uint16_t i = 0; i < kNumRow; ++i) {
        // LOAD: row i from memory → register
        AscendC::Reg::LoadAlign(rowReg, (__ubuf__ float *)xAddr + i * kNumCol);
        AscendC::Reg::ReduceSum(rowSumReg, rowReg, rowMask);
        // STORE: register → memory (first element only), contiguous at outAddr + i
        AscendC::Reg::StoreAlign<float, AscendC::Reg::StoreDist::DIST_FIRST_ELEMENT_B32>(
            (__ubuf__ float *)outAddr + i, rowSumReg, singleElemMask);
    }
}

// ---------------------------------------------------------------------------
// VF: ReduceSum with UNALIGNED output store — [256, 64] -> [256, 1]
// Same per-row reduction as reducesum_vf. The input is aligned so the row is
// read with the normal aligned load, but the per-row scalar is streamed out
// with the unaligned store protocol (StoreUnAlign + StoreUnAlignPost →
// vstus/vstas). The unaligned store needs an explicit UnalignReg scratch plus
// a Post flush stage and has no StoreDist, so dst is advanced by a post-update
// stride instead.
// ---------------------------------------------------------------------------
__simd_vf__ inline void reducesum_unalign_vf(__ubuf__ float *xAddr, __ubuf__ float *outAddr)
{
    constexpr static uint32_t kNumRow = 256;
    constexpr static uint32_t kNumCol = 64;

    AscendC::Reg::RegTensor<float> rowReg, rowSumReg;
    AscendC::Reg::UnalignReg storeUReg;  // scratch carry for the unaligned store

    // Mask selecting the kNumCol valid elements of each row.
    uint32_t colCount = kNumCol;
    AscendC::Reg::MaskReg rowMask = AscendC::Reg::UpdateMask<float>(colCount);

    __ubuf__ float *dst = outAddr;
    for (uint16_t i = 0; i < kNumRow; ++i) {
        // LOAD (aligned): row i from memory → register
        AscendC::Reg::LoadAlign(rowReg, (__ubuf__ float *)xAddr + i * kNumCol);
        AscendC::Reg::ReduceSum(rowSumReg, rowReg, rowMask);
        // STORE (unaligned): stream the per-row scalar; dst advances by 1
        AscendC::Reg::StoreUnAlign<float>(dst, rowSumReg, storeUReg, 1);
    }
    // STORE post: flush the trailing partial held in the store carry.
    AscendC::Reg::StoreUnAlignPost<float>(dst, storeUReg, 0);
}

// ---------------------------------------------------------------------------
// Warmup kernel: no-op to bring the NPU pipeline out of cold-start state
// ---------------------------------------------------------------------------
__global__ __aicore__ __vector__ void warmup_kernel()
{
    AscendC::Nop<8>();
}

// ---------------------------------------------------------------------------
// Minimal kernel: __ubuf__ arrays + asc_vf_call, no TPipe/TQue/DataCopy
// ---------------------------------------------------------------------------
__global__ __aicore__ __vector__ void pure_vf_kernel()
{
    constexpr static uint32_t vectorLength = AscendC::VECTOR_REG_WIDTH / sizeof(float);
    uint32_t loopNum = (kNumElements + vectorLength - 1) / vectorLength;

    __ubuf__ float xLocal[kNumElements];
    __ubuf__ float yLocal[kNumElements];
    __ubuf__ float zLocal[kNumElements];

    for (int i = 0; i < 3; ++i) {
        asc_vf_call<add_vf>(xLocal, yLocal, zLocal, static_cast<uint32_t>(kNumElements), loopNum);
    }
    for (int i = 0; i < 3; ++i) {
        asc_vf_call<muladd_vf>(xLocal, yLocal, zLocal, static_cast<uint32_t>(kNumElements), loopNum);
    }
    for (int i = 0; i < 3; ++i) {
        asc_vf_call<gelu_vf>(xLocal, zLocal, static_cast<uint32_t>(kNumElements), loopNum);
    }
    // ReduceSum: xLocal viewed as [256, 64] -> zLocal as [256, 1]
    for (int i = 0; i < 3; ++i) {
        asc_vf_call<reducesum_vf>(xLocal, zLocal);
    }
    // ReduceSum with unaligned register I/O: [256, 64] -> [256, 1]
    for (int i = 0; i < 3; ++i) {
        asc_vf_call<reducesum_unalign_vf>(xLocal, zLocal);
    }
}

int main()
{
    CHECK_ACL(aclInit(nullptr));
    int32_t deviceId = 0;
    CHECK_ACL(aclrtSetDevice(deviceId));
    aclrtStream stream = nullptr;
    CHECK_ACL(aclrtCreateStream(&stream));

    // Launch the warmup kernel to mitigate cold-start overhead for the subsequent pure VF kernel.
    warmup_kernel<<<1, nullptr, stream>>>();
    pure_vf_kernel<<<1, nullptr, stream>>>();
    CHECK_ACL(aclrtSynchronizeStream(stream));

    std::cout << "Kernel launched successfully!" << std::endl;

    CHECK_ACL(aclrtDestroyStream(stream));
    CHECK_ACL(aclrtResetDevice(deviceId));
    CHECK_ACL(aclFinalize());

    return 0;
}