* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
* \file calc_api.h
* \brief Calculator API
*/
#pragma once
#include <cstdint>
#include <ostream>
#include "tilefwk/data_type.h"
#include "tilefwk/element.h"
namespace npu::tile_fwk {
struct TensorData {
void* dataPtr = nullptr;
std::vector<int64_t> rawShape;
std::vector<int64_t> shape;
std::vector<int64_t> stride;
int64_t storageOffset;
DataType dtype;
bool isAxisCombine = false;
};
struct MatMulParam {
bool aTrans = false;
bool bTrans = false;
bool aScaleTrans = false;
bool bScaleTrans = false;
int64_t kStep = 0;
uint64_t scale = 0;
int relu = 0;
const TensorData* scalePtr = nullptr;
const TensorData* biasPtr = nullptr;
const TensorData* aScalePtr = nullptr;
const TensorData* bScalePtr = nullptr;
};
enum class CmpOperationType {
EQ,
NE,
LT,
LE,
GT,
GE,
};
enum class CmpModeType {
BOOL,
BIT,
};
struct CalcOps {
void (*Random)(const TensorData&);
bool (*AllClose)(const TensorData&, const TensorData&, double, double);
void (*Cast)(const TensorData&, const TensorData&, CastMode);
void (*Exp)(const TensorData&, const TensorData&);
void (*Exp2)(const TensorData&, const TensorData&);
void (*Expm1)(const TensorData&, const TensorData&);
void (*Sin)(const TensorData&, const TensorData&);
void (*Cos)(const TensorData&, const TensorData&);
void (*Erf)(const TensorData&, const TensorData&);
void (*Sinh)(const TensorData&, const TensorData&);
void (*Cosh)(const TensorData&, const TensorData&);
void (*Erfc)(const TensorData&, const TensorData&);
void (*Asin)(const TensorData&, const TensorData&);
void (*Acos)(const TensorData&, const TensorData&);
void (*ASinh)(const TensorData&, const TensorData&);
void (*ACosh)(const TensorData&, const TensorData&);
void (*Atanh)(const TensorData&, const TensorData&);
void (*Neg)(const TensorData&, const TensorData&);
void (*Rsqrt)(const TensorData&, const TensorData&);
void (*Sign)(const TensorData&, const TensorData&);
void (*Signbit)(const TensorData&, const TensorData&);
void (*Tanh)(const TensorData&, const TensorData&);
void (*Tan)(const TensorData&, const TensorData&);
void (*Sqrt)(const TensorData&, const TensorData&);
void (*Ceil)(const TensorData&, const TensorData&);
void (*Floor)(const TensorData&, const TensorData&);
void (*Trunc)(const TensorData&, const TensorData&);
void (*Round)(const TensorData&, const TensorData&, int);
void (*Reciprocal)(const TensorData&, const TensorData&);
void (*Relu)(const TensorData&, const TensorData&);
void (*Log1p)(const TensorData&, const TensorData&);
void (*Pad)(const TensorData&, const TensorData&, const Element&);
void (*FillPad)(const TensorData&, const TensorData&, const Element&);
void (*BitwiseNot)(const TensorData&, const TensorData&);
void (*Abs)(const TensorData&, const TensorData&);
void (*Brcb)(const TensorData&, const TensorData&);
void (*WhereTT)(const TensorData&, const TensorData&, const TensorData&, const TensorData&);
void (*WhereTS)(const TensorData&, const TensorData&, const TensorData&, const Element&);
void (*WhereST)(const TensorData&, const TensorData&, const Element&, const TensorData&);
void (*WhereSS)(const TensorData&, const TensorData&, const Element&, const Element&);
void (*LReLU)(const TensorData&, const TensorData&, const Element&);
void (*Ln)(const TensorData&, const TensorData&);
void (*IsFinite)(const TensorData&, const TensorData&);
void (*LogicalNot)(const TensorData&, const TensorData&);
void (*Range)(const TensorData&, const Element&, const Element&, const Element&);
void (*Compare)(const TensorData&, const TensorData&, const TensorData&, CmpOperationType, CmpModeType);
void (*Cmps)(const TensorData&, const TensorData&, const Element&, CmpOperationType, CmpModeType);
void (*Hypot)(const TensorData&, const TensorData&, const TensorData&);
void (*PReLU)(const TensorData&, const TensorData&, const TensorData&);
void (*LogicalAnd)(const TensorData&, const TensorData&, const TensorData&);
void (*Uniform)(const TensorData &, const Element &, const Element &, const Element &, const Element &, DataType);
void (*AddS)(const TensorData&, const TensorData&, const Element&, bool);
void (*SubS)(const TensorData&, const TensorData&, const Element&, bool);
void (*MulS)(const TensorData&, const TensorData&, const Element&, bool);
void (*DivS)(const TensorData&, const TensorData&, const Element&, bool);
void (*FloorDivS)(const TensorData&, const TensorData&, const Element&, bool);
void (*FmodS)(const TensorData&, const TensorData&, const Element&, bool);
void (*RemainderS)(const TensorData&, const TensorData&, const Element&, bool);
void (*RemainderRS)(const TensorData&, const TensorData&, const Element&, bool);
void (*PowS)(const TensorData&, const TensorData&, const Element&, bool);
void (*BitwiseAndS)(const TensorData&, const TensorData&, const Element&, bool);
void (*BitwiseOrS)(const TensorData&, const TensorData&, const Element&, bool);
void (*BitwiseXorS)(const TensorData&, const TensorData&, const Element&, bool);
void (*GcdS)(const TensorData&, const TensorData&, const Element&);
void (*Add)(const TensorData&, const TensorData&, const TensorData&);
void (*Sub)(const TensorData&, const TensorData&, const TensorData&);
void (*Mul)(const TensorData&, const TensorData&, const TensorData&);
void (*Div)(const TensorData&, const TensorData&, const TensorData&);
void (*FloorDiv)(const TensorData&, const TensorData&, const TensorData&);
void (*Fmod)(const TensorData&, const TensorData&, const TensorData&);
void (*Remainder)(const TensorData&, const TensorData&, const TensorData&);
void (*Pow)(const TensorData&, const TensorData&, const TensorData&);
void (*BitwiseAnd)(const TensorData&, const TensorData&, const TensorData&);
void (*BitwiseOr)(const TensorData&, const TensorData&, const TensorData&);
void (*BitwiseXor)(const TensorData&, const TensorData&, const TensorData&);
void (*ExpandExpDif)(const TensorData&, const TensorData&, const TensorData&);
void (*CopySign)(const TensorData&, const TensorData&, const TensorData&);
void (*Gcd)(const TensorData&, const TensorData&, const TensorData&);
void (*PairSum)(const TensorData&, const TensorData&, const TensorData&);
void (*PairMax)(const TensorData&, const TensorData&, const TensorData&);
void (*PairMin)(const TensorData&, const TensorData&, const TensorData&);
void (*PairProd)(const TensorData&, const TensorData&, const TensorData&);
void (*Min)(const TensorData&, const TensorData&, const TensorData&);
void (*Max)(const TensorData&, const TensorData&, const TensorData&);
void (*MinS)(const TensorData&, const TensorData&, const Element&);
void (*MaxS)(const TensorData&, const TensorData&, const Element&);
void (*RowSumExpand)(const TensorData&, const TensorData&, int);
void (*RowMinExpand)(const TensorData&, const TensorData&, int);
void (*RowMaxExpand)(const TensorData&, const TensorData&, int);
void (*RowSumSingle)(const TensorData&, const TensorData&, int);
void (*RowMinSingle)(const TensorData&, const TensorData&, int);
void (*RowMaxSingle)(const TensorData&, const TensorData&, int);
void (*RowProdSingle)(const TensorData&, const TensorData&, int);
void (*RowMinLine)(const TensorData&, const TensorData&, int);
void (*RowMaxLine)(const TensorData&, const TensorData&, int);
void (*RowProdLine)(const TensorData&, const TensorData&, int);
void (*RowArgMaxSingle)(const TensorData&, const TensorData&, int);
void (*RowArgMinSingle)(const TensorData&, const TensorData&, int);
void (*RowArgMaxWithValueSingle)(const TensorData&, const TensorData&, const TensorData&, const TensorData&, int);
void (*RowArgMinWithValueSingle)(const TensorData&, const TensorData&, const TensorData&, const TensorData&, int);
void (*RowArgMaxWithValueLine)(const TensorData&, const TensorData&, const TensorData&, const TensorData&, int);
void (*RowArgMinWithValueLine)(const TensorData&, const TensorData&, const TensorData&, const TensorData&, int);
void (*PairArgMax)(const TensorData&, const TensorData&, const TensorData&, const TensorData&, const TensorData&, const TensorData&);
void (*PairArgMin)(const TensorData&, const TensorData&, const TensorData&, const TensorData&, const TensorData&, const TensorData&);
void (*OneHot)(const TensorData&, const TensorData&, int);
void (*ExpandS)(const TensorData&, const Element&);
void (*Expand)(const TensorData&, const TensorData&);
void (*GatherElements)(const TensorData&, const TensorData&, const TensorData&, int);
void (*GatherMask)(const TensorData&, const TensorData&, int);
void (*IndexAdd)(const TensorData&, const TensorData&, const TensorData&, const TensorData&, int, const Element&);
void (*TriU)(const TensorData&, const TensorData&, int);
void (*TriL)(const TensorData&, const TensorData&, int);
void (*CumSum)(const TensorData&, const TensorData&, int);
void (*CumProd)(const TensorData&, const TensorData&, int);
void (*IndexPut)(const TensorData&, const TensorData&, const std::vector<TensorData>&, const TensorData&, bool);
void (*Atan)(const TensorData&, const TensorData&);
void (*Atan2)(const TensorData&, const TensorData&, const TensorData&);
void (*Reshape)(const TensorData&, const TensorData&);
void (*Permute)(const TensorData&, const TensorData&, const std::vector<int64_t>&);
void (*Transpose)(const TensorData&, const TensorData&, int64_t, int64_t);
void (*ReduceAcc)(const TensorData&, const std::vector<TensorData>&);
void (*Copy)(const TensorData&, const TensorData&, bool, bool);
void (*ScatterUpdate)(
const TensorData&, const TensorData&, const TensorData&, const TensorData&, int, std::string, int);
void (*ScatterElement)(const TensorData&, const TensorData&, const TensorData&, const Element&, int, int);
void (*Scatter)(const TensorData&, const TensorData&, const TensorData&, const TensorData&, int, int);
void (*FormatND2NZ)(const TensorData&, const TensorData&);
void (*FormatNZ2ND)(const TensorData&, const TensorData&);
void (*QuantPreCompute)(const TensorData&, const TensorData&, const TensorData*, uint64_t, int);
void (*MatMul)(const TensorData&, const TensorData&, const TensorData&, const TensorData*, MatMulParam&);
void (*Quantize)(const TensorData&, const TensorData&, const TensorData&, const TensorData&);
void (*Dequantize)(const TensorData&, const TensorData&, const TensorData&, const TensorData&);
void (*BitSort)(const TensorData&, const TensorData&, int64_t, bool, int64_t);
void (*TiledMrgSort)(
const TensorData&, const TensorData&, const TensorData&, const TensorData&, const TensorData&, int, int);
void (*Extract)(const TensorData&, const TensorData&, int, bool);
void (*MrgSort)(const TensorData&, const TensorData&, int64_t, int64_t);
void (*TopK)(const TensorData&, const TensorData&, const TensorData&, int, int, bool);
void (*QuantMX)(
const TensorData&, const TensorData&, const TensorData&, const TensorData&, const TensorData&, bool, int64_t);
void (*TopkSort)(const TensorData&, const TensorData&, const TensorData&, int);
void (*TopkMerge)(const TensorData&, const TensorData&, int);
void (*TopkExtract)(const TensorData&, const TensorData&, int, bool);
void (*TwoTileMrgSort)(const TensorData&, const TensorData&);
void (*Sort)(const TensorData&, const TensorData&, const TensorData&, int64_t, bool);
void (*Gather)(const TensorData&, const TensorData&, const TensorData&, int64_t);
void (*GatherINUB)(const TensorData&, const TensorData&, const TensorData&, const TensorData&, int64_t, int64_t);
void (*GatherInL1)(const TensorData&, const TensorData&, const TensorData&, const TensorData&, int64_t);
void (*BitwiseRightShift)(const TensorData&, const TensorData&, const TensorData&);
void (*BitwiseLeftShift)(const TensorData&, const TensorData&, const TensorData&);
void (*BitwiseRightShiftS)(const TensorData&, const TensorData&, const Element&);
void (*BitwiseLeftShiftS)(const TensorData&, const TensorData&, const Element&);
void (*SBitwiseRightShift)(const TensorData&, const Element&, const TensorData&);
void (*SBitwiseLeftShift)(const TensorData&, const Element&, const TensorData&);
};
extern "C" struct CalcOps* GetCalcOps();
}