* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under
* the terms and conditions of CANN Open Software License Agreement Version 2.0
* (the "License"). Please refer to the License for details. You may not use
* this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY
* KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO
* NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the
* License.
*/
* mhc_post Edge Cases Test (fp32/fp16/bf16)
* Tests unaligned dim, extreme batch/seq/streams, minimal cases
*
* Precision criteria:
* fp32: bit-exact (ULP = 0)
* fp16: allclose(atol=1e-4, rtol=1e-3)
* bf16: allclose(atol=1e-3, rtol=4e-3)
*/
#include "acl/acl.h"
#include <cstdint>
#include <cstdio>
#include <cstdlib>
#include <cmath>
#include <vector>
#include <cstring>
#include <random>
template <typename To, typename From>
inline To bit_copy(const From& src) {
static_assert(sizeof(To) == sizeof(From), "size mismatch");
To dst;
std::memcpy(&dst, &src, sizeof(To));
return dst;
}
extern "C" void mhc_post_do_fp32(
uint32_t blockDim, void* stream, uint8_t* in, uint8_t* h, uint8_t* out,
int64_t batch, int64_t seq_len, int64_t dim, int64_t num_streams);
extern "C" void mhc_post_do_fp16(
uint32_t blockDim, void* stream, uint8_t* in, uint8_t* h, uint8_t* out,
int64_t batch, int64_t seq_len, int64_t dim, int64_t num_streams);
extern "C" void mhc_post_do_bf16(
uint32_t blockDim, void* stream, uint8_t* in, uint8_t* h_fp32, uint8_t* out,
int64_t batch, int64_t seq_len, int64_t dim, int64_t num_streams);
#define CHECK_ACL(x) do { \
aclError err = (x); \
if (err != ACL_SUCCESS) { \
printf("ACL Error %d at %s:%d\n", err, __FILE__, __LINE__); \
exit(1); \
} \
} while(0)
uint16_t float_to_half(float f) {
uint32_t x = bit_copy<uint32_t>(f);
uint16_t sign = (x >> 16) & 0x8000;
int32_t exp = ((x >> 23) & 0xFF) - 127 + 15;
uint32_t mant = x & 0x7FFFFF;
if (exp <= 0) return sign;
if (exp >= 31) return sign | 0x7C00;
return sign | (exp << 10) | (mant >> 13);
}
float half_to_float(uint16_t h) {
uint32_t sign = (h & 0x8000) << 16;
int32_t exp = (h >> 10) & 0x1F;
uint32_t mant = h & 0x3FF;
uint32_t result;
if (exp == 0) result = sign;
else if (exp == 31) result = sign | 0x7F800000 | (mant << 13);
else result = sign | ((exp - 15 + 127) << 23) | (mant << 13);
return bit_copy<float>(result);
}
uint16_t float_to_bf16(float f) {
uint32_t x = bit_copy<uint32_t>(f);
return (uint16_t)(x >> 16);
}
float bf16_to_float(uint16_t h) {
uint32_t x = (uint32_t)h << 16;
return bit_copy<float>(x);
}
void cpu_ref(const float* in, const float* h, float* out,
int64_t batch, int64_t seq, int64_t dim, int64_t streams) {
int64_t elems = seq * dim;
for (int64_t b = 0; b < batch; ++b)
for (int64_t s = 0; s < streams; ++s)
for (int64_t i = 0; i < elems; ++i)
out[(b * streams + s) * elems + i] = in[b * elems + i] * h[s];
}
bool allclose(const float* a, const float* b, int64_t n, float atol, float rtol) {
for (int64_t i = 0; i < n; ++i) {
float err = std::abs(a[i] - b[i]);
if (err > atol + rtol * std::abs(b[i])) return false;
}
return true;
}
bool bit_exact(const float* a, const float* b, int64_t n) {
for (int64_t i = 0; i < n; ++i) {
uint32_t x, y;
x = bit_copy<uint32_t>(a[i]); y = bit_copy<uint32_t>(b[i]);
if (x != y) return false;
}
return true;
}
bool test_fp32(const char* name, int64_t batch, int64_t seq, int64_t dim, int64_t streams) {
int64_t in_sz = batch * seq * dim;
int64_t out_sz = batch * streams * seq * dim;
std::vector<float> h_in(in_sz), h_w(streams), h_out(out_sz), h_ref(out_sz);
std::mt19937 rng(42);
std::uniform_real_distribution<float> dist(-1.0f, 1.0f);
std::uniform_real_distribution<float> dist_pos(0.1f, 1.1f);
for (int64_t i = 0; i < in_sz; ++i) h_in[i] = dist(rng);
float sum = 0;
for (int64_t i = 0; i < streams; ++i) { h_w[i] = dist_pos(rng); sum += h_w[i]; }
for (int64_t i = 0; i < streams; ++i) h_w[i] /= sum;
cpu_ref(h_in.data(), h_w.data(), h_ref.data(), batch, seq, dim, streams);
void *d_in, *d_w, *d_out;
CHECK_ACL(aclrtMalloc(&d_in, in_sz * 4, ACL_MEM_MALLOC_HUGE_FIRST));
CHECK_ACL(aclrtMalloc(&d_w, streams * 4, ACL_MEM_MALLOC_HUGE_FIRST));
CHECK_ACL(aclrtMalloc(&d_out, out_sz * 4, ACL_MEM_MALLOC_HUGE_FIRST));
CHECK_ACL(aclrtMemcpy(d_in, in_sz * 4, h_in.data(), in_sz * 4, ACL_MEMCPY_HOST_TO_DEVICE));
CHECK_ACL(aclrtMemcpy(d_w, streams * 4, h_w.data(), streams * 4, ACL_MEMCPY_HOST_TO_DEVICE));
aclrtStream stream; CHECK_ACL(aclrtCreateStream(&stream));
mhc_post_do_fp32(batch * streams, stream, (uint8_t*)d_in, (uint8_t*)d_w, (uint8_t*)d_out, batch, seq, dim, streams);
CHECK_ACL(aclrtSynchronizeStream(stream));
CHECK_ACL(aclrtMemcpy(h_out.data(), out_sz * 4, d_out, out_sz * 4, ACL_MEMCPY_DEVICE_TO_HOST));
bool pass = bit_exact(h_out.data(), h_ref.data(), out_sz);
printf(" [fp32] %s: %s\n", name, pass ? "PASS" : "FAIL");
CHECK_ACL(aclrtFree(d_in)); CHECK_ACL(aclrtFree(d_w)); CHECK_ACL(aclrtFree(d_out));
CHECK_ACL(aclrtDestroyStream(stream));
return pass;
}
bool test_fp16(const char* name, int64_t batch, int64_t seq, int64_t dim, int64_t streams) {
int64_t in_sz = batch * seq * dim;
int64_t out_sz = batch * streams * seq * dim;
std::vector<float> h_in_f(in_sz), h_w_f(streams), h_ref(out_sz), h_out_f(out_sz);
std::vector<uint16_t> h_in(in_sz), h_w(streams), h_out(out_sz);
std::mt19937 rng(42);
std::uniform_real_distribution<float> dist(-1.0f, 1.0f);
std::uniform_real_distribution<float> dist_pos(0.1f, 1.1f);
for (int64_t i = 0; i < in_sz; ++i) {
h_in_f[i] = dist(rng);
h_in[i] = float_to_half(h_in_f[i]);
h_in_f[i] = half_to_float(h_in[i]);
}
float sum = 0;
for (int64_t i = 0; i < streams; ++i) { h_w_f[i] = dist_pos(rng); sum += h_w_f[i]; }
for (int64_t i = 0; i < streams; ++i) {
h_w_f[i] /= sum;
h_w[i] = float_to_half(h_w_f[i]);
h_w_f[i] = half_to_float(h_w[i]);
}
cpu_ref(h_in_f.data(), h_w_f.data(), h_ref.data(), batch, seq, dim, streams);
void *d_in, *d_w, *d_out;
CHECK_ACL(aclrtMalloc(&d_in, in_sz * 2, ACL_MEM_MALLOC_HUGE_FIRST));
CHECK_ACL(aclrtMalloc(&d_w, streams * 2, ACL_MEM_MALLOC_HUGE_FIRST));
CHECK_ACL(aclrtMalloc(&d_out, out_sz * 2, ACL_MEM_MALLOC_HUGE_FIRST));
CHECK_ACL(aclrtMemcpy(d_in, in_sz * 2, h_in.data(), in_sz * 2, ACL_MEMCPY_HOST_TO_DEVICE));
CHECK_ACL(aclrtMemcpy(d_w, streams * 2, h_w.data(), streams * 2, ACL_MEMCPY_HOST_TO_DEVICE));
aclrtStream stream; CHECK_ACL(aclrtCreateStream(&stream));
mhc_post_do_fp16(batch * streams, stream, (uint8_t*)d_in, (uint8_t*)d_w, (uint8_t*)d_out, batch, seq, dim, streams);
CHECK_ACL(aclrtSynchronizeStream(stream));
CHECK_ACL(aclrtMemcpy(h_out.data(), out_sz * 2, d_out, out_sz * 2, ACL_MEMCPY_DEVICE_TO_HOST));
for (int64_t i = 0; i < out_sz; ++i) h_out_f[i] = half_to_float(h_out[i]);
bool pass = allclose(h_out_f.data(), h_ref.data(), out_sz, 1e-4f, 1e-3f);
printf(" [fp16] %s: %s\n", name, pass ? "PASS" : "FAIL");
CHECK_ACL(aclrtFree(d_in)); CHECK_ACL(aclrtFree(d_w)); CHECK_ACL(aclrtFree(d_out));
CHECK_ACL(aclrtDestroyStream(stream));
return pass;
}
bool test_bf16(const char* name, int64_t batch, int64_t seq, int64_t dim, int64_t streams) {
int64_t in_sz = batch * seq * dim;
int64_t out_sz = batch * streams * seq * dim;
std::vector<float> h_in_f(in_sz), h_w_f(streams), h_ref(out_sz), h_out_f(out_sz);
std::vector<uint16_t> h_in(in_sz), h_out(out_sz);
std::mt19937 rng(42);
std::uniform_real_distribution<float> dist(-1.0f, 1.0f);
std::uniform_real_distribution<float> dist_pos(0.1f, 1.1f);
for (int64_t i = 0; i < in_sz; ++i) {
h_in_f[i] = dist(rng);
h_in[i] = float_to_bf16(h_in_f[i]);
h_in_f[i] = bf16_to_float(h_in[i]);
}
float sum = 0;
for (int64_t i = 0; i < streams; ++i) { h_w_f[i] = dist_pos(rng); sum += h_w_f[i]; }
for (int64_t i = 0; i < streams; ++i) h_w_f[i] /= sum;
cpu_ref(h_in_f.data(), h_w_f.data(), h_ref.data(), batch, seq, dim, streams);
void *d_in, *d_w, *d_out;
CHECK_ACL(aclrtMalloc(&d_in, in_sz * 2, ACL_MEM_MALLOC_HUGE_FIRST));
CHECK_ACL(aclrtMalloc(&d_w, streams * 4, ACL_MEM_MALLOC_HUGE_FIRST));
CHECK_ACL(aclrtMalloc(&d_out, out_sz * 2, ACL_MEM_MALLOC_HUGE_FIRST));
CHECK_ACL(aclrtMemcpy(d_in, in_sz * 2, h_in.data(), in_sz * 2, ACL_MEMCPY_HOST_TO_DEVICE));
CHECK_ACL(aclrtMemcpy(d_w, streams * 4, h_w_f.data(), streams * 4, ACL_MEMCPY_HOST_TO_DEVICE));
aclrtStream stream; CHECK_ACL(aclrtCreateStream(&stream));
mhc_post_do_bf16(batch * streams, stream, (uint8_t*)d_in, (uint8_t*)d_w, (uint8_t*)d_out, batch, seq, dim, streams);
CHECK_ACL(aclrtSynchronizeStream(stream));
CHECK_ACL(aclrtMemcpy(h_out.data(), out_sz * 2, d_out, out_sz * 2, ACL_MEMCPY_DEVICE_TO_HOST));
for (int64_t i = 0; i < out_sz; ++i) h_out_f[i] = bf16_to_float(h_out[i]);
bool pass = allclose(h_out_f.data(), h_ref.data(), out_sz, 1e-3f, 4e-3f);
printf(" [bf16] %s: %s\n", name, pass ? "PASS" : "FAIL");
CHECK_ACL(aclrtFree(d_in)); CHECK_ACL(aclrtFree(d_w)); CHECK_ACL(aclrtFree(d_out));
CHECK_ACL(aclrtDestroyStream(stream));
return pass;
}
bool test_case(const char* name, int64_t b, int64_t s, int64_t d, int64_t n) {
printf("\n[%s] batch=%ld seq=%ld dim=%ld streams=%ld\n", name, b, s, d, n);
bool pass = true;
pass &= test_fp32(name, b, s, d, n);
pass &= test_fp16(name, b, s, d, n);
pass &= test_bf16(name, b, s, d, n);
return pass;
}
int main() {
CHECK_ACL(aclInit(nullptr));
CHECK_ACL(aclrtSetDevice(0));
printf("=== mhc_post Edge Cases Test ===\n");
printf("Precision: fp32=bit-exact, fp16=allclose(1e-4,1e-3), bf16=allclose(1e-3,4e-3)\n");
bool all_pass = true;
all_pass &= test_case("dim=7", 2, 16, 7, 4);
all_pass &= test_case("dim=13", 2, 16, 13, 4);
all_pass &= test_case("dim=63", 2, 16, 63, 4);
all_pass &= test_case("dim=127", 2, 16, 127, 4);
all_pass &= test_case("dim=1", 2, 16, 1, 4);
all_pass &= test_case("batch=64", 64, 8, 64, 4);
all_pass &= test_case("seq=1", 4, 1, 128, 4);
all_pass &= test_case("streams=1", 4, 32, 64, 1);
all_pass &= test_case("minimal", 1, 1, 1, 1);
all_pass &= test_case("combo", 1, 1, 7, 1);
printf("\n=== Final: %s ===\n", all_pass ? "ALL PASS" : "SOME FAILED");
CHECK_ACL(aclrtResetDevice(0));
CHECK_ACL(aclFinalize());
return all_pass ? 0 : 1;
}