* Copyright (c) 2024 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#include <string>
#include <vector>
#include <sstream>
#include <nlohmann/json.hpp>
#include "atbops/params/params.h"
#include "mki/utils/SVector/SVector.h"
#include "mki/utils/stringify/stringify.h"
#include "mki/utils/any/any.h"
#include "mki/utils/log/log.h"
using namespace Mki;
namespace AtbOps {
template <typename T> std::vector<T> SVectorToVector(const SVector<T> &svector)
{
std::vector<T> tmpVec;
tmpVec.resize(svector.size());
for (size_t i = 0; i < svector.size(); i++) {
tmpVec.at(i) = svector.at(i);
}
return tmpVec;
};
std::string FastSoftMaxGradToJson(const Any ¶m)
{
nlohmann::json paramsJson;
OpParam::FastSoftMaxGrad specificParam = AnyCast<OpParam::FastSoftMaxGrad>(param);
paramsJson["qSeqLen"] = specificParam.qSeqLen;
paramsJson["headNum"] = specificParam.headNum;
return paramsJson.dump();
}
std::string FastSoftMaxToJson(const Any ¶m)
{
nlohmann::json paramsJson;
OpParam::FastSoftMax specificParam = AnyCast<OpParam::FastSoftMax>(param);
paramsJson["qSeqLen"] = specificParam.qSeqLen;
paramsJson["headNum"] = specificParam.headNum;
return paramsJson.dump();
}
std::string FusionToJson(const Any ¶m)
{
nlohmann::json paramsJson;
OpParam::Fusion specificParam = AnyCast<OpParam::Fusion>(param);
paramsJson["fusion"] = specificParam.fusionType;
return paramsJson.dump();
}
std::string GatingToJson(const Any ¶m)
{
nlohmann::json paramsJson;
OpParam::Gating specificParam = AnyCast<OpParam::Gating>(param);
paramsJson["headSize"] = specificParam.headSize;
paramsJson["headNum"] = specificParam.headNum;
return paramsJson.dump();
}
std::string GenAttentionMaskToJson(const Any ¶m)
{
nlohmann::json paramsJson;
OpParam::GenAttentionMask specificParam = AnyCast<OpParam::GenAttentionMask>(param);
paramsJson["qSeqLen"] = specificParam.qSeqLen;
paramsJson["headNum"] = specificParam.headNum;
return paramsJson.dump();
}
std::string KVCacheToJson(const Any ¶m)
{
nlohmann::json paramsJson;
OpParam::KVCache specificParam = AnyCast<OpParam::KVCache>(param);
paramsJson["type"] = specificParam.type;
paramsJson["qSeqLen"] = specificParam.qSeqLen;
paramsJson["kvSeqLen"] = specificParam.kvSeqLen;
paramsJson["batchRunStatus"] = specificParam.batchRunStatus;
paramsJson["seqLen"] = specificParam.seqLen;
paramsJson["tokenOffset"] = specificParam.tokenOffset;
return paramsJson.dump();
}
std::string PadWithHiddenStateToJson(const Any ¶m)
{
nlohmann::json paramsJson;
OpParam::PadWithHiddenState specificParam = AnyCast<OpParam::PadWithHiddenState>(param);
paramsJson["qSeqLen"] = specificParam.qSeqLen;
paramsJson["maxSeqLen"] = specificParam.maxSeqLen;
return paramsJson.dump();
}
std::string PadToJson(const Any ¶m)
{
(void)param;
return "{}";
}
std::string RopeQConcatToJson(const Any ¶m)
{
(void)param;
return "{}";
}
std::string SwigluQuantToJson(const Any ¶m)
{
(void)param;
return "{}";
}
std::string PagedAttentionToJson(const Any ¶m)
{
nlohmann::json paramsJson;
OpParam::PagedAttention specificParam = AnyCast<OpParam::PagedAttention>(param);
paramsJson["type"] = specificParam.type;
paramsJson["isTriuMask"] = specificParam.isTriuMask;
paramsJson["identityM"] = specificParam.identityM;
paramsJson["headSize"] = specificParam.headSize;
paramsJson["tor"] = specificParam.tor;
paramsJson["kvHead"] = specificParam.kvHead;
paramsJson["headSize"] = specificParam.headSize;
paramsJson["maskType"] = specificParam.maskType;
paramsJson["qSeqLen"] = specificParam.qSeqLen;
paramsJson["kvSeqLen"] = specificParam.kvSeqLen;
paramsJson["batchRunStatus"] = specificParam.batchRunStatus;
paramsJson["compressHead"] = specificParam.compressHead;
paramsJson["quantType"] = specificParam.quantType;
paramsJson["outDataType"] = specificParam.outDataType;
paramsJson["scaleType"] = specificParam.scaleType;
paramsJson["headDimV"] = specificParam.headDimV;
return paramsJson.dump();
}
std::string MLAToJson(const Any ¶m)
{
nlohmann::json paramsJson;
OpParam::MLA specificParam = AnyCast<OpParam::MLA>(param);
paramsJson["type"] = specificParam.type;
paramsJson["tor"] = specificParam.tor;
paramsJson["kvHead"] = specificParam.kvHead;
paramsJson["headSize"] = specificParam.headSize;
paramsJson["qSeqLen"] = specificParam.qSeqLen;
paramsJson["kvSeqLen"] = specificParam.kvSeqLen;
paramsJson["isRing"] = specificParam.isRing;
paramsJson["windowSize"] = specificParam.windowSize;
return paramsJson.dump();
}
std::string ReshapeAndCacheToJson(const Any ¶m)
{
nlohmann::json paramsJson;
OpParam::ReshapeAndCache specificParam = AnyCast<OpParam::ReshapeAndCache>(param);
paramsJson["type"] = specificParam.type;
return paramsJson.dump();
}
std::string PagedCacheLoadToJson(const Any ¶m)
{
nlohmann::json paramsJson;
OpParam::PagedCacheLoad specificParam = AnyCast<OpParam::PagedCacheLoad>(param);
paramsJson["type"] = specificParam.type;
paramsJson["cuSeqLens"] = specificParam.cuSeqLens;
paramsJson["hasSeqStarts"] = specificParam.hasSeqStarts;
return paramsJson.dump();
}
std::string RopeGradToJson(const Any ¶m)
{
nlohmann::json paramsJson;
OpParam::RopeGrad specificParam = AnyCast<OpParam::RopeGrad>(param);
paramsJson["qSeqLen"] = specificParam.qSeqLen;
return paramsJson.dump();
}
std::string BlockCopyToJson(const Any ¶m)
{
nlohmann::json paramsJson;
OpParam::BlockCopy specificParam = AnyCast<OpParam::BlockCopy>(param);
paramsJson["type"] = specificParam.type;
return paramsJson.dump();
}
std::string RopeToJson(const Any ¶m)
{
nlohmann::json paramsJson;
OpParam::Rope specificParam = AnyCast<OpParam::Rope>(param);
paramsJson["rotaryCoeff"] = specificParam.rotaryCoeff;
paramsJson["cosFormat"] = specificParam.cosFormat;
return paramsJson.dump();
}
std::string StridedBatchMatmulToJson(const Any ¶m)
{
nlohmann::json paramsJson;
OpParam::StridedBatchMatmul specificParam = AnyCast<OpParam::StridedBatchMatmul>(param);
paramsJson["headNum"] = specificParam.headNum;
paramsJson["batch"] = specificParam.batch;
paramsJson["transA"] = specificParam.transA;
paramsJson["transB"] = specificParam.transB;
paramsJson["m"] = specificParam.m;
paramsJson["k"] = specificParam.k;
paramsJson["v"] = specificParam.n;
paramsJson["lda"] = specificParam.lda;
paramsJson["ldb"] = specificParam.ldb;
paramsJson["ldc"] = specificParam.ldc;
paramsJson["strideA"] = specificParam.strideA;
paramsJson["strideB"] = specificParam.strideB;
paramsJson["strideC"] = specificParam.strideC;
return paramsJson.dump();
}
std::string ToppsampleToJson(const Any ¶m)
{
nlohmann::json paramsJson;
OpParam::Toppsample specificParam = AnyCast<OpParam::Toppsample>(param);
paramsJson["randSeed"] = specificParam.randSeed;
return paramsJson.dump();
}
std::string UnpadFlashAttentionNzToJson(const Any ¶m)
{
nlohmann::json paramsJson;
OpParam::UnpadFlashAttentionNz specificParam = AnyCast<OpParam::UnpadFlashAttentionNz>(param);
paramsJson["type"] = specificParam.type;
paramsJson["headSize"] = specificParam.headSize;
paramsJson["qSeqLen"] = specificParam.qSeqLen;
paramsJson["kvSeqLen"] = specificParam.kvSeqLen;
paramsJson["tor"] = specificParam.tor;
paramsJson["kvHead"] = specificParam.kvHead;
paramsJson["batchRunStatus"] = specificParam.batchRunStatus;
paramsJson["isTriuMask"] = specificParam.isTriuMask;
paramsJson["maskType"] = specificParam.maskType;
paramsJson["alibiLeftAlign"] = specificParam.alibiLeftAlign;
paramsJson["isAlibiMaskSqrt"] = specificParam.isAlibiMaskSqrt;
paramsJson["compressHead"] = specificParam.compressHead;
paramsJson["dataDimOrder"] = specificParam.dataDimOrder;
paramsJson["scaleType"] = specificParam.scaleType;
paramsJson["windowSize"] = specificParam.windowSize;
paramsJson["cacheType"] = specificParam.cacheType;
paramsJson["precType"] = specificParam.precType;
std::stringstream ss;
for (size_t i = 0; i < specificParam.kTensorList.size(); ++i) {
ss << "\nkTensorList[" << i << "]: " << specificParam.kTensorList.at(i).ToString();
}
for (size_t i = 0; i < specificParam.vTensorList.size(); ++i) {
ss << "\nvTensorList[" << i << "]: " << specificParam.vTensorList.at(i).ToString();
}
return paramsJson.dump() + ss.str();
}
std::string UnpadFlashAttentionToJson(const Any ¶m)
{
nlohmann::json paramsJson;
OpParam::UnpadFlashAttention specificParam = AnyCast<OpParam::UnpadFlashAttention>(param);
paramsJson["type"] = specificParam.type;
paramsJson["headSize"] = specificParam.headSize;
paramsJson["qSeqLen"] = specificParam.qSeqLen;
paramsJson["kvSeqLen"] = specificParam.kvSeqLen;
paramsJson["tor"] = specificParam.tor;
paramsJson["kvHead"] = specificParam.kvHead;
paramsJson["batchRunStatus"] = specificParam.batchRunStatus;
paramsJson["isClamp"] = specificParam.isClamp;
paramsJson["clampMin"] = specificParam.clampMin;
paramsJson["clampMax"] = specificParam.clampMax;
paramsJson["maskType"] = specificParam.maskType;
paramsJson["alibiLeftAlign"] = specificParam.alibiLeftAlign;
paramsJson["isAlibiMaskSqrt"] = specificParam.isAlibiMaskSqrt;
paramsJson["compressHead"] = specificParam.compressHead;
paramsJson["isTriuMask"] = specificParam.isTriuMask;
paramsJson["quantType"] = specificParam.quantType;
paramsJson["outDataType"] = specificParam.outDataType;
paramsJson["scaleType"] = specificParam.scaleType;
paramsJson["windowSize"] = specificParam.windowSize;
paramsJson["cacheType"] = specificParam.cacheType;
paramsJson["headDimV"] = specificParam.headDimV;
paramsJson["kvShareMap"] = specificParam.kvShareMap;
paramsJson["kvShareLen"] = specificParam.kvShareLen;
std::stringstream ss;
for (size_t i = 0; i < specificParam.kTensorList.size(); ++i) {
ss << "\nkTensorList[" << i << "]: " << specificParam.kTensorList.at(i).ToString();
}
for (size_t i = 0; i < specificParam.vTensorList.size(); ++i) {
ss << "\nvTensorList[" << i << "]: " << specificParam.vTensorList.at(i).ToString();
}
for (size_t i = 0; i < specificParam.kShareTensorList.size(); ++i) {
ss << "\nkShareTensorList[" << i << "]: " << specificParam.kShareTensorList.at(i).ToString();
}
for (size_t i = 0; i < specificParam.vShareTensorList.size(); ++i) {
ss << "\nvShareTensorList[" << i << "]: " << specificParam.vShareTensorList.at(i).ToString();
}
return paramsJson.dump() + ss.str();
}
std::string UnpadWithHiddenStateToJson(const Any ¶m)
{
nlohmann::json paramsJson;
OpParam::UnpadWithHiddenState specificParam = AnyCast<OpParam::UnpadWithHiddenState>(param);
paramsJson["qSeqLen"] = specificParam.qSeqLen;
paramsJson["maxSeqLen"] = specificParam.maxSeqLen;
return paramsJson.dump();
}
std::string UnpadToJson(const Any ¶m)
{
(void)param;
return "{}";
}
std::string RINGMLAToJson(const Any ¶m)
{
nlohmann::json paramsJson;
OpParam::RINGMLA specificParam = AnyCast<OpParam::RINGMLA>(param);
paramsJson["type"] = specificParam.type;
paramsJson["tor"] = specificParam.tor;
paramsJson["kvHead"] = specificParam.kvHead;
paramsJson["headSize"] = specificParam.headSize;
paramsJson["qSeqLen"] = specificParam.qSeqLen;
paramsJson["kvSeqLen"] = specificParam.kvSeqLen;
paramsJson["isRing"] = specificParam.isRing;
return paramsJson.dump();
}
std::string GmmDeqSwigluQuantGmmDeqToJson(const Any ¶m)
{
nlohmann::json paramsJson;
OpParam::GmmDeqSwigluQuantGmmDeq specificParam = AnyCast<OpParam::GmmDeqSwigluQuantGmmDeq>(param);
paramsJson["outputType"] = specificParam.outputType;
paramsJson["groupListType"] = specificParam.groupListType;
paramsJson["weightUpPermuteType"] = specificParam.weightUpPermuteType;
paramsJson["transposeWeightUp"] = specificParam.transposeWeightUp;
paramsJson["transposeWeightDown"] = specificParam.transposeWeightDown;
return paramsJson.dump();
}
std::string MmDeqSwigluQuantMmDeqToJson(const Any ¶m)
{
nlohmann::json paramsJson;
OpParam::MmDeqSwigluQuantMmDeq specificParam = AnyCast<OpParam::MmDeqSwigluQuantMmDeq>(param);
paramsJson["outputType"] = specificParam.outputType;
paramsJson["weightUpPermuteType"] = specificParam.weightUpPermuteType;
paramsJson["transposeWeightUp"] = specificParam.transposeWeightUp;
paramsJson["transposeWeightDown"] = specificParam.transposeWeightDown;
return paramsJson.dump();
}
REG_STRINGIFY(OpParam::BlockCopy, BlockCopyToJson);
REG_STRINGIFY(OpParam::FastSoftMaxGrad, FastSoftMaxGradToJson);
REG_STRINGIFY(OpParam::FastSoftMax, FastSoftMaxToJson);
REG_STRINGIFY(OpParam::Gating, GatingToJson);
REG_STRINGIFY(OpParam::GenAttentionMask, GenAttentionMaskToJson);
REG_STRINGIFY(OpParam::KVCache, KVCacheToJson);
REG_STRINGIFY(OpParam::PadWithHiddenState, PadWithHiddenStateToJson);
REG_STRINGIFY(OpParam::Pad, PadToJson);
REG_STRINGIFY(OpParam::PagedAttention, PagedAttentionToJson);
REG_STRINGIFY(OpParam::MLA, MLAToJson);
REG_STRINGIFY(OpParam::RINGMLA, RINGMLAToJson);
REG_STRINGIFY(OpParam::ReshapeAndCache, ReshapeAndCacheToJson);
REG_STRINGIFY(OpParam::PagedCacheLoad, PagedCacheLoadToJson);
REG_STRINGIFY(OpParam::RopeGrad, RopeGradToJson);
REG_STRINGIFY(OpParam::Rope, RopeToJson);
REG_STRINGIFY(OpParam::StridedBatchMatmul, StridedBatchMatmulToJson);
REG_STRINGIFY(OpParam::Toppsample, ToppsampleToJson);
REG_STRINGIFY(OpParam::UnpadFlashAttentionNz, UnpadFlashAttentionNzToJson);
REG_STRINGIFY(OpParam::UnpadFlashAttention, UnpadFlashAttentionToJson);
REG_STRINGIFY(OpParam::UnpadWithHiddenState, UnpadWithHiddenStateToJson);
REG_STRINGIFY(OpParam::Unpad, UnpadToJson);
REG_STRINGIFY(OpParam::RopeQConcat, RopeQConcatToJson);
REG_STRINGIFY(OpParam::SwigluQuant, SwigluQuantToJson);
REG_STRINGIFY(OpParam::GmmDeqSwigluQuantGmmDeq, GmmDeqSwigluQuantGmmDeqToJson);
REG_STRINGIFY(OpParam::MmDeqSwigluQuantMmDeq, MmDeqSwigluQuantMmDeqToJson);
}