* This program is free software, you can redistribute it and/or modify.
* Copyright (c) 2026 Huawei Technologies Co., Ltd.
* This file is a part of the CANN Open Software.
* Licensed under CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING
* BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. See LICENSE in the root of
* the software repository for the full text of the License.
*/
#ifndef OPTEST_CATLASS_KERNEL_JIT_H
#define OPTEST_CATLASS_KERNEL_JIT_H
#include <cstdint>
#include <string>
#include <unordered_map>
#include <vector>
#include <acl/acl.h>
#include "catlass/gemm_coord.hpp"
namespace CatlassKernel {
struct TParamsBase {
Catlass::GemmCoord l1TileShape;
Catlass::GemmCoord l0TileShape;
Catlass::GemmCoord swizzle;
};
* @brief Compile-time JIT parameters shared by numbered matmul-family examples.
* Uses map-based storage for extensibility.
*/
struct TParams : TParamsBase {
std::unordered_map<std::string, aclDataType> element;
std::unordered_map<std::string, bool> transpose;
std::unordered_map<std::string, bool> useNz;
aclDataType elem(const std::string& k, aclDataType def = ACL_FLOAT16) const
{
auto it = element.find(k);
return it != element.end() ? it->second : def;
}
bool trans(const std::string& k, bool def = false) const
{
auto it = transpose.find(k);
return it != transpose.end() ? it->second : def;
}
bool nz(const std::string& k, bool def = false) const
{
auto it = useNz.find(k);
return it != useNz.end() ? it->second : def;
}
};
* @brief Runtime matrix parameters shared by numbered matmul-family examples.
*/
struct MatmulParams {
uint32_t m = 1;
uint32_t n = 1;
uint32_t k = 1;
uint32_t batch = 1;
std::vector<uint8_t*> inputAddr;
std::vector<uint8_t*> outputAddr;
};
struct MatmulEvgParams : public MatmulParams {
std::string evgType;
float negativeSlope = 1;
};
* @brief Runtime parameters for grouped matmul examples.
*/
struct GroupedMatmulParams : public MatmulParams {
enum class SliceMode : uint32_t
{
M = 0,
K = 1,
N = 2
};
SliceMode sliceMode = SliceMode::M;
};
* @brief Runtime parameters for strided batched matmul examples.
*/
struct StridedBatchedMatmulParams : public MatmulParams {
int64_t strideA = -1;
int64_t strideB = -1;
int64_t strideC = -1;
int64_t lda = -1;
int64_t ldb = -1;
int64_t ldc = -1;
};
* @brief Runtime parameters for GEMM examples with alpha and beta scaling.
*/
struct GemmParams : public MatmulParams {
float alpha = 1.0f;
float beta = 0.0f;
};
extern "C" {
* @brief JIT interface for example 00_basic_matmul.
*/
void BasicMatmul(const uint32_t blockNum, aclrtStream stream, const TParams& tParams, const MatmulParams& params);
* @brief JIT interface for example 01_batched_matmul.
*/
void BatchedMatmul(const uint32_t blockNum, aclrtStream stream, const TParams& tParams, const MatmulParams& params);
* @brief JIT interface for example 02_grouped_matmul_slice_m.
*/
void GroupedMatmulSliceM(
const uint32_t blockNum, aclrtStream stream, const TParams& tParams, const GroupedMatmulParams& params);
* @brief JIT interface for example 03_matmul_add.
*/
void MatmulAdd(const uint32_t blockNum, aclrtStream stream, const TParams& tParams, const MatmulParams& params);
* @brief Reserved JIT interface for example 04_padding_matmul.
*/
void PaddingMatmul(const uint32_t blockNum, aclrtStream stream, const TParams& tParams, const MatmulParams& params);
* @brief JIT interface for example 05_grouped_matmul_slice_k.
*/
void GroupedMatmulSliceK(
const uint32_t blockNum, aclrtStream stream, const TParams& tParams, const GroupedMatmulParams& params);
* @brief Reserved JIT interface for example 06_optimized_matmul.
*/
void OptimizedMatmul(const uint32_t blockNum, aclrtStream stream, const TParams& tParams, const MatmulParams& params);
* @brief JIT interface for example 07_grouped_matmul_slice_m_per_token_dequant_moe.
*/
void GroupedMatmulSliceMPerTokenDequantMoe(
const uint32_t blockNum, aclrtStream stream, const TParams& tParams, const GroupedMatmulParams& params);
* @brief Reserved JIT interface for example 08_grouped_matmul.
*/
void GroupedMatmul(
const uint32_t blockNum, aclrtStream stream, const TParams& tParams, const GroupedMatmulParams& params);
* @brief JIT interface for example 09_splitk_matmul.
*/
void SplitkMatmul(const uint32_t blockNum, aclrtStream stream, const TParams& tParams, const MatmulParams& params);
* @brief JIT interface for example 10_grouped_matmul_slice_m_per_token_dequant.
*/
void GroupedMatmulSliceMPerTokenDequant(
const uint32_t blockNum, aclrtStream stream, const TParams& tParams, const GroupedMatmulParams& params);
* @brief JIT interface for example 11_grouped_matmul_slice_k_per_token_dequant.
*/
void GroupedMatmulSliceKPerTokenDequant(
const uint32_t blockNum, aclrtStream stream, const TParams& tParams, const GroupedMatmulParams& params);
* @brief JIT interface for example 12_quant_matmul.
*/
void QuantMatmul(const uint32_t blockNum, aclrtStream stream, const TParams& tParams, const MatmulParams& params);
* @brief Reserved JIT interface for example 13_basic_matmul_tla.
*/
void BasicMatmulTLA(const uint32_t blockNum, aclrtStream stream, const TParams& tParams, const MatmulParams& params);
* @brief Reserved JIT interface for example 14_optimized_matmul_tla.
*/
void OptimizedMatmulTLA(
const uint32_t blockNum, aclrtStream stream, const TParams& tParams, const MatmulParams& params);
* @brief Reserved JIT interface for example 15_gemm.
*/
void Gemm(const uint32_t blockNum, aclrtStream stream, const TParams& tParams, const GemmParams& params);
* @brief Reserved JIT interface for example 16_group_gemm.
*/
void GroupGemm(const uint32_t blockNum, aclrtStream stream, const TParams& tParams, const GroupedMatmulParams& params);
* @brief Reserved JIT interface for example 17_gemv_aiv.
*/
void GemvAIV(const uint32_t blockNum, aclrtStream stream, const TParams& tParams, const MatmulParams& params);
* @brief Reserved JIT interface for example 18_gemv_aic.
*/
void GemvAIC(const uint32_t blockNum, aclrtStream stream, const TParams& tParams, const MatmulParams& params);
* @brief JIT interface for example 20_matmul_bias.
*/
void MatmulBias(const uint32_t blockNum, aclrtStream stream, const TParams& tParams, const MatmulParams& params);
* @brief Reserved JIT interface for example 21_basic_matmul_preload_zN.
*/
void BasicMatmulPreloadZN(
const uint32_t blockNum, aclrtStream stream, const TParams& tParams, const MatmulParams& params);
* @brief JIT interface for example 22_padding_splitk_matmul.
*/
void PaddingSplitkMatmul(
const uint32_t blockNum, aclrtStream stream, const TParams& tParams, const MatmulParams& params);
* @brief Reserved JIT interface for example 25_matmul_full_loadA.
*/
void MatmulFullLoadA(const uint32_t blockNum, aclrtStream stream, const TParams& tParams, const MatmulParams& params);
* @brief JIT interface for example 26_matmul_relu.
*/
void MatmulRelu(const uint32_t blockNum, aclrtStream stream, const TParams& tParams, const MatmulParams& params);
* @brief JIT interface for example 27_matmul_gelu.
*/
void MatmulGelu(const uint32_t blockNum, aclrtStream stream, const TParams& tParams, const MatmulParams& params);
* @brief JIT interface for example 28_matmul_silu.
*/
void MatmulSilu(const uint32_t blockNum, aclrtStream stream, const TParams& tParams, const MatmulParams& params);
* @brief JIT interface for example 30_w8a16_matmul.
*/
void W8A16Matmul(const uint32_t blockNum, aclrtStream stream, const TParams& tParams, const MatmulParams& params);
* @brief Reserved JIT interface for example 31_small_matmul.
*/
void SmallMatmul(const uint32_t blockNum, aclrtStream stream, const TParams& tParams, const MatmulParams& params);
* @brief JIT interface for example 32_w4a8_matmul.
*/
void W4A8Matmul(const uint32_t blockNum, aclrtStream stream, const TParams& tParams, const MatmulParams& params);
* @brief Reserved JIT interface for example 34_single_core_splitk_matmul.
*/
void SingleCoreSplitkMatmul(
const uint32_t blockNum, aclrtStream stream, const TParams& tParams, const MatmulParams& params);
* @brief Reserved JIT interface for example 35_w4a8_grouped_matmul_msd.
*/
void W4A8GroupedMatmulMSD(
const uint32_t blockNum, aclrtStream stream, const TParams& tParams, const GroupedMatmulParams& params);
* @brief Reserved JIT interface for example 36_w4a8_matmul_msd.
*/
void W4A8MatmulMSD(const uint32_t blockNum, aclrtStream stream, const TParams& tParams, const MatmulParams& params);
* @brief Reserved JIT interface for example 37_streamk_matmul.
*/
void StreamkMatmul(const uint32_t blockNum, aclrtStream stream, const TParams& tParams, const MatmulParams& params);
* @brief Reserved JIT interface for example 38_w4a4_matmul_per_token_per_channel_dequant.
*/
void W4A4MatmulPerTokenPerChannelDequant(
const uint32_t blockNum, aclrtStream stream, const TParams& tParams, const MatmulParams& params);
* @brief Reserved JIT interface for example 39_big_matmul_tla.
*/
void BigMatmulTLA(const uint32_t blockNum, aclrtStream stream, const TParams& tParams, const MatmulParams& params);
* @brief Reserved JIT interface for example 41_sparse_matmul_tla.
*/
void SparseMatmulTLA(const uint32_t blockNum, aclrtStream stream, const TParams& tParams, const MatmulParams& params);
* @brief JIT interface for example 42_quant_optimized_matmul_tla.
*/
void QuantOptimizedMatmulTLA(
const uint32_t blockNum, aclrtStream stream, const TParams& tParams, const MatmulParams& params);
* @brief JIT interface for example 43_ascend950_basic_matmul.
*/
void Ascend950BasicMatmul(
const uint32_t blockNum, aclrtStream stream, const TParams& tParams, const MatmulParams& params);
* @brief JIT interface for example 44_quant_matmul_full_loadA_tla.
*/
void QuantMatmulFullLoadATLA(
const uint32_t blockNum, aclrtStream stream, const TParams& tParams, const MatmulParams& params);
* @brief Reserved JIT interface for example 45_strided_batched_matmul_tla.
*/
void StridedBatchedMatmulTLA(
const uint32_t blockNum, aclrtStream stream, const TParams& tParams, const StridedBatchedMatmulParams& params);
* @brief Reserved JIT interface for example 46_ascend950_matmul_fixpipe_opti.
*/
void Ascend950MatmulFixpipeOpti(
const uint32_t blockNum, aclrtStream stream, const TParams& tParams, const MatmulParams& params);
* @brief Reserved JIT interface for example 47_ascend950_grouped_matmul_slice_m_per_token_dequant.
*/
void Ascend950GroupedMatmulSliceMPerTokenDequant(
const uint32_t blockNum, aclrtStream stream, const TParams& tParams, const GroupedMatmulParams& params);
* @brief Reserved JIT interface for example 48_ascend950_grouped_matmul_slice_m_per_tensor_per_channel_dequant.
*/
void Ascend950GroupedMatmulSliceMPerTensorPerChannelDequant(
const uint32_t blockNum, aclrtStream stream, const TParams& tParams, const GroupedMatmulParams& params);
* @brief Reserved JIT interface for example 50_ascend950_basic_matmul_gemv.
*/
void Ascend950BasicMatmulGemv(
const uint32_t blockNum, aclrtStream stream, const TParams& tParams, const MatmulParams& params);
* @brief Reserved JIT interface for example 51_ascend950_quant_matmul_per_group_per_block_tla.
*/
void Ascend950QuantMatmulPerGroupPerBlockTLA(
const uint32_t blockNum, aclrtStream stream, const TParams& tParams, const MatmulParams& params);
* @brief JIT interface for example 52_quant_multi_core_splitk_matmul_tla.
*/
void QuantMultiCoreSplitkMatmulTLA(
const uint32_t blockNum, aclrtStream stream, const TParams& tParams, const MatmulParams& params);
* @brief Reserved JIT interface for example 53_ascend950_fp8_mx_matmul.
*/
void Ascend950Fp8MxMatmul(
const uint32_t blockNum, aclrtStream stream, const TParams& tParams, const MatmulParams& params);
* @brief JIT interface for example 53_ascend950_fp8_mx_matmul_aswt.
*/
void Ascend950Fp8MxMatmulAswt(
const uint32_t blockNum, aclrtStream stream, const TParams& tParams, const MatmulParams& params);
* @brief Reserved JIT interface for example 54_ascend950_fp4_mx_matmul.
*/
void Ascend950Fp4MxMatmul(
const uint32_t blockNum, aclrtStream stream, const TParams& tParams, const MatmulParams& params);
* @brief JIT interface for example 54_ascend950_fp4_mx_matmul_aswt.
*/
void Ascend950Fp4MxMatmulAswt(
const uint32_t blockNum, aclrtStream stream, const TParams& tParams, const MatmulParams& params);
* @brief Reserved JIT interface for example 57_ascend950_matmul_full_dequant.
*/
void Ascend950MatmulFullDequant(
const uint32_t blockNum, aclrtStream stream, const TParams& tParams, const MatmulParams& params);
* @brief JIT interface for example 58_ascend950_fp8_mx_batch_matmul.
*/
void Ascend950Fp8MxBatchMatmul(
const uint32_t blockNum, aclrtStream stream, const TParams& tParams, const MatmulParams& params);
* @brief JIT interface for example 63_ascend950_dual_level_quant_mx_batch_matmul.
*/
void Ascend950DualLevelQuantMxBatchMatmul(
const uint32_t blockNum, aclrtStream stream, const TParams& tParams, const MatmulParams& params);
* @brief Reserved JIT interface for example 59_ascend950_a8w4_mx_matmul.
*/
void Ascend950A8W4MxMatmul(
const uint32_t blockNum, aclrtStream stream, const TParams& tParams, const MatmulParams& params);
* @brief Reserved JIT interface for example 60_ascend950_grouped_matmul_slice_m.
*/
void Ascend950GroupedMatmulSliceM(
const uint32_t blockNum, aclrtStream stream, const TParams& tParams, const GroupedMatmulParams& params);
* @brief JIT interface for example 64_ascend950_matmul_evg (unified EVG matmul entry).
*
* Selects the JIT template via ``params.evgType`` (e.g. add, add_ub, bias, leaky_relu, ...).
*/
void MatmulEvg(const uint32_t blockNum, aclrtStream stream, const TParams& tParams, const MatmulEvgParams& params);
* @brief Reserved JIT interface for example 102_dynamic_optimized_matmul.
*/
void DynamicOptimizedMatmul(
const uint32_t blockNum, aclrtStream stream, const TParams& tParams, const MatmulParams& params);
* @brief Reserved JIT interface for example 103_dynamic_optimized_quant_matmul_per_token_basic.
*/
void DynamicOptimizedQuantMatmulPerTokenBasic(
const uint32_t blockNum, aclrtStream stream, const TParams& tParams, const MatmulParams& params);
}
}
#endif