* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
*\file kernel_operator_mm_bitmode_intf.h
*\brief
*/
#ifndef ASCENDC_MODULE_OPERATOR_MM_BITMODE_INTERFACE_H
#define ASCENDC_MODULE_OPERATOR_MM_BITMODE_INTERFACE_H
#if defined(__NPU_ARCH__) && (__NPU_ARCH__ == 3101 || __NPU_ARCH__ == 5102)
#include "kernel_struct_mm.h"
#include "../../impl/basic_api/kernel_operator_mm_bitmode_struct.h"
namespace AscendC {
class Load2DBitModeParam {
public:
__aicore__ inline Load2DBitModeParam();
__aicore__ inline Load2DBitModeParam(const LoadData2DParamsV2 &loadData2DParams_);
__aicore__ inline void SetMStartPosition(uint32_t mStartPosition_);
__aicore__ inline void SetKStartPosition(uint32_t kStartPosition_);
__aicore__ inline void SetMStep(uint16_t mStep_);
__aicore__ inline void SetKStep(uint16_t kStep_);
__aicore__ inline void SetSrcStride(int32_t srcStride_);
__aicore__ inline void SetDstStride(uint16_t dstStride_);
__aicore__ inline void SetIfTranspose(bool ifTranspose_);
__aicore__ inline void SetConfig0(uint64_t config0_) {
config0 = config0_;
};
__aicore__ inline void SetConfig1(uint64_t config1_) {
config1 = config1_;
};
__aicore__ inline uint32_t GetMStartPosition() const;
__aicore__ inline uint32_t GetKStartPosition() const;
__aicore__ inline uint16_t GetMStep() const;
__aicore__ inline uint16_t GetKStep() const;
__aicore__ inline int32_t GetSrcStride() const;
__aicore__ inline uint16_t GetDstStride() const;
__aicore__ inline uint64_t GetIfTranspose() const {
return ifTranspose;
};
__aicore__ inline uint64_t GetConfig0() const {
return config0;
};
__aicore__ inline uint64_t GetConfig1() const {
return config1;
};
private:
union {
uint64_t config0;
struct Load2DBitModeConfig0 config0BitMode;
};
union {
uint64_t config1;
struct Load2DBitModeConfig1 config1BitMode;
};
bool ifTranspose = false;
};
class Load3DBitModeParam {
public:
__aicore__ inline Load3DBitModeParam();
template <typename T>
__aicore__ inline Load3DBitModeParam(const LoadData3DParamsV2<T> &loadData3DParams_);
__aicore__ inline void SetKExtension(uint16_t kStep_);
__aicore__ inline void SetMExtension(uint16_t mStep_);
__aicore__ inline void SetKStartPt(uint16_t kPos_);
__aicore__ inline void SetMStartPt(uint16_t mPos_);
__aicore__ inline void SetStrideW(uint8_t strideW_);
__aicore__ inline void SetStrideH(uint8_t strideH_);
__aicore__ inline void SetFilterW(uint8_t Wk_);
__aicore__ inline void SetFilterH(uint8_t Hk_);
__aicore__ inline void SetDilationFilterW(uint8_t dilationW_);
__aicore__ inline void SetDilationFilterH(uint8_t dilationH_);
__aicore__ inline void SetFilterSizeW(bool filterSizeW_);
__aicore__ inline void SetFilterSizeH(bool filterSizeH_);
__aicore__ inline void SetTranspose(bool transpose_);
__aicore__ inline void SetFMatrixCtrl(bool fmatrixCtrl_);
__aicore__ inline void SetChannelSize(uint16_t sizeChannel_);
__aicore__ inline void SetConfig0(uint64_t config0_) {
config0 = config0_;
};
__aicore__ inline void SetConfig1(uint64_t config1_) {
config1 = config1_;
};
__aicore__ inline uint16_t GetKExtension() const;
__aicore__ inline uint16_t GetMExtension() const;
__aicore__ inline uint16_t GetKStartPt() const;
__aicore__ inline uint16_t GetMStartPt() const;
__aicore__ inline uint8_t GetStrideW() const;
__aicore__ inline uint8_t GetStrideH() const;
__aicore__ inline uint8_t GetFilterW() const;
__aicore__ inline uint8_t GetFilterH() const;
__aicore__ inline uint8_t GetDilationFilterW() const;
__aicore__ inline uint8_t GetDilationFilterH() const;
__aicore__ inline bool GetFilterSizeW() const;
__aicore__ inline bool GetFilterSizeH() const;
__aicore__ inline bool GetTranspose() const;
__aicore__ inline bool GetFMatrixCtrl() const;
__aicore__ inline uint16_t GetChannelSize() const;
__aicore__ inline uint64_t GetConfig0() const {
return config0;
};
__aicore__ inline uint64_t GetConfig1() const {
return config1;
};
private:
union {
uint64_t config0;
struct Load3DBitModeConfig0 config0BitMode;
};
union {
uint64_t config1;
struct Load3DBitModeConfig1 config1BitMode;
};
};
class SetFMatrixBitModeParams {
public:
__aicore__ inline SetFMatrixBitModeParams();
template <typename T>
__aicore__ inline SetFMatrixBitModeParams(const LoadData3DParamsV2<T> &loadData3DParams_);
__aicore__ inline void SetL1H(uint16_t l1H_);
__aicore__ inline void SetL1W(uint16_t l1W_);
__aicore__ inline void SetPadList(const uint8_t padList_[4]);
__aicore__ inline void SetConfig0(uint64_t config0_) {
config0 = config0_;
};
__aicore__ inline uint64_t GetConfig0() const {
return config0;
};
private:
union {
uint64_t config0;
struct SetFMatrixBitModeConfig0 config0BitMode;
};
};
#if defined(__NPU_ARCH__) && (__NPU_ARCH__ == 3101)
class MmadBitModeParams {
public:
__aicore__ inline MmadBitModeParams();
__aicore__ inline MmadBitModeParams(const MmadParams &mmadParams_);
__aicore__ inline void SetM(uint16_t m_);
__aicore__ inline void SetK(uint16_t k_);
__aicore__ inline void SetN(uint16_t n_);
__aicore__ inline void SetUnitFlag(uint8_t unitFlag_);
__aicore__ inline void SetDisableGemv(bool disableGemv_);
__aicore__ inline void SetCmatrixSource(bool cmatrixSource_);
__aicore__ inline void SetCmatrixInitVal(bool cmatrixInitVal_);
__aicore__ inline void SetConfig0(uint64_t config0_) {
config0 = config0_;
};
__aicore__ inline uint16_t GetM() const;
__aicore__ inline uint16_t GetK() const;
__aicore__ inline uint16_t GetN() const;
__aicore__ inline uint8_t GetUnitFlag() const;
__aicore__ inline bool GetDisableGemv() const;
__aicore__ inline bool GetCmatrixSource() const;
__aicore__ inline bool GetCmatrixInitVal() const;
__aicore__ inline uint64_t GetConfig0() const {
return config0;
};
private:
union {
uint64_t config0;
struct MmadBitModeConfig0 config0BitMode;
};
};
#endif
}
#include "../../impl/basic_api/kernel_operator_mm_bitmode_intf_impl.h"
#endif
#endif