* Copyright (c) 2026 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#pragma once
#include <algorithm>
#include <string>
#include "acl/acl.h"
#include "cann_ops_blasLt.h"
#include "csv_loader.h"
inline aclDataType parseTransformDtype(const std::string& s)
{
if (s == "FP32" || s == "ACL_FLOAT" || s == "0") return ACL_FLOAT;
if (s == "FP16" || s == "ACL_FLOAT16" || s == "1") return ACL_FLOAT16;
if (s == "INT8" || s == "ACL_INT8" || s == "2") return ACL_INT8;
if (s == "INT32" || s == "ACL_INT32" || s == "3") return ACL_INT32;
if (s == "BF16" || s == "ACL_BF16" || s == "27") return ACL_BF16;
if (s == "FP8_E4M3FN" || s == "ACL_FLOAT8_E4M3FN" || s == "36") return ACL_FLOAT8_E4M3FN;
if (s == "FP8_E5M2" || s == "ACL_FLOAT8_E5M2" || s == "35") return ACL_FLOAT8_E5M2;
if (s == "FP4_E2M1" || s == "ACL_FLOAT4_E2M1" || s == "40") return ACL_FLOAT4_E2M1;
if (s == "FP64" || s == "ACL_DOUBLE" || s == "11") return ACL_DOUBLE;
if (s == "COMPLEX64" || s == "ACL_COMPLEX64" || s == "16") return ACL_COMPLEX64;
if (s == "HIFLOAT8" || s == "ACL_HIFLOAT8" || s == "34") return ACL_HIFLOAT8;
if (s == "FP4_E1M2" || s == "ACL_FLOAT4_E1M2" || s == "41") return ACL_FLOAT4_E1M2;
if (s == "E8M0" || s == "FP8_E8M0" || s == "ACL_FLOAT8_E8M0" || s == "37") return ACL_FLOAT8_E8M0;
if (s == "FP6" || s == "FP6_E2M3" || s == "ACL_FLOAT6_E2M3" || s == "39") return ACL_FLOAT6_E2M3;
return ACL_FLOAT;
}
inline aclblasLtOrder_t parseTransformOrder(const std::string& s)
{
if (s == "COL" || s == "ACLBLASLT_ORDER_COL" || s == "0") return ACLBLASLT_ORDER_COL;
if (s == "ROW" || s == "ACLBLASLT_ORDER_ROW" || s == "1") return ACLBLASLT_ORDER_ROW;
if (s == "COL32" || s == "ACLBLASLT_ORDER_COL32" || s == "2") return ACLBLASLT_ORDER_COL32;
if (s == "COL4_4R2_8C" || s == "ACLBLASLT_ORDER_COL4_4R2_8C" || s == "3") return ACLBLASLT_ORDER_COL4_4R2_8C;
if (s == "COL32_2R_4R4" || s == "ACLBLASLT_ORDER_COL32_2R_4R4" || s == "4") return ACLBLASLT_ORDER_COL32_2R_4R4;
try {
return static_cast<aclblasLtOrder_t>(std::stoi(s));
} catch (...) {
return static_cast<aclblasLtOrder_t>(-1);
}
}
inline aclDataType parseScaleType(const std::string& s)
{
if (s == "INT32" || s == "ACL_INT32" || s == "3") return ACL_INT32;
if (s == "BF16" || s == "ACL_BF16" || s == "27") return ACL_BF16;
return ACL_FLOAT;
}
inline bool isFp8TransformDtype(aclDataType dt)
{
return dt == ACL_FLOAT8_E4M3FN || dt == ACL_FLOAT8_E5M2;
}
inline bool isFp4TransformDtype(aclDataType dt)
{
return dt == ACL_FLOAT4_E2M1;
}
inline size_t transformDtypeSize(aclDataType dt)
{
switch (dt) {
case ACL_FLOAT: return 4;
case ACL_INT32: return 4;
case ACL_FLOAT16: return 2;
case ACL_BF16: return 2;
case ACL_INT8: return 1;
case ACL_FLOAT8_E4M3FN:
case ACL_FLOAT8_E5M2: return 1;
case ACL_FLOAT4_E2M1: return 1;
default: return 4;
}
}
inline int fp4PackedLd(int logicalLd)
{
return (logicalLd + 1) / 2;
}
inline int64_t fp4PackedBytes(int64_t elemCount)
{
return (elemCount + 1) / 2;
}
inline bool isIntTransformDtype(aclDataType dt)
{
return dt == ACL_INT8 || dt == ACL_INT32;
}
struct LtMatrixTransformParam : public BlasTestParamBase {
aclDataType dtypeA = ACL_FLOAT;
aclDataType dtypeB = ACL_FLOAT;
aclDataType dtypeC = ACL_FLOAT;
aclDataType scaleType = ACL_FLOAT;
aclblasLtOrder_t orderA = ACLBLASLT_ORDER_COL;
aclblasLtOrder_t orderB = ACLBLASLT_ORDER_COL;
aclblasLtOrder_t orderC = ACLBLASLT_ORDER_COL;
aclblasOperation_t transA = ACLBLAS_OP_N;
aclblasOperation_t transB = ACLBLAS_OP_N;
int rowsA = 0, colsA = 0;
int rowsB = 0, colsB = 0;
int rowsC = 0, colsC = 0;
int lda = 0, ldb = 0, ldc = 0;
float alpha = 1.0f;
float beta = 0.0f;
int batchCount = 1;
bool BIsNull = false;
bool handleNull = false;
bool transformDescNull = false;
bool AdescNull = false;
bool BdescNull = false;
bool CdescNull = false;
bool alphaNull = false;
bool betaNull = false;
bool ANull = false;
bool CIsNull = false;
LtMatrixTransformParam(const csv_map& m) : BlasTestParamBase(m)
{
dtypeA = parseTransformDtype(ReadMap(m, "dtypeA", "FP32"));
dtypeB = parseTransformDtype(ReadMap(m, "dtypeB", "FP32"));
dtypeC = parseTransformDtype(ReadMap(m, "dtypeC", "FP32"));
scaleType = parseScaleType(ReadMap(m, "scale_type", "FP32"));
orderA = parseTransformOrder(ReadMap(m, "orderA", "COL"));
orderB = parseTransformOrder(ReadMap(m, "orderB", "COL"));
orderC = parseTransformOrder(ReadMap(m, "orderC", "COL"));
transA = parseOpTrans(ReadMap(m, "transA", "N"));
transB = parseOpTrans(ReadMap(m, "transB", "N"));
rowsA = parseInt(ReadMap(m, "rowsA", "0"));
colsA = parseInt(ReadMap(m, "colsA", "0"));
rowsB = parseInt(ReadMap(m, "rowsB", "0"));
colsB = parseInt(ReadMap(m, "colsB", "0"));
rowsC = parseInt(ReadMap(m, "rowsC", "0"));
colsC = parseInt(ReadMap(m, "colsC", "0"));
lda = parseInt(ReadMap(m, "lda", "0"));
ldb = parseInt(ReadMap(m, "ldb", "0"));
ldc = parseInt(ReadMap(m, "ldc", "0"));
alpha = parseFloat(ReadMap(m, "alpha", "1.0"));
beta = parseFloat(ReadMap(m, "beta", "0.0"));
batchCount = parseInt(ReadMap(m, "batch_count", "1"));
BIsNull = (ReadMap(m, "B_is_null", "false") == "true");
handleNull = (ReadMap(m, "handle_null", "false") == "true");
transformDescNull = (ReadMap(m, "transformDesc_null", "false") == "true");
AdescNull = (ReadMap(m, "Adesc_null", "false") == "true");
BdescNull = (ReadMap(m, "Bdesc_null", "false") == "true");
CdescNull = (ReadMap(m, "Cdesc_null", "false") == "true");
alphaNull = (ReadMap(m, "alpha_null", "false") == "true");
betaNull = (ReadMap(m, "beta_null", "false") == "true");
ANull = (ReadMap(m, "A_null", "false") == "true");
CIsNull = (ReadMap(m, "C_is_null", "false") == "true");
if (rowsB == 0 && colsB == 0 && !BIsNull) { rowsB = rowsA; colsB = colsA; }
if (rowsC == 0 && colsC == 0) { rowsC = rowsA; colsC = colsA; }
}
};