* Copyright (c) 2025-2026 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
* \file tilefwk_op.h
* \brief
*/
#pragma once
#include <array>
#include <sstream>
#include "tilefwk/tensor.h"
#include "tilefwk/element.h"
namespace npu::tile_fwk {
class Function;
class SymbolicScalar;
class Element;
constexpr const int TILE_VEC_DIMS = 2;
constexpr const int TILE_CUBE_DIMS = 6;
enum class OpType {
EQ,
NE,
LT,
LE,
GT,
GE,
};
enum class OutType {
BOOL,
BIT,
};
enum class TopKAlgo {
MERGE_SORT,
RADIX_SELECT,
};
enum class AtomicRMWMode { ADD, MAX, MIN };
enum class DequantScaleRoundingMode : int64_t {
ROUND_UP = 0,
ROUND_DOWN = 1,
};
enum class SaturationMode : uint8_t {
ON = 0,
OFF = 1,
};
enum class PrecisionType : uint8_t {
INTRINSIC = 0,
HIGH_PRECISION = 1
};
namespace experimental {
struct PrintHelper {
SymbolicScalar cond;
std::vector<Tensor> tensors;
std::vector<SymbolicScalar> scalars;
std::stringstream ss;
template <typename T>
void Append(T t)
{
if constexpr (std::is_same_v<T, Tensor>) {
tensors.push_back(t);
ss << "{T}";
} else if constexpr (std::is_same_v<T, SymbolicScalar>) {
scalars.push_back(t);
ss << "{S}";
} else {
ss << t;
}
}
};
void Print(
SymbolicScalar cond, const std::string& format, const std::vector<Tensor>& tensors,
const std::vector<SymbolicScalar>& scalars);
template <bool isB, bool isTrans>
Tensor GatherInL1(const Tensor& src, const Tensor& offsets, const Tensor& blockTable, int blockSize, int size);
Tensor GatherInUB(const Tensor& params, const Tensor& indices, const Tensor& blockTable, int blockSize, int axis);
}
template <typename... Args>
void Print(Args... args)
{
experimental::PrintHelper helper;
(helper.Append(args), ...);
experimental::Print(1, helper.ss.str(), helper.tensors, helper.scalars);
}
template <typename... Args>
void PrintIf(SymbolicScalar cond, Args... args)
{
experimental::PrintHelper helper;
(helper.Append(args), ...);
experimental::Print(cond, helper.ss.str(), helper.tensors, helper.scalars);
}
* \brief Dump a tensor to file
*
* \param cond Dump the tensor only `cond` evaluate result is none zero
* \param operand tensor to dump
* \param fname filename, {S} can be used as scalar placeholder
* \param scalars scalars to dump
*/
void ToFile(
const Tensor& operand, const std::string& fname, const std::vector<SymbolicScalar>& scalars = {},
SymbolicScalar cond = 1);
Tensor View(const Tensor& operand, const std::vector<int64_t>& shapes, const std::vector<int64_t>& offsets);
Tensor View(const Tensor& operand, const DataType dstDataType);
Tensor View(const Tensor& operand, const std::vector<int64_t>& shapes, const std::vector<SymbolicScalar>& newOffsets);
Tensor View(
const Tensor& operand, const std::vector<int64_t>& shapes, const std::initializer_list<SymbolicScalar>& newOffsets);
Tensor View(
const Tensor& operand, const std::vector<int64_t>& shapes, const std::vector<SymbolicScalar>& newValidShapes,
const std::vector<SymbolicScalar>& newOffsets);
Tensor Assemble(const std::vector<std::pair<Tensor, std::vector<int64_t>>>& tensors);
void Assemble(const Tensor& tensor, const std::vector<SymbolicScalar>& dynOffset, Tensor& dest);
struct AssembleItem {
Tensor tensor;
std::vector<SymbolicScalar> offsets;
};
void Assemble(const std::vector<AssembleItem>& items, Tensor& src, bool parallelInAssemble = false);
void AtomicRMW(const Tensor& tensor, const std::vector<SymbolicScalar>& dynOffset, Tensor& dest, AtomicRMWMode mode);
Tensor Reshape(
const Tensor& operand, const std::vector<int64_t>& dstshape, const std::vector<SymbolicScalar>& validShape = {},
const bool inplace = false);
Tensor Reshape(
const Tensor& operand, const std::initializer_list<int64_t>& dstshape,
const std::initializer_list<SymbolicScalar>& validShape = {}, const bool inplace = false);
Tensor Reshape(const Tensor& operand, const std::vector<SymbolicScalar>& dstShape, const bool inplace);
void Reshape(const Tensor& operand, Tensor& dst);
Tensor Full(
const Element& src, DataType dtype, const std::vector<int64_t>& dstShape,
std::vector<SymbolicScalar> validShape = {});
Tensor Full(
const SymbolicScalar& src, DataType dtype, const std::vector<int64_t>& dstShape,
std::vector<SymbolicScalar> validShape = {});
Tensor Transpose(const Tensor& self, std::vector<int> perm);
Tensor TransData(const Tensor& self, TileOpFormat transDataType, int group = 1);
std::shared_ptr<LogicalTensor> TransData(
Function& function, const std::shared_ptr<LogicalTensor>& self, const std::shared_ptr<LogicalTensor>& fakeDstTensor, TileOpFormat transDataType, int group = 1);
Tensor Cast(
const Tensor& self, DataType dstDataType, CastMode mode = CAST_NONE, SaturationMode satmode = SaturationMode::OFF);
Tensor Permute(const Tensor& self, std::vector<int> perm);
Tensor Permute(Function& function, const Tensor& self, std::vector<int> perm);
Tensor Exp(const Tensor& self, PrecisionType precisionType = PrecisionType::INTRINSIC);
Tensor Exp2(const Tensor& self);
Tensor Expm1(const Tensor& self);
Tensor Erfc(const Tensor& self);
Tensor Atan(const Tensor& self);
Tensor Sin(const Tensor& self);
Tensor Cos(const Tensor& self);
Tensor Erf(const Tensor& self);
Tensor Neg(const Tensor& self);
Tensor Round(const Tensor& self, const int& decimals = 0);
Tensor Rsqrt(const Tensor& self, PrecisionType precisionType = PrecisionType::INTRINSIC);
Tensor Relu(const Tensor& self);
Tensor Pad(
const Tensor& self, const std::vector<int64_t>& padding, std::string mode = "constant",
const Element& value = Element(DT_FP32, 0.0));
Tensor FillPad(const Tensor& self, std::string mode = "constant", const Element& value = Element(DT_FP32, 0.0));
Tensor BitwiseNot(const Tensor& self);
Tensor Sqrt(const Tensor& self, PrecisionType precisionType = PrecisionType::INTRINSIC);
Tensor Ceil(const Tensor& self);
Tensor CeilDiv(const Tensor& self, const Tensor& other);
Tensor CeilDiv(const Tensor& self, const Element& other);
Tensor Floor(const Tensor& self);
Tensor FloorDiv(const Tensor& self, const Tensor& other);
Tensor FloorDiv(const Tensor& self, const Element& other);
Tensor Trunc(const Tensor& self);
Tensor Reciprocal(const Tensor& operand, PrecisionType precisionType = PrecisionType::INTRINSIC);
Tensor Abs(const Tensor& self);
Tensor Ln(const Tensor& operand, PrecisionType precisionType = PrecisionType::INTRINSIC);
Tensor Hub(const Tensor& operand);
Tensor Sign(const Tensor& operand);
Tensor Tan(const Tensor& operand);
Tensor Signbit(const Tensor& operand);
Tensor Sinh(const Tensor& self);
Tensor Cosh(const Tensor& self);
Tensor Tanh(const Tensor& operand);
Tensor Asin(const Tensor& self);
Tensor Acos(const Tensor& self);
Tensor ASinh(const Tensor& self);
Tensor ACosh(const Tensor& self);
Tensor Atanh(const Tensor& self);
Tensor Duplicate(const Tensor& operand);
Tensor Gather(const Tensor& params, const Tensor& indices, int axis);
Tensor GatherElements(const Tensor& params, const Tensor& indices, int axis);
Tensor GatherMask(const Tensor& self, const uint8_t patternMode);
enum class ScatterMode {
NONE,
ADD,
MULTIPLY,
UNKNOWN,
};
* \brief Write the scalar value of src into self Tensor, with the write position specified by the indices Tensor.
*
* \param self : Tensor to write into.
* \param indices : the index Tensor of element to be dispersed.
* \param src : scalar value or tensor to be dispersed.
* \param axis : axis to be indexed.
* \param reduce : scatter reduction mode to be applied. Support NONE, ADD, MULTIPLY. NONE is default.
* \return Tensor
*/
Tensor Scatter(
const Tensor& self, const Tensor& indices, const Element& src, int axis, ScatterMode reduce = ScatterMode::NONE);
Tensor Scatter(
const Tensor& self, const Tensor& indices, const Tensor& src, int axis, ScatterMode reduce = ScatterMode::NONE);
void IndexPut_(Tensor& self, const std::vector<Tensor>& indices, const Tensor& values, bool accumulate = false);
Tensor IndexAddUB(
const Tensor& self, const Tensor& src, const Tensor& indices, int axis,
const Element& alpha = Element(DT_FP32, 1.0f));
void IndexAdd_(
Tensor& self, const Tensor& src, const Tensor& indices, int axis, const Element& alpha = Element(DT_FP32, 1.0f));
Tensor RowSumExpand(const Tensor& operand);
Tensor RowMaxExpand(const Tensor& operand);
Tensor Sum(const Tensor& self, int axis = -1, bool keepDim = false);
Tensor Amax(const Tensor& self, int axis = -1, bool keepDim = false);
Tensor ArgMax(const Tensor& self, int axis = -1, bool keepDim = false);
Tensor ArgMin(const Tensor& self, int axis = -1, bool keepDim = false);
Tensor Amin(const Tensor& self, int axis = -1, bool keepDim = false);
Tensor Prod(const Tensor& self, int axis = -1, bool keepDim = false);
Tensor Compact(const Tensor& operand);
Tensor Add(const Tensor& self, const Tensor& other);
Tensor Sub(const Tensor& self, const Tensor& other);
Tensor Div(const Tensor& self, const Tensor& other, PrecisionType precisionType = PrecisionType::INTRINSIC);
Tensor Mul(const Tensor& self, const Tensor& other);
Tensor Hypot(const Tensor& self, const Tensor& other);
Tensor Fmod(const Tensor& self, const Tensor& other, PrecisionType precisionType = PrecisionType::INTRINSIC);
Tensor Maximum(const Tensor& operand1, const Tensor& operand2);
Tensor Minimum(const Tensor& operand1, const Tensor& operand2);
Tensor BitwiseAnd(const Tensor& self, const Tensor& other);
Tensor BitwiseOr(const Tensor& self, const Tensor& other);
Tensor BitwiseXor(const Tensor& self, const Tensor& other);
Tensor ExpandExpDif(const Tensor& input, const Tensor& other);
Tensor Add(const Tensor& self, const Element& other);
Tensor Sub(const Tensor& self, const Element& other);
Tensor Div(const Tensor& self, const Element& other, PrecisionType precisionType = PrecisionType::INTRINSIC);
Tensor Mul(const Tensor& self, const Element& other);
Tensor Fmod(const Tensor& self, const Element& other, PrecisionType precisionType = PrecisionType::INTRINSIC);
Tensor BitwiseAnd(const Tensor& self, const Element& other);
Tensor BitwiseOr(const Tensor& self, const Element& other);
Tensor BitwiseXor(const Tensor& self, const Element& other);
Tensor Minimum(const Tensor& operand1, const Element& operand2);
Tensor Maximum(const Tensor& operand1, const Element& operand2);
Tensor Compare(const Tensor& self, const Tensor& other, OpType op, OutType mode);
Tensor Compare(const Tensor& self, const Element& other, OpType op, OutType mode);
Tensor Compare(const Element& self, const Tensor& other, OpType op, OutType mode);
Tensor Pow(const Tensor& self, const Tensor& other, PrecisionType precisionType = PrecisionType::INTRINSIC);
Tensor Pow(const Tensor& self, const Element& other, PrecisionType precisionType = PrecisionType::INTRINSIC);
Tensor Remainder(const Tensor& self, const Tensor& other, PrecisionType precisionType = PrecisionType::INTRINSIC);
Tensor Remainder(const Tensor& self, const Element& other, PrecisionType precisionType = PrecisionType::INTRINSIC);
Tensor Remainder(const Element& self, const Tensor& other, PrecisionType precisionType = PrecisionType::INTRINSIC);
Tensor CopySign(const Tensor& self, const Tensor& other);
Tensor PReLU(const Tensor& self, const Tensor& weight);
Tensor Atan2(const Tensor& y, const Tensor& x);
Tensor Axpy(const Tensor& self, const Tensor& other, float alpha);
Tensor BitwiseRightShift(const Tensor& self, const Tensor& other);
Tensor BitwiseRightShift(const Tensor& self, const Element& other);
Tensor BitwiseRightShift(const Element& self, const Tensor& other);
Tensor BitwiseLeftShift(const Tensor& self, const Tensor& other);
Tensor BitwiseLeftShift(const Tensor& self, const Element& other);
Tensor BitwiseLeftShift(const Element& self, const Tensor& other);
Tensor Where(const Tensor& condition, const Tensor& input, const Tensor& other);
Tensor Where(const Tensor& condition, const Tensor& input, const Element& other);
Tensor Where(const Tensor& condition, const Element& input, const Tensor& other);
Tensor Where(const Tensor& condition, const Element& input, const Element& other);
Tensor LReLU(const Tensor& self, const Element& negative_slope);
Tensor Unsqueeze(const Tensor& old, int unsqueezeDimNum);
Tensor Squeeze(const Tensor& input, const std::vector<int>& dim = {});
Tensor TensorIndex(const Tensor& params, const Tensor& indices);
Tensor ScatterUpdate(
const Tensor& dst, const Tensor& index, const Tensor& src, int axis = -2, std::string cacheMode = "PA_BNSD",
int chunkSize = 1);
Tensor Expand(const Tensor& self, const std::vector<int64_t>& dstShape, std::vector<SymbolicScalar> validShape = {});
Tensor Var(const Tensor& input, const std::vector<int>& dim = {}, float correction = 1.0f, bool keepDim = false);
Tensor Softmax(const Tensor& operand);
Tensor RmsNorm(const Tensor& operand);
Tensor RmsNorm(const Tensor& operand, const Tensor& gamma, float epsilon = 1e-05f);
Tensor Cat(const std::vector<Tensor>& tensors, int axis);
Tensor NewCompact(const Tensor& operand);
Tensor LogicalNot(const Tensor& self);
Tensor Range(const Element& start, const Element& end, const Element& step);
Tensor LogicalAnd(const Tensor& self, const Tensor& other);
Tensor IsFinite(const Tensor& self);
Tensor Assign(const Tensor& operand);
Tensor Uniform(
const Element& key, const SymbolicScalar& counter0, const Element& counter1, const std::vector<int64_t>& shape,
const Element& rounds, DataType dtype = DT_FP32);
Tensor Clip(const Tensor& self, const Tensor& min = {}, const Tensor& max = {});
Tensor Clip(const Tensor& self, const Element& min = {}, const Element& max = {});
std::tuple<Tensor, Tensor> TopK(
const Tensor& self, int k, int axis = -1, bool isLargest = true, TopKAlgo algo = TopKAlgo::MERGE_SORT);
Tensor ArgSort(const Tensor& self, int axis = -1, bool descending = false);
Tensor Sort32(const Tensor& self, int idxStart = 0);
Tensor MrgSort(const Tensor& self, int mergeSize);
Tensor Quantize(const Tensor& input, const Tensor& scale, DataType dtype, int axis, const Tensor& zeroPoints);
Tensor Dequantize(const Tensor& input, const Tensor& scale, DataType otype, int axis, const Tensor& zeroPoints);
* @brief Sort a tensor with shape (1, n) along the last dimension, n must be orders of 2.
* The vecTile (1, t), t must be orders of 2, maximum is 8K.
* @param x The input tensor to be sorted, the indices are initialized to 0123...
* @param descending If true, sorts in descending order; otherwise ascending order (default: true).
* @return std::tuple<Tensor, Tensor> A tuple containing two tensors:
* - First tensor: The sorted data.
* - Second tensor: The corresponding indices.
*/
std::tuple<Tensor, Tensor> Sort(const Tensor& x, bool descending = true);
* @brief Sort a tensor & indices with shape (1, n) along the last dimension, n must be orders of 2.
* The vecTile (1, t), t must be orders of 2, maximum is 8K.
* @param x The input tensor to be sorted.
* @param idx The input indices corresponding to x.
* @param descending If true, sorts in descending order; otherwise ascending order (default: true).
* @return std::tuple<Tensor, Tensor> A tuple containing two tensors:
* - First tensor: The sorted data.
* - Second tensor: The corresponding indices.
*/
std::tuple<Tensor, Tensor> SortWithIndex(const Tensor& x, const Tensor& idx, bool descending = true);
Tensor SoftmaxNew(const Tensor& operand);
void SoftmaxDynamic(Tensor& input, Tensor& output);
Tensor RotateHalf(const Tensor& input);
Tensor Sigmoid(Tensor& input);
std::tuple<Tensor, Tensor> Quant(
const Tensor& input, bool isSymmetry = true, bool hasSmoothFactor = false, const Tensor& smoothFactor = Tensor());
std::tuple<Tensor, Tensor> QuantMX(
const Tensor& input, DataType quantDtype = DataType::DT_FP8E4M3,
DequantScaleRoundingMode mode = DequantScaleRoundingMode::ROUND_DOWN, int64_t axis = -1,
bool performanceMode = true);
Tensor ScalarDivS(const Tensor& operand, const Element& value, bool reverseOperand = false);
Tensor ScalarAddS(const Tensor& operand, const Element& value, bool reverseOperand = false);
Tensor ScalarMaxS(const Tensor& operand, const Element& value, bool reverseOperand = false);
Tensor ScalarSubS(const Tensor& operand, const Element& value, bool reverseOperand = false);
Tensor ScalarMulS(const Tensor& operand, const Element& value, bool reverseOperand = false);
Tensor ScalarSub(const Tensor& operand1, const Tensor& operand2);
Tensor ScalarDiv(const Tensor& operand1, const Tensor& operand2);
Tensor CumSum(const Tensor& input, const int& axis);
Tensor CumProd(const Tensor& input, const int& axis);
Tensor Gcd(const Tensor& input, const Tensor& other);
Tensor Gcd(const Tensor& input, const Element& other);
Tensor TriU(const Tensor& input, const SymbolicScalar& diagonal);
Tensor TriL(const Tensor& input, const SymbolicScalar& diagonal);
struct PaTileShapeConfig {
int headNumQTile;
std::array<int, TILE_VEC_DIMS> v0TileShape;
std::array<int, TILE_CUBE_DIMS> c1TileShape;
std::array<int, TILE_VEC_DIMS> v1TileShape;
std::array<int, TILE_CUBE_DIMS> c2TileShape;
std::array<int, TILE_VEC_DIMS> v2TileShape;
};
enum class ReduceMode {
ATOMIC_ADD,
};
Tensor Reduce(const std::vector<Tensor>& aggregation, const ReduceMode reduceMode);
Tensor Maxpool(
const Tensor& operand, const std::vector<int>& pools, const std::vector<int>& strides,
const std::vector<int>& paddings);
enum class LogBaseType {
LOG_E,
LOG_2,
LOG_10,
};
Tensor Log(
const Tensor& self, LogBaseType base = LogBaseType::LOG_E, PrecisionType precisionType = PrecisionType::INTRINSIC);
Tensor Log1p(const Tensor& self);
Tensor OneHot(const Tensor& self, int numClasses);
struct IfaTileShapeConfig {
int blockSize;
int headNumQTile;
std::array<int, TILE_VEC_DIMS> v0TileShape;
std::array<int, TILE_CUBE_DIMS> c1TileShape;
std::array<int, TILE_VEC_DIMS> v1TileShape;
std::array<int, TILE_CUBE_DIMS> c2TileShape;
std::array<int, TILE_VEC_DIMS> v2TileShape;
};
struct RoPETileShapeConfig {
std::vector<int64_t> twoDimsTileShape;
std::vector<int64_t> threeDimsTileShape;
std::vector<int64_t> fourDimsTileShape;
std::vector<int64_t> fiveDimsTileShape;
};
struct RoPETileShapeConfigNew {
std::vector<int64_t> threeDimsTileShape;
std::vector<int64_t> fourDimsTileShapeQ;
std::vector<int64_t> fourDimsTileShapeK;
std::vector<int64_t> fiveDimsTileShape;
};
void ApplyRotaryPosEmb(
const Tensor& q, const Tensor& k, const Tensor& cos, const Tensor& sin, const Tensor& positionIds, Tensor& qEmbed,
Tensor& kEmbed, const int unsqueezeDim = 1, const RoPETileShapeConfig& ropeTileShapeConfig = {});
void ApplyRotaryPosEmbV2(
const Tensor& q, const Tensor& k, const Tensor& cos, const Tensor& sin, Tensor& qEmbed, Tensor& kEmbed,
const int unsqueezeDim = 2, const RoPETileShapeConfigNew& ropeTileShapeConfig = {});
void IncreFlashAttention(
Tensor& qNope, Tensor& kNopeCache, Tensor& vNopeCache, Tensor& qRope, Tensor& kRopeCache,
std::vector<std::vector<int>>& blockTable, std::vector<int>& actSeqs, float softmaxScale, Tensor& attentionOut,
IfaTileShapeConfig& tileConfig);
void PageAttentionAddS(
Tensor& qNope, Tensor& kNopeCache, Tensor& vNopeCache, Tensor& qRope, Tensor& kRopeCache, Tensor& blockTable,
Tensor& actSeqs, int blockSize, float softmaxScale, Tensor& attentionOut, Tensor& postOut,
PaTileShapeConfig& tileConfig, int maxUnrollTimes = 1);
void PageAttentionAddSSingleOutput(
Tensor& qNope, Tensor& kNopeCache, Tensor& vNopeCache, Tensor& qRope, Tensor& kRopeCache, Tensor& blockTable,
Tensor& actSeqs, int blockSize, float softmaxScale, Tensor& attentionOut, Tensor& postOut,
PaTileShapeConfig& tileConfig, int maxUnrollTimes = 1);
void PrologPost(
Tensor& qNope, Tensor& kNopeCache, Tensor& vNopeCache, Tensor& qRope, Tensor& kRopeCache, Tensor& blockTable,
Tensor& actSeqs, Tensor& weightUV, Tensor& weightO, int blockSize, float softmaxScale, Tensor& postOut,
PaTileShapeConfig& tileConfig);
namespace Matrix {
enum class ReLuType : int64_t { NoReLu = 0, ReLu = 1 };
enum class TransMode : int64_t { CAST_NONE = 0, CAST_RINT = 1, CAST_ROUND = 2 };
struct MatmulExtendParam {
Tensor biasTensor{Tensor()};
Tensor scaleTensor{Tensor()};
float scaleValue{0.0f};
ReLuType reluType{ReLuType::NoReLu};
TransMode transMode{TransMode::CAST_NONE};
MatmulExtendParam(Tensor bias, Tensor scale, float scaleVal, ReLuType relu, TransMode mode = TransMode::CAST_NONE)
: biasTensor(std::move(bias)),
scaleTensor(std::move(scale)),
scaleValue(scaleVal),
reluType(relu),
transMode(mode)
{}
MatmulExtendParam() = default;
};
Tensor Matmul(
DataType outType, const Tensor& aMatrix, const Tensor& bMatrix, bool isATrans = false, bool isBTrans = false,
bool isCMatrixNZ = false);
Tensor Matmul(
DataType outType, const Tensor& aMatrix, const Tensor& bMatrix, const MatmulExtendParam& extendParam,
bool isATrans = false, bool isBTrans = false, bool isCMatrixNZ = false);
Tensor MatmulMX(
DataType outType, const Tensor& aMatrix, const Tensor& aScale, const Tensor& bMatrix, const Tensor& bScale,
bool isATrans = false, bool isAScaleTrans = false, bool isBTrans = false, bool isBScaleTrans = false,
bool isCMatrixNZ = false);
Tensor MatmulMX(
DataType outType, const Tensor& aMatrix, const Tensor& aScale, const Tensor& bMatrix, const Tensor& bScale,
const MatmulExtendParam& extendParam, bool isATrans = false, bool isAScaleTrans = false, bool isBTrans = false,
bool isBScaleTrans = false, bool isCMatrixNZ = false);
Tensor BatchMatmul(
DataType dataType, const Tensor& aMatrix, const Tensor& bMatrix, bool isATrans = false, bool isBTrans = false,
bool isCMatrixNZ = false);
Tensor BatchMatmul(
DataType dataType, const Tensor& aMatrix, const Tensor& bMatrix, const MatmulExtendParam& extendParam,
bool isATrans = false, bool isBTrans = false, bool isCMatrixNZ = false);
Tensor BatchMatmulMX(
DataType dataType, const Tensor& aMatrix, const Tensor& aScale, const Tensor& bMatrix, const Tensor& bScale,
bool isTransA = false, bool isAScaleTrans = false, bool isTransB = false, bool isBScaleTrans = false,
bool isCMatrixNZ = false);
Tensor BatchMatmulMX(
DataType dataType, const Tensor& aMatrix, const Tensor& aScale, const Tensor& bMatrix, const Tensor& bScale,
const MatmulExtendParam& extendParam, bool isTransA = false, bool isAScaleTrans = false, bool isTransB = false,
bool isBScaleTrans = false, bool isCMatrixNZ = false);
Tensor TransposedBatchMatmul(DataType dataType, const Tensor& aMatrix, const Tensor& bMatrix);
Tensor QuantMM(const Tensor& operand1, const Tensor& operand2, const Tensor& dequantScaleW);
}
namespace Conv {
struct TileL1Info {
int64_t tileHin{0};
int64_t tileHout{0};
int64_t tileWin{0};
int64_t tileWout{0};
int64_t tileCinFmap{0};
int64_t tileCinWeight{0};
int64_t tileN{0};
int64_t tileBatch{0};
TileL1Info(
int64_t hin, int64_t hout, int64_t win, int64_t wout, int64_t cinFmap, int64_t cinWeight, int64_t cout,
int64_t n)
: tileHin(hin),
tileHout(hout),
tileWin(win),
tileWout(wout),
tileCinFmap(cinFmap),
tileCinWeight(cinWeight),
tileN(cout),
tileBatch(n)
{}
TileL1Info() = default;
};
struct TileL0Info {
int64_t tileH{0};
int64_t tileW{0};
int64_t tileK{0};
int64_t tileN{0};
TileL0Info(int64_t h, int64_t w, int64_t k, int64_t n) : tileH(h), tileW(w), tileK(k), tileN(n) {}
TileL0Info() = default;
};
enum class ReLuType : int64_t { NoReLu = 0, ReLu = 1 };
struct ConvExtendParam {
Tensor biasTensor{Tensor()};
Tensor scaleTensor{Tensor()};
float scaleValue{0.0f};
ReLuType reluType{ReLuType::NoReLu};
ConvExtendParam(Tensor bias, Tensor scale, float scaleVal, ReLuType relu)
: biasTensor(std::move(bias)), scaleTensor(std::move(scale)), scaleValue(scaleVal), reluType(relu)
{}
ConvExtendParam() = default;
};
Tensor Conv(
DataType outType, const Tensor& inputTensor, const Tensor& weightTensor, const std::vector<int64_t>& strides,
const std::vector<SymbolicScalar>& paddings, const std::vector<int64_t>& dilations,
const ConvExtendParam& extendParam, const int64_t groups = 1);
}
namespace FakeTrans {
Tensor FakeTrans(const Tensor& input, const Tensor& output);
}
namespace Distributed {
enum class DistReduceType {
DIST_REDUCE_ADD,
DIST_REDUCE_MAX,
DIST_REDUCE_MIN,
};
enum class AtomicType { SET, ADD };
struct MoeConfig {
int32_t routedExpertNum{0};
int32_t expertNumPerRank{0};
int32_t rankNum{0};
};
void MoeDistributedDispatch(
const Tensor& tokenTensor, const Tensor& tokenExpertTable, Tensor& expandX, Tensor& validCnt, Tensor& combineInfo,
const char* group, const MoeConfig& moeConfig);
void MoeDistributedCombine(
const Tensor& expandX, const Tensor& assistInfoForCombine, const Tensor& recvCounts, const Tensor& expertScales,
const char* group, uint32_t epWorldSize, uint32_t moeExpertNum, uint32_t sharedExpertNum,
uint32_t sharedExpertRankNum, Tensor& out);
struct ShmemTensor {
std::string group;
int64_t worldSize{0};
Tensor data;
Tensor signal;
void* signalOp{nullptr};
};
ShmemTensor CreateShmemTensor(const char* group, int64_t worldSize, DataType dataType, const Shape& shape);
void CreateShmemTensor(const char* group, int64_t worldSize, DataType dataType, const Shape& shape, ShmemTensor& t);
ShmemTensor CreateShmemSignal(const char* group, int64_t worldSize);
void CreateShmemSignal(const char* group, int64_t worldSize, ShmemTensor& t);
ShmemTensor ShmemView(
const ShmemTensor& operand, const std::vector<int64_t>& shapes, const std::vector<int64_t>& offsets);
ShmemTensor ShmemView(
const ShmemTensor& operand, const std::vector<int64_t>& shapes, const std::vector<SymbolicScalar>& offsets);
ShmemTensor ShmemView(
const ShmemTensor& operand, const std::vector<int64_t>& shapes, const std::vector<SymbolicScalar>& newValidShapes,
const std::vector<SymbolicScalar>& newOffsets);
ShmemTensor ShmemView(
const ShmemTensor& operand, const std::vector<int64_t>& shapes,
const std::initializer_list<SymbolicScalar>& newOffsets);
Tensor ShmemPut(
const Tensor& src, const ShmemTensor& dst, const SymbolicScalar& dstRank, AtomicType putOp, const Tensor& pred);
Tensor ShmemGet(
const ShmemTensor& src, const SymbolicScalar& srcRank, const Tensor& pred,
DataType targetDataType = DataType::DT_BOTTOM);
Tensor ShmemSignal(
const ShmemTensor& src, const SymbolicScalar& srcRank, const SymbolicScalar& targetRank, int32_t signal,
AtomicType sigOp, const Tensor& pred);
Tensor ShmemSignalAll(
const ShmemTensor& src, const SymbolicScalar& srcRank, int32_t signal, AtomicType sigOp, const Tensor& pred);
Tensor ShmemWaitUntil(
const ShmemTensor& src, const SymbolicScalar& srcRank, OpType cmp, int32_t cmpValue, bool clearSignal,
const Tensor& pred);
Tensor ShmemClearData(const ShmemTensor& src, Tensor& pred);
Tensor ShmemClearSignal(const ShmemTensor& src, Tensor& pred);
Tensor ShmemBarrier(const ShmemTensor& src, const Tensor& pred);
Tensor ShmemLoad(
const ShmemTensor& src, const SymbolicScalar& srcRank, const Tensor& pred,
DataType targetDataType = DataType::DT_BOTTOM);
Tensor ShmemStore(
const Tensor& src, const ShmemTensor& dst, const SymbolicScalar& dstRank, AtomicType putOp, const Tensor& pred);
void AllGather(const Tensor& predToken, const Tensor& in, ShmemTensor& shmemTensor, Tensor& out);
void ReduceScatter(
const Tensor& predToken, const Tensor& in, ShmemTensor& shmemTensor, DistReduceType reduceType, Tensor& out);
void OneShotAllReduce(const Tensor& predToken, const Tensor& in, ShmemTensor& shmemTensor, Tensor& out);
void TwoShotAllReduce(const Tensor& predToken, const Tensor& in, ShmemTensor& shmemTensor, Tensor& out);
void MoeDistributedDispatchV2(
const Tensor& x, const Tensor& expertIds, const char* group, uint32_t epWorldSize, uint32_t moeExpertNum,
uint32_t sharedExpertNum, uint32_t sharedExpertRankNum, Tensor& expandX, Tensor& assistInfoForCombine,
Tensor& expertTokenNums, Tensor& recvCounts);
void MoeDistributedCombineV2(
const Tensor& expandX, const Tensor& assistInfoForCombine, const Tensor& recvCounts, const Tensor& expertScales,
const char* group, uint32_t epWorldSize, uint32_t moeExpertNum, uint32_t sharedExpertNum,
uint32_t sharedExpertRankNum, Tensor& out);
}
std::tuple<Tensor, Tensor> TopKSort(const Tensor& x, int idxStart);
std::tuple<Tensor, Tensor> TopKSort(const Tensor& x, const SymbolicScalar& idxStart);
Tensor TopKExtract(const Tensor& x, int k, bool isIndex);
Tensor TopKMerge(const Tensor& x, int mergeSize);
Tensor Nop(const std::vector<Tensor>& inTensors);
}