* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under
* the terms and conditions of CANN Open Software License Agreement Version 2.0
* (the "License"). Please refer to the License for details. You may not use
* this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY
* KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO
* NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the
* License.
*/
* mhc_pre Performance Benchmark
*/
#include "acl/acl.h"
#include <cstdint>
#include <cstdio>
#include <cstdlib>
#include <vector>
#include <algorithm>
extern "C" void mhc_pre_do_fp32(
uint32_t blockDim, void* stream,
uint8_t* input, uint8_t* h_pre, uint8_t* output,
int64_t batch, int64_t seq_len, int64_t dim, int64_t num_streams
);
extern "C" void mhc_pre_do_fp16(
uint32_t blockDim, void* stream,
uint8_t* input, uint8_t* h_pre, uint8_t* output,
int64_t batch, int64_t seq_len, int64_t dim, int64_t num_streams
);
extern "C" void mhc_pre_do_bf16(
uint32_t blockDim, void* stream,
uint8_t* input, uint8_t* h_pre_fp32, uint8_t* output,
int64_t batch, int64_t seq_len, int64_t dim, int64_t num_streams
);
#define CHECK_ACL(x) do { \
aclError err = (x); \
if (err != ACL_SUCCESS) { \
printf("ACL Error %d at %s:%d\n", err, __FILE__, __LINE__); \
exit(1); \
} \
} while(0)
struct BenchResult {
double median_us;
double min_us;
double max_us;
double gbps;
};
BenchResult benchmark_fp32(int64_t batch, int64_t seq_len, int64_t dim, int64_t num_streams,
int warmup = 10, int iters = 100) {
int64_t input_size = batch * num_streams * seq_len * dim;
int64_t weight_size = num_streams;
int64_t output_size = batch * seq_len * dim;
void *d_input, *d_weight, *d_output;
CHECK_ACL(aclrtMalloc(&d_input, input_size * sizeof(float), ACL_MEM_MALLOC_HUGE_FIRST));
CHECK_ACL(aclrtMalloc(&d_weight, weight_size * sizeof(float), ACL_MEM_MALLOC_HUGE_FIRST));
CHECK_ACL(aclrtMalloc(&d_output, output_size * sizeof(float), ACL_MEM_MALLOC_HUGE_FIRST));
aclrtStream stream;
CHECK_ACL(aclrtCreateStream(&stream));
for (int i = 0; i < warmup; ++i) {
mhc_pre_do_fp32(0, stream, (uint8_t*)d_input, (uint8_t*)d_weight,
(uint8_t*)d_output, batch, seq_len, dim, num_streams);
}
CHECK_ACL(aclrtSynchronizeStream(stream));
std::vector<double> times(iters);
for (int i = 0; i < iters; ++i) {
aclrtEvent start, end;
CHECK_ACL(aclrtCreateEvent(&start));
CHECK_ACL(aclrtCreateEvent(&end));
CHECK_ACL(aclrtRecordEvent(start, stream));
mhc_pre_do_fp32(0, stream, (uint8_t*)d_input, (uint8_t*)d_weight,
(uint8_t*)d_output, batch, seq_len, dim, num_streams);
CHECK_ACL(aclrtRecordEvent(end, stream));
CHECK_ACL(aclrtSynchronizeStream(stream));
float ms;
CHECK_ACL(aclrtEventElapsedTime(&ms, start, end));
times[i] = ms * 1000.0;
CHECK_ACL(aclrtDestroyEvent(start));
CHECK_ACL(aclrtDestroyEvent(end));
}
CHECK_ACL(aclrtFree(d_input));
CHECK_ACL(aclrtFree(d_weight));
CHECK_ACL(aclrtFree(d_output));
CHECK_ACL(aclrtDestroyStream(stream));
std::sort(times.begin(), times.end());
double median = times[iters / 2];
double bytes = (input_size + output_size) * sizeof(float) + weight_size * sizeof(float);
double gbps = bytes / (median * 1e-6) / 1e9;
return {median, times[0], times[iters - 1], gbps};
}
void run_benchmark(const char* name, int64_t batch, int64_t seq_len, int64_t dim, int64_t num_streams) {
auto r = benchmark_fp32(batch, seq_len, dim, num_streams);
printf("%-20s B=%2ld S=%4ld D=%4ld N=%ld median=%.1f us [%.1f-%.1f] %.1f GB/s\n",
name, batch, seq_len, dim, num_streams, r.median_us, r.min_us, r.max_us, r.gbps);
}
int main() {
CHECK_ACL(aclInit(nullptr));
CHECK_ACL(aclrtSetDevice(0));
printf("=== mhc_pre Performance Benchmark (fp32) ===\n\n");
run_benchmark("small", 4, 64, 256, 4);
run_benchmark("medium", 8, 128, 512, 4);
run_benchmark("large", 16, 256, 1024, 4);
run_benchmark("xlarge", 32, 512, 2048, 4);
run_benchmark("streams=8", 8, 128, 512, 8);
run_benchmark("streams=2", 8, 128, 512, 2);
printf("\n=== Benchmark Complete ===\n");
CHECK_ACL(aclrtResetDevice(0));
CHECK_ACL(aclFinalize());
return 0;
}