* Copyright (c) 2024 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#ifndef ASDOPS_PARAMS_NORM_H
#define ASDOPS_PARAMS_NORM_H
#include <string>
#include <sstream>
#include <mki/utils/SVector/SVector.h>
#include <mki/utils/compare/compare.h>
namespace AsdOps {
namespace OpParam {
struct Norm {
enum NormType {
NORM_UNDEFINED = 0,
LAYER_NORM,
RMS_NORM,
RMS_NORM_FORWARD,
RMS_NORM_BACKWARD,
GATHER_PRE_RMS_NORM,
};
NormType normType;
int32_t beginNormAxis = 0;
int32_t beginParamsAxis = 0;
size_t opsMode = 0;
float epsilon = 0.1f;
float zoomScaleValue = 1.0f;
uint32_t precisionMode = 0;
uint32_t gemmaMode = 0;
bool inGamma = false;
bool inBeta = false;
bool inRes = false;
bool inNormBias = false;
bool outMean = false;
bool outVarience = false;
bool outResQuant = false;
bool outRes = false;
bool isDynamicQuant = false;
bool isSymmetric = true;
bool operator==(const Norm &other) const
{
return this->normType == other.normType && this->beginNormAxis == other.beginNormAxis &&
this->beginParamsAxis == other.beginParamsAxis && this->opsMode == other.opsMode &&
Mki::Utils::Compare<float>::IsEqual(this->epsilon, other.epsilon) &&
Mki::Utils::Compare<float>::IsEqual(this->zoomScaleValue, other.zoomScaleValue) &&
this->inGamma == other.inGamma &&
this->inBeta == other.inBeta &&
this->inRes == other.inRes &&
this->inNormBias == other.inNormBias &&
this->outMean == other.outMean &&
this->outVarience == other.outVarience &&
this->outResQuant == other.outResQuant &&
this->outRes == other.outRes &&
this->precisionMode == other.precisionMode &&
this->gemmaMode == other.gemmaMode &&
this->isDynamicQuant == other.isDynamicQuant;
}
};
}
}
#endif