* Copyright (c) 2026 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#pragma once
#include <string>
#include <algorithm>
#include "acl/acl.h"
#include "cann_ops_blasLt.h"
#include "csv_loader.h"
inline aclDataType parseAclDataType(const std::string& s)
{
if (s == "FP32" || s == "ACL_FLOAT" || s == "0")
return ACL_FLOAT;
if (s == "BF16" || s == "ACL_BF16" || s == "27")
return ACL_BF16;
if (s == "MXFP8_E4M3FN" || s == "ACL_FLOAT8_E4M3FN" || s == "36")
return ACL_FLOAT8_E4M3FN;
if (s == "MXFP4_E2M1" || s == "ACL_FLOAT4_E2M1" || s == "40")
return ACL_FLOAT4_E2M1;
return ACL_FLOAT;
}
inline bool isMxfpType(aclDataType dt)
{
return dt == ACL_FLOAT8_E4M3FN || dt == ACL_FLOAT4_E2M1;
}
inline size_t dtypeElementSize(aclDataType dt)
{
if (dt == ACL_FLOAT) return sizeof(float);
if (dt == ACL_BF16) return 2;
if (dt == ACL_FLOAT8_E4M3FN) return 1;
if (dt == ACL_FLOAT4_E2M1) return 1;
return 4;
}
inline aclblasComputeType_t getComputeType(aclDataType dtypeA, aclDataType dtypeD)
{
return ACLBLAS_COMPUTE_32F;
}
inline int getPhysicalRowsA(int M, int K, aclblasOperation_t transA) {
return (transA == ACLBLAS_OP_N) ? M : K;
}
inline int getPhysicalColsA(int M, int K, aclblasOperation_t transA) {
return (transA == ACLBLAS_OP_N) ? K : M;
}
inline int getPhysicalRowsB(int K, int N, aclblasOperation_t transB) {
return (transB == ACLBLAS_OP_N) ? K : N;
}
inline int getPhysicalColsB(int K, int N, aclblasOperation_t transB) {
return (transB == ACLBLAS_OP_N) ? N : K;
}
inline int64_t mxfp4PackedLd(int64_t logicalLd)
{
return (logicalLd + 1) / 2;
}
inline int mxScaleStrideAlongK(int kDim)
{
return ((kDim + 63) / 64) * 2;
}
inline size_t mxScaleBufferBytesA(int M, int K, aclblasOperation_t transA)
{
if (transA == ACLBLAS_OP_N) {
return static_cast<size_t>(M) * static_cast<size_t>(mxScaleStrideAlongK(K));
}
return static_cast<size_t>(K) * static_cast<size_t>(mxScaleStrideAlongK(M));
}
inline size_t mxScaleBufferBytesB(int N, int K, aclblasOperation_t transB)
{
if (transB == ACLBLAS_OP_N) {
return static_cast<size_t>(mxScaleStrideAlongK(K)) * static_cast<size_t>(N);
}
return static_cast<size_t>(N) * static_cast<size_t>(mxScaleStrideAlongK(K));
}
struct LtMatmulParam : public BlasTestParamBase {
aclDataType dtypeA = ACL_FLOAT;
aclDataType dtypeB = ACL_FLOAT;
aclDataType dtypeC = ACL_FLOAT;
aclDataType dtypeD = ACL_FLOAT;
int M = 0;
int N = 0;
int K = 0;
aclblasOperation_t transA = ACLBLAS_OP_N;
aclblasOperation_t transB = ACLBLAS_OP_N;
int lda = 0;
int ldb = 0;
int ldc = 0;
int ldd = 0;
float alpha = 1.0f;
float beta = 0.0f;
std::string algoMode = "default";
bool CIsNull = false;
bool handleNull = false;
bool computeDescNull = false;
bool alphaNull = false;
bool Anull = false;
BlasFillMode scaleAFill = parseFill("RANDOM");
BlasFillMode scaleBFill = parseFill("RANDOM");
LtMatmulParam(const csv_map& m) : BlasTestParamBase(m)
{
dtypeA = parseAclDataType(ReadMap(m, "dtypeA", "FP32"));
dtypeB = parseAclDataType(ReadMap(m, "dtypeB", "FP32"));
dtypeC = parseAclDataType(ReadMap(m, "dtypeC", "FP32"));
dtypeD = parseAclDataType(ReadMap(m, "dtypeD", "FP32"));
M = parseInt(ReadMap(m, "M", "0"));
N = parseInt(ReadMap(m, "N", "0"));
K = parseInt(ReadMap(m, "K", "0"));
transA = parseOpTrans(ReadMap(m, "transA", "N"));
transB = parseOpTrans(ReadMap(m, "transB", "N"));
lda = parseInt(ReadMap(m, "lda", "0"));
ldb = parseInt(ReadMap(m, "ldb", "0"));
ldc = parseInt(ReadMap(m, "ldc", "0"));
ldd = parseInt(ReadMap(m, "ldd", "0"));
alpha = parseFloat(ReadMap(m, "alpha", "1.0"));
beta = parseFloat(ReadMap(m, "beta", "0.0"));
algoMode = ReadMap(m, "algo_mode", "default");
CIsNull = (ReadMap(m, "C_is_null", "false") == "true");
handleNull = (ReadMap(m, "handle_null", "false") == "true");
computeDescNull = (ReadMap(m, "computeDesc_null", "false") == "true");
alphaNull = (ReadMap(m, "alpha_null", "false") == "true");
Anull = (ReadMap(m, "A_null", "false") == "true");
std::string sfA = ReadMap(m, "scaleA_fill", "RANDOM");
std::string sfB = ReadMap(m, "scaleB_fill", "RANDOM");
if (sfA == "none") scaleAFill = parseFill("VALUE_NORM_1");
else scaleAFill = parseFill(sfA);
if (sfB == "none") scaleBFill = parseFill("VALUE_NORM_1");
else scaleBFill = parseFill(sfB);
int physColsA = getPhysicalColsA(M, K, transA);
int physColsB = getPhysicalColsB(K, N, transB);
if (lda == 0) lda = std::max(1, physColsA);
if (ldb == 0) ldb = std::max(1, physColsB);
if (ldc == 0) ldc = std::max(1, M);
if (ldd == 0) ldd = std::max(1, M);
}
};