/**
 * Copyright (c) 2026 Huawei Technologies Co., Ltd.
 * This program is free software, you can redistribute it and/or modify it under the terms and conditions of
 * CANN Open Software License Agreement Version 2.0 (the "License").
 * Please refer to the License for details. You may not use this file except in compliance with the License.
 * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
 * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FITNESS FOR A PARTICULAR PURPOSE.
 * See LICENSE in the root of the software repository for the full text of the License.
 */

/*!
 * \file main.cpp
 * \brief Vector Addition Example (C API)
 */

#include <iostream>
#include <random>
#include <algorithm>
#include <memory>
#include <vector>
#include "acl/acl.h"
#include "c_api/asc_simd.h"
#include "platform/platform_ascendc.h"

#define CHECK_ACL(call)                                              \
    do {                                                             \
        aclError err = (call);                                       \
        if (err != ACL_SUCCESS) {                                    \
            std::cerr << "ACL error: " << err << " at " << __FILE__ \
                      << ":" << __LINE__ << std::endl;              \
            return 1;                                                \
        }                                                            \
    } while (0)

struct AclrtFreeDeleter {
    void operator()(void* ptr) const {
        if (ptr != nullptr) {
            aclrtFree(ptr);
        }
    }
};

constexpr uint32_t TILE_LENGTH = 2048;

__vector__ __global__ __aicore__ void add_kernel(
    __gm__ float* x, __gm__ float* y, __gm__ float* z,
    int64_t totalLength, int64_t blockLength)
{
    // asc_init: Initialize the Vector Core hardware state for the current AI Core.
    // It clears the atomic mode (set_atomic_none), switches the vector mask to
    // normal mode (set_mask_norm), and enables all mask bits (set_vector_mask(-1, -1)).
    // It must be called before any asc_* API calls; otherwise, the vector mask and
    // atomic mode are in an undefined state, causing vector operations to produce
    // incorrect results.
    asc_init();

    __ubuf__ float xLocal[TILE_LENGTH];
    __ubuf__ float yLocal[TILE_LENGTH];
    __ubuf__ float zLocal[TILE_LENGTH];

    int64_t blockIdx = asc_get_block_idx();
    int64_t currentBlockLength = totalLength - blockIdx * blockLength;
    if (currentBlockLength <= 0) return;
    currentBlockLength = currentBlockLength > blockLength ? blockLength : currentBlockLength;

    int64_t tileNum = (currentBlockLength + TILE_LENGTH - 1) / TILE_LENGTH;
    int64_t baseOffset = blockIdx * blockLength;

    for (int64_t i = 0; i < tileNum; ++i) {
        int64_t offset = baseOffset + i * TILE_LENGTH;
        int64_t remaining = currentBlockLength - i * TILE_LENGTH;
        int64_t currentTileLength = (remaining > TILE_LENGTH) ? TILE_LENGTH : remaining;
        asc_copy_gm2ub_sync((__ubuf__ void*)xLocal, (__gm__ void*)(x + offset), currentTileLength * sizeof(float));
        asc_copy_gm2ub_sync((__ubuf__ void*)yLocal, (__gm__ void*)(y + offset), currentTileLength * sizeof(float));
        asc_add_sync(zLocal, xLocal, yLocal, currentTileLength);
        asc_copy_ub2gm_sync((__gm__ void*)(z + offset), (__ubuf__ void*)zLocal, currentTileLength * sizeof(float));
    }
}

int run_vector_add(aclrtStream stream, int64_t numElements)
{
    std::random_device rd;
    std::mt19937 gen(rd());
    std::uniform_real_distribution<float> dist(0.0f, 10.0f);

    size_t size = static_cast<size_t>(numElements) * sizeof(float);

    std::vector<float> h_A(numElements);
    std::vector<float> h_B(numElements);
    std::vector<float> h_C(numElements);

    for (int64_t i = 0; i < numElements; ++i) {
        h_A[i] = dist(gen);
        h_B[i] = dist(gen);
        h_C[i] = 0.0f;
    }

    uint8_t *d_A = nullptr;
    uint8_t *d_B = nullptr;
    uint8_t *d_C = nullptr;
    CHECK_ACL(aclrtMalloc((void **)&d_A, size, ACL_MEM_MALLOC_HUGE_FIRST));
    CHECK_ACL(aclrtMalloc((void **)&d_B, size, ACL_MEM_MALLOC_HUGE_FIRST));
    CHECK_ACL(aclrtMalloc((void **)&d_C, size, ACL_MEM_MALLOC_HUGE_FIRST));
    std::unique_ptr<void, AclrtFreeDeleter> d_A_guard(d_A);
    std::unique_ptr<void, AclrtFreeDeleter> d_B_guard(d_B);
    std::unique_ptr<void, AclrtFreeDeleter> d_C_guard(d_C);

    CHECK_ACL(aclrtMemcpy(d_A, size, h_A.data(), size, ACL_MEMCPY_HOST_TO_DEVICE));
    CHECK_ACL(aclrtMemcpy(d_B, size, h_B.data(), size, ACL_MEMCPY_HOST_TO_DEVICE));

    int64_t coreNum = platform_ascendc::PlatformAscendCManager::GetInstance()->GetCoreNumAiv();
    int64_t numBlocks = std::min(coreNum, (numElements + TILE_LENGTH - 1) / TILE_LENGTH);
    int64_t blockLength = (numElements + numBlocks - 1) / numBlocks;

    CHECK_ACL(aclrtSynchronizeStream(stream));
    add_kernel<<<numBlocks, nullptr, stream>>>(
        (float*)d_A, (float*)d_B, (float*)d_C,
        numElements, blockLength);
    CHECK_ACL(aclrtSynchronizeStream(stream));

    CHECK_ACL(aclrtMemcpy(h_C.data(), size, d_C, size, ACL_MEMCPY_DEVICE_TO_HOST));
    CHECK_ACL(aclrtSynchronizeStream(stream));

    bool success = true;
    for (int64_t i = 0; i < numElements; ++i) {
        if (h_C[i] != h_A[i] + h_B[i]) {
            success = false;
            break;
        }
    }

    if (success) {
        std::cout << "Vector add completed successfully!" << std::endl;
    } else {
        std::cout << "Vector add failed!" << std::endl;
    }

    return success ? 0 : 1;
}

int main()
{
    CHECK_ACL(aclInit(nullptr));
    int32_t deviceId = 0;
    CHECK_ACL(aclrtSetDevice(deviceId));
    aclrtStream stream = nullptr;
    CHECK_ACL(aclrtCreateStream(&stream));

    int result = run_vector_add(stream, 409600);

    CHECK_ACL(aclrtDestroyStream(stream));
    CHECK_ACL(aclrtResetDevice(deviceId));
    CHECK_ACL(aclFinalize());

    return result;
}