* Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved.
* This file is a part of the CANN Open Software.
* Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#pragma once
#include "DevContext.h"
#include <vector>
#include <cassert>
#include "tensor_list.h"
std::vector<std::pair<int,int>> network__memory_malloc(DevContext &devObj);
void launch_network(DevContext &devObj, const std::vector<std::pair<int,int>>& layer_params_idx_size, int repeat_cnt);
void gen_rms_kernel_func(DevContext &devObj);
void gen_grouped_matmul_func(DevContext &devObj);
void gen_weight_quant_batch_matmul_v2_func(DevContext &devObj);
void gen_matmul_add_func(DevContext &devObj);
void gen_dequant_swiglu_quant_func(DevContext &devObj);
void ClearOpsKernelLaunch(uint32_t block_dim, void *stream);
void RmsNormKernel(uint32_t block_dim, void *stream, uint8_t *x, uint8_t *gamma, uint8_t *y, uint8_t *rstd, const RMSNormTilingData &tiling);
void GroupedMatmulKernel(uint32_t block_dim, void *stream,
uint8_t *x, uint8_t *weight, uint8_t *bias, uint8_t *scale,
uint8_t *offset, uint8_t *antiquantScale, uint8_t *antiquantOffset,
uint8_t *groupList, uint8_t *perTokenScale, uint8_t *y,
uint8_t *workspace, const GMMTilingData &tiling);
void WeightQuantBatchMatmulV2Kernel(uint32_t numBlocks, void *stream, uint8_t *x, uint8_t *weight, uint8_t *antiquantScale, uint8_t *antiquantOffset, uint8_t *quantScale, uint8_t *quantOffset, uint8_t *bias, uint8_t *y, uint8_t *workspace, const WeightQuantBatchMatmulV2MsdTilingData &tiling);
void MatmulAdd(uint32_t numBlocks, void *stream, uint8_t *x, uint8_t *y);
void DequantSwiGluQuantDynamicKernel(uint32_t numBlocks, void *stream,
uint8_t *xGM, uint8_t *weightSscaleGM,
uint8_t *activationScaleGM, uint8_t *biasGM,
uint8_t *quantScaleGM, uint8_t *quantOffsetGM,
uint8_t *groupIndex, uint8_t *yGM, uint8_t *scaleGM,
uint8_t *workspace, const SwiGluTilingData &tiling);
void GroupedMatmulKernelV2(uint32_t block_dim, void *stream,
uint8_t *x, uint8_t *weight, uint8_t *bias, uint8_t *scale,
uint8_t *offset, uint8_t *antiquantScale, uint8_t *antiquantOffset,
uint8_t *groupList, uint8_t *perTokenScale, uint8_t *y,
uint8_t *workspace, const GMMTilingData &tiling);
void DequantSwiGluQuantDynamicKernel(uint32_t numBlocks, void *stream,
uint8_t *xGM, uint8_t *weightSscaleGM,
uint8_t *activationScaleGM, uint8_t *biasGM,
uint8_t *quantScaleGM, uint8_t *quantOffsetGM,
uint8_t *groupIndex, uint8_t *yGM, uint8_t *scaleGM,
uint8_t *workspace, const DequantSwigluQuantBaseTilingData &tiling);
void GroupedMatmulKernelV3(uint32_t block_dim, void *stream,
uint8_t *x, uint8_t *weight, uint8_t *bias, uint8_t *scale,
uint8_t *offset, uint8_t *antiquantScale, uint8_t *antiquantOffset,
uint8_t *groupList, uint8_t *perTokenScale, uint8_t *y,
uint8_t *workspace, const GMMTilingData &tiling);
void DynamicQuantKernel(uint32_t block_dim, void *stream,
uint8_t *x, uint8_t *smooth_scales, uint8_t *group_index, uint8_t *y,
uint8_t *scale, uint8_t *workSpace, const DynamicQuantTilingData& tiling);