* This program is free software, you can redistribute it and/or modify.
* Copyright (c) 2026 Huawei Technologies Co., Ltd.
* This file is a part of the CANN Open Software.
* Licensed under CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING
* BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. See LICENSE in the root of
* the software repository for the full text of the License.
*/
#ifndef OPTEST_CATLASS_KERNEL_PREBUILT_H
#define OPTEST_CATLASS_KERNEL_PREBUILT_H
#include <string>
#include <cstdint>
#include <vector>
#include <acl/acl.h>
#include "catlass_kernel_jit.h"
namespace CatlassKernel {
* @brief Runtime parameters shared by prebuilt numbered examples.
*/
struct PrebuiltParams {
std::vector<uint8_t*> inputAddr;
std::vector<uint8_t*> outputAddr;
};
* @brief Runtime parameters for convolution examples.
*/
struct ConvParams : public PrebuiltParams {
aclDataType inputDataType = aclDataType::ACL_FLOAT16;
aclDataType biasDataType = aclDataType::ACL_FLOAT;
aclDataType outputDataType = aclDataType::ACL_FLOAT16;
std::vector<uint32_t> fmapRelated;
std::vector<uint32_t> filterRelated;
std::vector<uint32_t> strideList;
std::vector<uint32_t> padList;
std::vector<uint32_t> dilationList;
};
* @brief Runtime parameters for flash-attention style examples.
*/
struct FlashAttentionParams : public PrebuiltParams {
uint32_t qNtokens = 0;
uint32_t batch = 0;
uint32_t qSeqlen = 0;
uint32_t kvSeqlen = 0;
uint32_t numHeads = 0;
uint32_t kvHeads = 0;
uint32_t embeddingSize = 0;
uint32_t isVariedLen = 0;
uint32_t maskType = 0;
uint32_t blockSize = 128;
aclDataType dataType = ACL_FLOAT16;
};
* @brief Runtime parameters for MLA examples.
*/
struct MlaParams : public FlashAttentionParams {
uint32_t qRopeHeadDim = 0;
uint32_t kvRopeHeadDim = 0;
uint32_t numBlocks = 0;
std::vector<int32_t> qSeqHost;
std::vector<int32_t> kvSeqHost;
mutable std::vector<uint8_t> outputHost;
};
* @brief Runtime parameters for flash-attention style examples.
*/
struct FlashAttentionChunkPrefillParams : public PrebuiltParams {
uint32_t qNtokens = 0;
uint32_t batch = 0;
uint32_t qSeqlen = 0;
uint32_t kvSeqlen = 0;
uint32_t numHeads = 0;
uint32_t kvHeads = 0;
uint32_t qkembeddingSize = 0;
uint32_t vembeddingSize = 0;
uint32_t isVariedLen = 0;
uint32_t maskType = 0;
uint32_t blockSize = 128;
uint32_t numBlocks = 2048;
std::string cacheLayout = "nd";
aclDataType dataType = ACL_FLOAT16;
};
* @brief Reserved prebuilt interface for example 19_mla.
*/
__attribute__((weak)) void Mla(const uint32_t blockNum, aclrtStream stream, const MlaParams& params);
* @brief Reserved prebuilt interface for example 23_flash_attention_infer.
*/
__attribute__((weak)) void FlashAttentionInfer(const uint32_t blockNum, aclrtStream stream, const FlashAttentionParams& params);
* @brief Reserved prebuilt interface for example 24_conv_bias.
*/
__attribute__((weak)) void ConvBias(const uint32_t blockNum, aclrtStream stream, const ConvParams& params);
* @brief Reserved prebuilt interface for example 33_basic_conv2d.
*/
__attribute__((weak)) void BasicConv2d(const uint32_t blockNum, aclrtStream stream, const ConvParams& params);
* @brief Reserved prebuilt interface for example 40_flash_attention_infer_tla.
*/
__attribute__((weak)) void FlashAttentionInferTLA(const uint32_t blockNum, aclrtStream stream, const FlashAttentionParams& params);
* @brief Reserved prebuilt interface for example 49_ascend950_flash_attention_infer.
*/
__attribute__((weak)) void Ascend950FlashAttentionInfer(const uint32_t blockNum, aclrtStream stream, const FlashAttentionParams& params);
* @brief Reserved prebuilt interface for example 56_ascend950_basic_conv2d_tla.
*/
__attribute__((weak)) void Ascend950BasicConv2dTLA(const uint32_t blockNum, aclrtStream stream, const ConvParams& params);
* @brief Runtime parameters for Ascend950 MXFP8 flash attention examples.
*/
struct Ascend950MxFp8FlashAttentionParams : public FlashAttentionParams {
uint32_t usePscale = 0;
};
* @brief Prebuilt interface for example 72_ascend950_fp8_mx_flash_attention_infer.
*/
__attribute__((weak)) void Ascend950MxFp8FlashAttentionInfer(const uint32_t blockNum, aclrtStream stream,
const Ascend950MxFp8FlashAttentionParams& params);
* @brief Prebuilt interface for example 29_a2_fp8_e4m3_matmul.
*/
extern "C" __attribute__((weak)) void A2Fp8E4M3Matmul(const uint32_t blockNum, aclrtStream stream, const TParams& tParams, const MatmulParams& params);
* @brief Reserved prebuilt interface for example 70_ascend950_flash_attention_chunk_prefill.
*/
__attribute__((weak)) void FlashAttentionChunkPrefill(const uint32_t blockNum, aclrtStream stream, const FlashAttentionChunkPrefillParams& params);
* @brief Broadcast MatMul with Per-Block Quantization(Ascend 950 TLA)。
* @param blockNum 启用的 AI Core 数量。
* @param stream ACL 计算流。
* @param params 运行期参数(M/N/K/batch、地址)。
*/
__attribute__((weak))
void BroadcastMatmulPerblockQuant(const uint32_t blockNum, aclrtStream stream, const MatmulParams& params);
}
#endif