/*
 * Copyright (c) 2025 Huawei Technologies Co., Ltd.
 * This program is free software, you can redistribute it and/or modify it under
 * the terms and conditions of CANN Open Software License Agreement Version 2.0
 * (the "License"). Please refer to the License for details. You may not use
 * this file except in compliance with the License.
 * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY
 * KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO
 * NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
 * See LICENSE in the root of the software repository for the full text of the
 * License.
 */
/**
 * mhc_pre blockDim Sweep for Tuning
 */

#include "acl/acl.h"
#include <cstdint>
#include <cstdio>
#include <cstdlib>
#include <vector>
#include <algorithm>

extern "C" void mhc_pre_do_fp32(
    uint32_t blockDim, void* stream,
    uint8_t* input, uint8_t* h_pre, uint8_t* output,
    int64_t batch, int64_t seq_len, int64_t dim, int64_t num_streams
);

#define CHECK_ACL(x) do { \
    aclError err = (x); \
    if (err != ACL_SUCCESS) { \
        printf("ACL Error %d at %s:%d\n", err, __FILE__, __LINE__); \
        exit(1); \
    } \
} while(0)

double benchmark_with_blockdim(uint32_t blockDim, int64_t batch, int64_t seq_len,
                                int64_t dim, int64_t num_streams,
                                void* d_input, void* d_weight, void* d_output,
                                aclrtStream stream, int iters = 50) {
    for (int i = 0; i < 5; ++i) {
        mhc_pre_do_fp32(blockDim, stream, (uint8_t*)d_input, (uint8_t*)d_weight,
                        (uint8_t*)d_output, batch, seq_len, dim, num_streams);
    }
    CHECK_ACL(aclrtSynchronizeStream(stream));

    std::vector<double> times(iters);
    for (int i = 0; i < iters; ++i) {
        aclrtEvent start, end;
        CHECK_ACL(aclrtCreateEvent(&start));
        CHECK_ACL(aclrtCreateEvent(&end));
        CHECK_ACL(aclrtRecordEvent(start, stream));
        mhc_pre_do_fp32(blockDim, stream, (uint8_t*)d_input, (uint8_t*)d_weight,
                        (uint8_t*)d_output, batch, seq_len, dim, num_streams);
        CHECK_ACL(aclrtRecordEvent(end, stream));
        CHECK_ACL(aclrtSynchronizeStream(stream));
        float ms;
        CHECK_ACL(aclrtEventElapsedTime(&ms, start, end));
        times[i] = ms * 1000.0;
        CHECK_ACL(aclrtDestroyEvent(start));
        CHECK_ACL(aclrtDestroyEvent(end));
    }
    std::sort(times.begin(), times.end());
    return times[iters / 2];
}

void sweep_blockdim(int64_t batch, int64_t seq_len, int64_t dim, int64_t num_streams) {
    printf("\nSweeping blockDim for [B=%ld, S=%ld, D=%ld, N=%ld]\n", batch, seq_len, dim, num_streams);
    printf("Max logical blocks = %ld\n", batch);

    int64_t input_size = batch * num_streams * seq_len * dim;
    int64_t weight_size = num_streams;
    int64_t output_size = batch * seq_len * dim;

    void *d_input, *d_weight, *d_output;
    CHECK_ACL(aclrtMalloc(&d_input, input_size * sizeof(float), ACL_MEM_MALLOC_HUGE_FIRST));
    CHECK_ACL(aclrtMalloc(&d_weight, weight_size * sizeof(float), ACL_MEM_MALLOC_HUGE_FIRST));
    CHECK_ACL(aclrtMalloc(&d_output, output_size * sizeof(float), ACL_MEM_MALLOC_HUGE_FIRST));

    aclrtStream stream;
    CHECK_ACL(aclrtCreateStream(&stream));

    uint32_t max_blocks = static_cast<uint32_t>(batch);
    double best_time = 1e9;
    uint32_t best_blockdim = 0;

    printf("\n%10s %12s %12s\n", "blockDim", "median(us)", "rel");
    printf("%10s %12s %12s\n", "--------", "----------", "---");

    std::vector<uint32_t> test_dims;
    for (uint32_t bd = 1; bd <= max_blocks && bd <= 64; bd *= 2)
        test_dims.push_back(bd);
    if (max_blocks > 0 && std::find(test_dims.begin(), test_dims.end(), max_blocks) == test_dims.end())
        test_dims.push_back(max_blocks);
    std::sort(test_dims.begin(), test_dims.end());

    for (uint32_t bd : test_dims) {
        double t = benchmark_with_blockdim(bd, batch, seq_len, dim, num_streams,
                                           d_input, d_weight, d_output, stream);
        if (t < best_time) {
            best_time = t;
            best_blockdim = bd;
        }
        printf("%10u %12.1f %11.2fx\n", bd, t, t / best_time);
    }

    printf("\nBest: blockDim=%u (%.1f us)\n", best_blockdim, best_time);

    CHECK_ACL(aclrtFree(d_input));
    CHECK_ACL(aclrtFree(d_weight));
    CHECK_ACL(aclrtFree(d_output));
    CHECK_ACL(aclrtDestroyStream(stream));
}

int main() {
    CHECK_ACL(aclInit(nullptr));
    CHECK_ACL(aclrtSetDevice(0));

    printf("=== mhc_pre blockDim Sweep ===\n");

    sweep_blockdim(8, 128, 512, 4);
    sweep_blockdim(16, 256, 1024, 4);
    sweep_blockdim(32, 128, 512, 4);

    printf("\n=== Sweep Complete ===\n");

    CHECK_ACL(aclrtResetDevice(0));
    CHECK_ACL(aclFinalize());
    return 0;
}