* Copyright (c) 2026 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
* NOTE: Portions of this code were AI-generated and have been
* technically reviewed for functional accuracy and security
*/
* \file test_aclnn_ndtri.cpp
* \brief Comprehensive aclnn test for Ndtri operator (Ascend950 / arch35).
*
* Ndtri is the inverse of the standard normal CDF (a.k.a. probit / quantile fn):
* y = ndtri(p), p in (0, 1)
*
* Verification: roundtrip Phi(ndtri(p)) ≈ p, where
* Phi(z) = 0.5 * (1 + erf(z / sqrt(2)))
*
* Test coverage:
* - Shapes: small (7), aligned (1024), 2D (32x32), large (100000), non-aligned (3x5)
* - Dtypes: float32, float16
* - Probabilities: near-boundary (0.001/0.999), small/medium/large region
*
* Usage:
* export ASCEND_RT_VISIBLE_DEVICES=7
* source opp/vendors/custom_math/bin/set_env.bash
* ./test_aclnn_ndtri
*/
#include <cmath>
#include <cstdint>
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <random>
#include <string>
#include <vector>
#include "acl/acl.h"
#include "aclnn_ndtri.h"
#define CHECK_ACL(expr, msg) \
do { \
auto _ret = (expr); \
if (_ret != ACL_SUCCESS) { \
printf("[FAIL] %s — ACL error %d\n", (msg), (int)_ret); \
return -1; \
} \
} while (0)
static int64_t ShapeSize(const std::vector<int64_t>& s)
{
int64_t n = 1;
for (auto d : s) n *= d;
return n;
}
static uint16_t FloatToHalf(float v)
{
uint32_t bits;
std::memcpy(&bits, &v, 4);
uint16_t sign = (bits >> 16) & 0x8000;
int32_t exp = ((bits >> 23) & 0xFF) - 127 + 15;
uint32_t mant = bits & 0x7FFFFF;
if (exp <= 0) return sign;
if (exp >= 31) return sign | 0x7C00;
return sign | (uint16_t)(exp << 10) | (uint16_t)(mant >> 13);
}
static float HalfToFloat(uint16_t h)
{
uint32_t sign = (h & 0x8000) << 16;
uint32_t exp = (h >> 10) & 0x1F;
uint32_t mant = h & 0x03FF;
if (exp == 0) {
if (mant == 0) { float r; uint32_t b = sign; std::memcpy(&r, &b, 4); return r; }
while (!(mant & 0x0400)) { mant <<= 1; exp--; }
exp++; mant &= ~0x0400;
} else if (exp == 31) {
uint32_t b = sign | 0x7F800000 | (mant << 13);
float r; std::memcpy(&r, &b, 4); return r;
}
uint32_t bits = sign | ((exp + 127 - 15) << 23) | (mant << 13);
float r; std::memcpy(&r, &bits, 4);
return r;
}
static float NormalCdf(float z)
{
return 0.5f * (1.0f + std::erf(z / std::sqrt(2.0f)));
}
struct TestCase {
std::string name;
std::vector<int64_t> shape;
aclDataType dtype;
float atol;
};
static int RunTestCase(aclrtStream stream, const TestCase& tc, const std::vector<float>& values)
{
int64_t n = ShapeSize(tc.shape);
bool isFp16 = (tc.dtype == ACL_FLOAT16);
size_t elemSize = isFp16 ? 2 : 4;
size_t totalBytes = n * elemSize;
std::vector<uint8_t> xHostBytes(totalBytes);
std::vector<float> xFloat(n);
for (int64_t i = 0; i < n; i++) {
float v = values[i % values.size()];
xFloat[i] = v;
if (isFp16) {
uint16_t h = FloatToHalf(v);
std::memcpy(xHostBytes.data() + i * 2, &h, 2);
} else {
std::memcpy(xHostBytes.data() + i * 4, &v, 4);
}
}
std::vector<uint8_t> outHostBytes(totalBytes, 0);
void *xDev = nullptr, *outDev = nullptr;
CHECK_ACL(aclrtMalloc(&xDev, totalBytes, ACL_MEM_MALLOC_HUGE_FIRST), "malloc x");
CHECK_ACL(aclrtMalloc(&outDev, totalBytes, ACL_MEM_MALLOC_HUGE_FIRST), "malloc out");
CHECK_ACL(aclrtMemcpy(xDev, totalBytes, xHostBytes.data(), totalBytes, ACL_MEMCPY_HOST_TO_DEVICE), "H2D x");
CHECK_ACL(aclrtMemcpy(outDev, totalBytes, outHostBytes.data(), totalBytes, ACL_MEMCPY_HOST_TO_DEVICE), "H2D out");
std::vector<int64_t> strides(tc.shape.size(), 1);
for (int64_t i = (int64_t)tc.shape.size() - 2; i >= 0; i--)
strides[i] = tc.shape[i + 1] * strides[i + 1];
aclTensor* xT = aclCreateTensor(tc.shape.data(), tc.shape.size(), tc.dtype,
strides.data(), 0, ACL_FORMAT_ND,
tc.shape.data(), tc.shape.size(), xDev);
aclTensor* outT = aclCreateTensor(tc.shape.data(), tc.shape.size(), tc.dtype,
strides.data(), 0, ACL_FORMAT_ND,
tc.shape.data(), tc.shape.size(), outDev);
uint64_t wsSize = 0;
aclOpExecutor* executor = nullptr;
CHECK_ACL(aclnnNdtriGetWorkspaceSize(xT, outT, &wsSize, &executor), "GetWorkspaceSize");
void* wsAddr = nullptr;
if (wsSize > 0) {
CHECK_ACL(aclrtMalloc(&wsAddr, wsSize, ACL_MEM_MALLOC_HUGE_FIRST), "malloc ws");
}
CHECK_ACL(aclnnNdtri(wsAddr, wsSize, executor, stream), "Ndtri execute");
CHECK_ACL(aclrtSynchronizeStream(stream), "sync");
CHECK_ACL(aclrtMemcpy(outHostBytes.data(), totalBytes, outDev, totalBytes, ACL_MEMCPY_DEVICE_TO_HOST), "D2H");
int failCount = 0;
float maxDiff = 0.0f;
for (int64_t i = 0; i < n; i++) {
float npuVal;
if (isFp16) {
uint16_t h;
std::memcpy(&h, outHostBytes.data() + i * 2, 2);
npuVal = HalfToFloat(h);
} else {
std::memcpy(&npuVal, outHostBytes.data() + i * 4, 4);
}
float roundtrip = NormalCdf(npuVal);
float diff = std::fabs(roundtrip - xFloat[i]);
if (diff > maxDiff) maxDiff = diff;
if (diff > tc.atol) {
if (failCount < 5) {
printf(" p=%.6f ndtri=%.6f Phi(ndtri)=%.6f diff=%.2e > atol=%.2e\n",
xFloat[i], npuVal, roundtrip, diff, tc.atol);
}
failCount++;
}
}
aclDestroyTensor(xT);
aclDestroyTensor(outT);
aclrtFree(xDev);
aclrtFree(outDev);
if (wsSize > 0) aclrtFree(wsAddr);
if (failCount > 0) {
printf(" [FAIL] %s — %d/%ld failures, max_diff=%.2e\n",
tc.name.c_str(), failCount, (long)n, maxDiff);
return -1;
}
printf(" [PASS] %s — %ld elems, max_diff=%.2e\n",
tc.name.c_str(), (long)n, maxDiff);
return 0;
}
static std::vector<float> FixedValues()
{
return {
0.001f, 0.01f, 0.05f, 0.1f, 0.2f, 0.3f, 0.4f,
0.5f,
0.6f, 0.7f, 0.8f, 0.9f, 0.95f, 0.99f, 0.999f
};
}
static std::vector<float> RandomValues(int64_t n, uint32_t seed = 42)
{
std::mt19937 gen(seed);
std::uniform_real_distribution<float> dist(0.001f, 0.999f);
std::vector<float> v(n);
for (auto& x : v) x = dist(gen);
return v;
}
int main()
{
aclrtStream stream;
CHECK_ACL(aclInit(nullptr), "aclInit");
CHECK_ACL(aclrtSetDevice(0), "setDevice");
CHECK_ACL(aclrtCreateStream(&stream), "createStream");
printf("\n===== Ndtri Comprehensive Test (Ascend950) =====\n\n");
int totalTests = 0, passedTests = 0;
printf("[float32]\n");
{
TestCase tc{"fp32_fixed_15vals", {15}, ACL_FLOAT, 1e-4f};
auto vals = FixedValues();
totalTests++;
if (RunTestCase(stream, tc, vals) == 0) passedTests++;
}
{
TestCase tc{"fp32_shape_3x5", {3, 5}, ACL_FLOAT, 1e-4f};
auto vals = FixedValues();
totalTests++;
if (RunTestCase(stream, tc, vals) == 0) passedTests++;
}
{
TestCase tc{"fp32_shape_1024", {1024}, ACL_FLOAT, 1e-4f};
auto vals = RandomValues(1024);
totalTests++;
if (RunTestCase(stream, tc, vals) == 0) passedTests++;
}
{
TestCase tc{"fp32_shape_32x32", {32, 32}, ACL_FLOAT, 1e-4f};
auto vals = RandomValues(32 * 32);
totalTests++;
if (RunTestCase(stream, tc, vals) == 0) passedTests++;
}
{
TestCase tc{"fp32_shape_100000", {100000}, ACL_FLOAT, 1e-4f};
auto vals = RandomValues(100000);
totalTests++;
if (RunTestCase(stream, tc, vals) == 0) passedTests++;
}
printf("\n[float16]\n");
{
TestCase tc{"fp16_fixed_15vals", {15}, ACL_FLOAT16, 5e-2f};
auto vals = FixedValues();
totalTests++;
if (RunTestCase(stream, tc, vals) == 0) passedTests++;
}
{
TestCase tc{"fp16_shape_7", {7}, ACL_FLOAT16, 5e-2f};
auto vals = FixedValues();
totalTests++;
if (RunTestCase(stream, tc, vals) == 0) passedTests++;
}
{
TestCase tc{"fp16_shape_1024", {1024}, ACL_FLOAT16, 5e-2f};
auto vals = RandomValues(1024, 123);
totalTests++;
if (RunTestCase(stream, tc, vals) == 0) passedTests++;
}
{
TestCase tc{"fp16_shape_100000", {100000}, ACL_FLOAT16, 5e-2f};
auto vals = RandomValues(100000, 456);
totalTests++;
if (RunTestCase(stream, tc, vals) == 0) passedTests++;
}
printf("\n===== Summary: %d/%d passed =====\n", passedTests, totalTests);
aclrtDestroyStream(stream);
aclrtResetDevice(0);
aclFinalize();
return (passedTests == totalTests) ? 0 : 1;
}