/**
 * Copyright (c) 2026 Huawei Technologies Co., Ltd.
 * This program is free software, you can redistribute it and/or modify it under the terms and conditions of
 * CANN Open Software License Agreement Version 2.0 (the "License").
 * Please refer to the License for details. You may not use this file except in compliance with the License.
 * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
 * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
 * See LICENSE in the root of the software repository for the full text of the License.
 */


/* !
 * \file clock.asc
 * \brief
 */

#include <algorithm>
#include <cstdint>
#include <iostream>
#include <iterator>
#include <vector>
#include "acl/acl.h"
#include "asc_simt.h"
#include "asc_printf.h"

__global__ void simt_gather(float* input, int32_t* index, float* output, uint64_t index_total_length)
{
    uint64_t start = clock();

    int32_t idx = blockIdx.x * blockDim.x + threadIdx.x;

    // Maps to the index of output tensor
    if (idx >= index_total_length) {
        return;
    }
    output[idx] = input[index[idx]];

    uint64_t end = clock();
    if (blockIdx.x == 0 && threadIdx.x == 0) {
        printf("%s execute cycle : %lu\n", "simt_gather", end - start);
    }
}


uint32_t verify_result(std::vector<float>& output, std::vector<float>& golden)
{
    auto print_tensor = [](std::vector<float>& tensor, const char* name) {
        constexpr size_t max_print_size = 20;
        std::cout << name << ": ";
        std::copy(tensor.begin(), tensor.begin() + std::min(tensor.size(), max_print_size),
            std::ostream_iterator<float>(std::cout, " "));
        if (tensor.size() > max_print_size) {
            std::cout << "...";
        }
        std::cout << std::endl;
    };
    print_tensor(output, "Output");
    print_tensor(golden, "Golden");
    if (std::equal(output.begin(), output.end(), golden.begin())) {
        std::cout << "[Success] Case accuracy is verification passed." << std::endl;
        return 0;
    } else {
        std::cout << "[Failed] Case accuracy is verification failed!" << std::endl;
        return 1;
    }
    return 0;
}

std::vector<float> run_gather(std::vector<float>& input, std::vector<int32_t>& index)
{
    if (input.empty() || index.empty()) {
        std::cout << "[ERROR] Empty input tensors." << std::endl;
        return {};
    }

    uint32_t input_total_length = input.size();
    uint32_t index_total_length = index.size();

    size_t input_total_byte_size = input_total_length * sizeof(float);
    size_t index_total_byte_size = index_total_length * sizeof(int32_t);
    size_t output_total_byte_size = index_total_length * sizeof(float);

    int32_t device_id = 0;
    aclrtStream stream = nullptr;

    uint8_t* input_host = reinterpret_cast<uint8_t *>(input.data());
    uint8_t* index_host = reinterpret_cast<uint8_t *>(index.data());
    uint8_t* output_host = nullptr;
    float* input_device = nullptr;
    int32_t* index_device = nullptr;
    float* output_device = nullptr;
    // Init
    aclInit(nullptr);
    aclrtSetDevice(device_id);
    aclrtCreateStream(&stream);

    // Allocate host and device memory, and copy input data from host to device
    aclrtMallocHost((void **)(&output_host), output_total_byte_size);
    aclrtMalloc((void **)&input_device, input_total_byte_size, ACL_MEM_MALLOC_HUGE_FIRST);
    aclrtMalloc((void **)&index_device, index_total_byte_size, ACL_MEM_MALLOC_HUGE_FIRST);
    aclrtMalloc((void **)&output_device, output_total_byte_size, ACL_MEM_MALLOC_HUGE_FIRST);
    aclrtMemcpy(input_device, input_total_byte_size, input_host, input_total_byte_size, ACL_MEMCPY_HOST_TO_DEVICE);
    aclrtMemcpy(index_device, index_total_byte_size, index_host, index_total_byte_size, ACL_MEMCPY_HOST_TO_DEVICE);

    // Configure kernel launch parameters
    uint32_t blocks_per_grid = 48; // Number of thread blocks (Grid size)
    uint32_t threads_per_block = 256; // Number of threads per block (Block size)
    uint32_t dyn_ubuf_size = 0;  // No dynamic memory required in this sample

    // Call kernel function with <<<...>>>
    simt_gather<<<blocks_per_grid, threads_per_block, dyn_ubuf_size, stream>>>(
        input_device, index_device, output_device, index_total_length);

    // Wait for the simt_gather kernel to complete
    aclrtSynchronizeStream(stream);

    // Copy the result from device memory to host memory
    aclrtMemcpy(output_host, output_total_byte_size, output_device, output_total_byte_size, ACL_MEMCPY_DEVICE_TO_HOST);
    std::vector<float> output((float *)output_host, (float *)(output_host + output_total_byte_size));

    // free memory
    aclrtFree(input_device);
    aclrtFree(index_device);
    aclrtFree(output_device);
    aclrtFreeHost(output_host);

    const char* err = aclGetRecentErrMsg();
    if (err != nullptr) {
        fprintf(stderr, "%s\n", err);
    }

    // DeInit
    aclrtDestroyStream(stream);
    aclrtResetDevice(device_id);
    aclFinalize();
    return output;
}

int32_t main(int32_t argc, char* argv[])
{
    constexpr uint32_t input_total_length = 100000;
    constexpr uint32_t index_total_length = 48 * 256;

    std::vector<float> input(input_total_length);
    for (uint32_t i = 0; i < input_total_length; i++) {
        input[i] = i * 1.2f;
    }

    std::vector<int32_t> index(index_total_length);
    for (uint32_t i = 0; i < index_total_length; i++) {
        index[i] = static_cast<int32_t>((i * 97 + 13) % input_total_length);
    }

    std::vector<float> golden(index_total_length);
    for (uint32_t i = 0; i < index_total_length; i++) {
        golden[i] = input[index[i]];
    }

    std::vector<float> output = run_gather(input, index);
    if (output.empty()) {
        return 1;
    }
    return verify_result(output, golden);
}