* Copyright (c) 2026 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#pragma once
#include <fstream>
#include <iostream>
#include <sstream>
#include <stdexcept>
#include <string>
#include <unordered_map>
#include <vector>
#include <gtest/gtest.h>
#include "fill.h"
#include "acl/acl.h"
#include "cann_ops_blas.h"
using csv_map = std::unordered_map<std::string, std::string>;
template <typename Map>
inline typename Map::mapped_type ReadMap(
const Map& m, const typename Map::key_type& key,
const typename Map::mapped_type& defaultValue = typename Map::mapped_type{})
{
auto it = m.find(key);
return it != m.end() ? it->second : defaultValue;
}
inline std::vector<std::string> CsvSplitLine(const std::string& line)
{
std::vector<std::string> fields;
std::string field;
bool inQuotes = false;
for (size_t i = 0; i < line.size(); i++) {
char c = line[i];
if (c == '"') {
inQuotes = !inQuotes;
} else if (c == ',' && !inQuotes) {
size_t s = field.find_first_not_of(" \t\r\n");
size_t e = field.find_last_not_of(" \t\r\n");
fields.push_back((s == std::string::npos) ? "" : field.substr(s, e - s + 1));
field.clear();
} else {
field += c;
}
}
size_t s = field.find_first_not_of(" \t\r\n");
size_t e = field.find_last_not_of(" \t\r\n");
fields.push_back((s == std::string::npos) ? "" : field.substr(s, e - s + 1));
return fields;
}
template <typename T>
inline std::vector<T> GetCasesFromCsv(const std::string& csvPath)
{
std::ifstream ifs(csvPath);
if (!ifs.is_open()) {
throw std::runtime_error("Cannot open CSV file: " + csvPath);
}
std::string headerLine;
bool hasHeader = false;
while (std::getline(ifs, headerLine)) {
size_t first = headerLine.find_first_not_of(" \t\r\n");
if (first != std::string::npos && headerLine[first] != '#') {
hasHeader = true;
break;
}
}
if (!hasHeader) {
throw std::runtime_error("CSV file has no header: " + csvPath);
}
auto keys = CsvSplitLine(headerLine);
std::vector<T> cases;
std::string line;
while (std::getline(ifs, line)) {
auto trimmed = line;
size_t ns = trimmed.find_first_not_of(" \t\r\n");
if (ns == std::string::npos)
continue;
if (trimmed[ns] == '#')
continue;
auto values = CsvSplitLine(line);
if (values.size() < keys.size())
continue;
csv_map row;
for (size_t i = 0; i < keys.size(); i++) {
if (!values[i].empty()) {
row[keys[i]] = values[i];
}
}
cases.emplace_back(row);
}
return cases;
}
inline std::string ReplaceFileExtension2Csv(const std::string& filePath)
{
std::string p = filePath;
size_t dot = p.rfind('.');
if (dot != std::string::npos)
p = p.substr(0, dot);
return p + ".csv";
}
inline int parseInt(const std::string& s, int def = 0)
{
try {
return std::stoi(s);
} catch (...) {
return def;
}
}
inline int64_t parseInt64(const std::string& s, int64_t def = 0)
{
try {
return std::stoll(s);
} catch (...) {
return def;
}
}
inline float parseFloat(const std::string& s, float def = 0.0f)
{
try {
return std::stof(s);
} catch (...) {
return def;
}
}
inline double parseDouble(const std::string& s, double def = 0.0)
{
try {
return std::stod(s);
} catch (...) {
return def;
}
}
inline uint32_t parseUint(const std::string& s, uint32_t def = 0)
{
try {
unsigned long val = std::stoul(s);
if (val > UINT32_MAX)
return def;
return static_cast<uint32_t>(val);
} catch (...) {
return def;
}
}
template <typename EnumT>
inline EnumT parseEnum(const std::string& s, const std::unordered_map<std::string, EnumT>& table, EnumT defaultVal)
{
if (s.empty())
return defaultVal;
auto it = table.find(s);
if (it != table.end())
return it->second;
try {
return static_cast<EnumT>(std::stoi(s));
} catch (...) {
return defaultVal;
}
}
inline aclblasStatus_t parseStatus(const std::string& s)
{
static const std::unordered_map<std::string, aclblasStatus_t> t = {
{"ACLBLAS_STATUS_SUCCESS", ACLBLAS_STATUS_SUCCESS},
{"SUCCESS", ACLBLAS_STATUS_SUCCESS},
{"ACLBLAS_STATUS_NOT_INITIALIZED", ACLBLAS_STATUS_NOT_INITIALIZED},
{"NOT_INITIALIZED", ACLBLAS_STATUS_NOT_INITIALIZED},
{"ACLBLAS_STATUS_ALLOC_FAILED", ACLBLAS_STATUS_ALLOC_FAILED},
{"ALLOC_FAILED", ACLBLAS_STATUS_ALLOC_FAILED},
{"ACLBLAS_STATUS_INVALID_VALUE", ACLBLAS_STATUS_INVALID_VALUE},
{"INVALID_VALUE", ACLBLAS_STATUS_INVALID_VALUE},
{"ACLBLAS_STATUS_MAPPING_ERROR", ACLBLAS_STATUS_MAPPING_ERROR},
{"MAPPING_ERROR", ACLBLAS_STATUS_MAPPING_ERROR},
{"ACLBLAS_STATUS_EXECUTION_FAILED", ACLBLAS_STATUS_EXECUTION_FAILED},
{"EXECUTION_FAILED", ACLBLAS_STATUS_EXECUTION_FAILED},
{"ACLBLAS_STATUS_INTERNAL_ERROR", ACLBLAS_STATUS_INTERNAL_ERROR},
{"INTERNAL_ERROR", ACLBLAS_STATUS_INTERNAL_ERROR},
{"ACLBLAS_STATUS_NOT_SUPPORTED", ACLBLAS_STATUS_NOT_SUPPORTED},
{"NOT_SUPPORTED", ACLBLAS_STATUS_NOT_SUPPORTED},
{"ACLBLAS_STATUS_ARCH_MISMATCH", ACLBLAS_STATUS_ARCH_MISMATCH},
{"ARCH_MISMATCH", ACLBLAS_STATUS_ARCH_MISMATCH},
{"ACLBLAS_STATUS_HANDLE_IS_NULLPTR", ACLBLAS_STATUS_HANDLE_IS_NULLPTR},
{"HANDLE_IS_NULLPTR", ACLBLAS_STATUS_HANDLE_IS_NULLPTR},
{"ACLBLAS_STATUS_INVALID_ENUM", ACLBLAS_STATUS_INVALID_ENUM},
{"INVALID_ENUM", ACLBLAS_STATUS_INVALID_ENUM},
{"ACLBLAS_STATUS_UNKNOWN", ACLBLAS_STATUS_UNKNOWN},
{"UNKNOWN", ACLBLAS_STATUS_UNKNOWN}};
return parseEnum(s, t, ACLBLAS_STATUS_INTERNAL_ERROR);
}
struct BlasTestParamBase {
std::string caseName;
std::string description;
aclblasStatus_t expectResult = ACLBLAS_STATUS_SUCCESS;
double mereThreshold = 0.0;
double mareMultiplier = 0.0;
uint32_t randomSeed = 0;
BlasTestParamBase(const csv_map& m)
{
caseName = ReadMap(m, "case_name");
description = ReadMap(m, "description");
expectResult = parseStatus(ReadMap(m, "expect_result", "SUCCESS"));
mereThreshold = parseDouble(ReadMap(m, "mere_threshold", "0"));
mareMultiplier = parseDouble(ReadMap(m, "mare_multiplier", "0"));
randomSeed = parseUint(ReadMap(m, "random_seed", "0"));
}
};
inline std::ostream& operator<<(std::ostream& os, const BlasTestParamBase& p) { return os << p.caseName; }
template <typename T>
inline std::string PrintCaseInfoString(const ::testing::TestParamInfo<T>& info)
{
return info.param.caseName;
}
inline aclblasFillMode_t parseFillMode(const std::string& s)
{
static const std::unordered_map<std::string, aclblasFillMode_t> t = {
{"ACLBLAS_UPPER", ACLBLAS_UPPER},
{"UPPER", ACLBLAS_UPPER},
{"ACLBLAS_LOWER", ACLBLAS_LOWER},
{"LOWER", ACLBLAS_LOWER},
{"INVALID", static_cast<aclblasFillMode_t>(0xFF)}};
return parseEnum(s, t, static_cast<aclblasFillMode_t>(0xFF));
}
inline aclblasDiagType_t parseDiagType(const std::string& s)
{
static const std::unordered_map<std::string, aclblasDiagType_t> t = {
{"ACLBLAS_NON_UNIT", ACLBLAS_NON_UNIT},
{"NON_UNIT", ACLBLAS_NON_UNIT},
{"ACLBLAS_UNIT", ACLBLAS_UNIT},
{"UNIT", ACLBLAS_UNIT},
{"INVALID", static_cast<aclblasDiagType_t>(0xFF)}};
return parseEnum(s, t, ACLBLAS_NON_UNIT);
}
inline aclblasSideMode_t parseSideMode(const std::string& s)
{
static const std::unordered_map<std::string, aclblasSideMode_t> t = {
{"ACLBLAS_SIDE_LEFT", ACLBLAS_SIDE_LEFT},
{"LEFT", ACLBLAS_SIDE_LEFT},
{"ACLBLAS_SIDE_RIGHT", ACLBLAS_SIDE_RIGHT},
{"RIGHT", ACLBLAS_SIDE_RIGHT}};
return parseEnum(s, t, ACLBLAS_SIDE_LEFT);
}
inline aclblasOperation_t parseOpTrans(const std::string& s)
{
static const std::unordered_map<std::string, aclblasOperation_t> t = {
{"ACLBLAS_OP_N", ACLBLAS_OP_N},
{"N", ACLBLAS_OP_N},
{"ACLBLAS_OP_T", ACLBLAS_OP_T},
{"T", ACLBLAS_OP_T},
{"ACLBLAS_OP_C", ACLBLAS_OP_C},
{"C", ACLBLAS_OP_C},
{"ACLBLAS_OP_H", ACLBLAS_OP_C},
{"H", ACLBLAS_OP_C},
{"INVALID", static_cast<aclblasOperation_t>(0xFF)}};
return parseEnum(s, t, ACLBLAS_OP_N);
}
inline aclblasComputeType_t parseComputeType(const std::string& s)
{
static const std::unordered_map<std::string, aclblasComputeType_t> t = {
{"ACLBLAS_COMPUTE_16F", ACLBLAS_COMPUTE_16F},
{"COMPUTE_16F", ACLBLAS_COMPUTE_16F},
{"ACLBLAS_COMPUTE_16F_PEDANTIC", ACLBLAS_COMPUTE_16F_PEDANTIC},
{"COMPUTE_16F_PEDANTIC", ACLBLAS_COMPUTE_16F_PEDANTIC},
{"ACLBLAS_COMPUTE_32F", ACLBLAS_COMPUTE_32F},
{"COMPUTE_32F", ACLBLAS_COMPUTE_32F},
{"ACLBLAS_COMPUTE_32F_PEDANTIC", ACLBLAS_COMPUTE_32F_PEDANTIC},
{"COMPUTE_32F_PEDANTIC", ACLBLAS_COMPUTE_32F_PEDANTIC},
{"ACLBLAS_COMPUTE_32F_FAST_16F", ACLBLAS_COMPUTE_32F_FAST_16F},
{"COMPUTE_32F_FAST_16F", ACLBLAS_COMPUTE_32F_FAST_16F},
{"ACLBLAS_COMPUTE_32F_FAST_16BF", ACLBLAS_COMPUTE_32F_FAST_16BF},
{"COMPUTE_32F_FAST_16BF", ACLBLAS_COMPUTE_32F_FAST_16BF},
{"ACLBLAS_COMPUTE_32F_FAST_TF32", ACLBLAS_COMPUTE_32F_FAST_TF32},
{"COMPUTE_32F_FAST_TF32", ACLBLAS_COMPUTE_32F_FAST_TF32},
{"ACLBLAS_COMPUTE_64F", ACLBLAS_COMPUTE_64F},
{"COMPUTE_64F", ACLBLAS_COMPUTE_64F},
{"ACLBLAS_COMPUTE_64F_PEDANTIC", ACLBLAS_COMPUTE_64F_PEDANTIC},
{"COMPUTE_64F_PEDANTIC", ACLBLAS_COMPUTE_64F_PEDANTIC},
{"ACLBLAS_COMPUTE_32I", ACLBLAS_COMPUTE_32I},
{"COMPUTE_32I", ACLBLAS_COMPUTE_32I},
{"ACLBLAS_COMPUTE_32I_PEDANTIC", ACLBLAS_COMPUTE_32I_PEDANTIC},
{"COMPUTE_32I_PEDANTIC", ACLBLAS_COMPUTE_32I_PEDANTIC}};
return parseEnum(s, t, ACLBLAS_COMPUTE_32F);
}
inline std::vector<std::string> splitBySemicolon(const std::string& s)
{
std::vector<std::string> parts;
std::istringstream iss(s);
std::string token;
while (std::getline(iss, token, ';'))
if (!token.empty())
parts.push_back(token);
return parts;
}
inline std::vector<int> parseIntArray(const std::string& s)
{
auto parts = splitBySemicolon(s);
std::vector<int> result;
for (const auto& p : parts)
result.push_back(parseInt(p));
return result;
}
inline std::vector<float> parseFloatArray(const std::string& s)
{
auto parts = splitBySemicolon(s);
std::vector<float> result;
for (const auto& p : parts)
result.push_back(parseFloat(p));
return result;
}
inline std::vector<aclblasOperation_t> parseTransArray(const std::string& s)
{
auto parts = splitBySemicolon(s);
std::vector<aclblasOperation_t> result;
for (const auto& p : parts)
result.push_back(parseOpTrans(p));
return result;
}