* This file is part of the MindStudio project.
* Copyright (c) 2025 Huawei Technologies Co.,Ltd.
*
* MindStudio is licensed under Mulan PSL v2.
* You can use this software according to the terms and conditions of the Mulan PSL v2.
* You may obtain a copy of Mulan PSL v2 at:
*
* http://license.coscl.org.cn/MulanPSL2
*
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
* EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
* MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
* See the Mulan PSL v2 for more details.
* ------------------------------------------------------------------------- */
#ifndef STATISTICS_H
#define STATISTICS_H
#include <limits>
#include <string>
#include <complex>
#include <cmath>
#include <limits>
namespace LLM {
class StatisticsBase {
public:
virtual ~StatisticsBase() = default;
virtual std::string GetMaxStr() const = 0;
virtual std::string GetMinStr() const = 0;
virtual std::string GetMeanStr() const = 0;
virtual std::string GetL2NormStr() const = 0;
};
template<typename T>
class Statistics : public StatisticsBase {
public:
T maxValue_;
T minValue_;
double average_;
double l2norm_;
double sumValue_;
double sumOfSquares_;
size_t count_;
Statistics() : maxValue_(std::numeric_limits<T>::lowest()),
minValue_(std::numeric_limits<T>::max()),
average_(0.0),
l2norm_(0.0),
sumValue_(0),
sumOfSquares_(0),
count_(0) {}
void operator+=(const Statistics& other)
{
if (other.count_ != 0) {
maxValue_ = std::max(maxValue_, other.maxValue_);
minValue_ = std::min(minValue_, other.minValue_);
sumValue_ += other.sumValue_;
sumOfSquares_ += other.sumOfSquares_;
count_ += other.count_;
}
}
void Compute(T value)
{
maxValue_ = std::max(maxValue_, value);
minValue_ = std::min(minValue_, value);
sumValue_ += value;
sumOfSquares_ += static_cast<double>(value) * value;
++count_;
}
void ComputeAverage()
{
average_ = (count_ > 0) ? (sumValue_ / count_) : 0.0;
}
std::string GetMaxStr() const override
{
return std::to_string(maxValue_);
}
std::string GetMinStr() const override
{
return std::to_string(minValue_);
}
std::string GetMeanStr() const override
{
return std::to_string(average_);
}
std::string GetL2NormStr() const override
{
return std::to_string(l2norm_);
}
};
template<>
class Statistics<std::complex<float>> : public StatisticsBase {
public:
std::complex<float> maxValue_;
std::complex<float> minValue_;
std::complex<double> average_;
double l2norm_;
std::complex<double> sumValue_;
double sumOfSquares_;
size_t count_;
Statistics() : maxValue_(std::complex<float>(0, 0)),
minValue_(std::complex<float>(std::numeric_limits<float>::max(),
std::numeric_limits<float>::max())),
average_(std::complex<double>(0, 0)),
l2norm_(0),
sumValue_(std::complex<double>(0, 0)),
sumOfSquares_(0),
count_(0) {}
void operator+=(const Statistics& other)
{
if (other.count_ != 0) {
double thisMag = std::norm(maxValue_);
double otherMag = std::norm(other.maxValue_);
maxValue_ = (thisMag > otherMag) ? maxValue_ : other.maxValue_;
thisMag = std::norm(minValue_);
otherMag = std::norm(other.minValue_);
minValue_ = (thisMag < otherMag) ? minValue_ : other.minValue_;
sumValue_ += other.sumValue_;
sumOfSquares_ += other.sumOfSquares_;
count_ += other.count_;
}
}
void Compute(const std::complex<float>& value)
{
float valMag = std::norm(value);
if (count_ == 0 || valMag > std::norm(maxValue_)) {
maxValue_ = value;
}
if (count_ == 0 || valMag < std::norm(minValue_)) {
minValue_ = value;
}
sumValue_ += value;
sumOfSquares_ += std::norm(value);
count_++;
}
void ComputeAverage()
{
bool notEmpty = count_ > 0;
average_ = std::complex<double>(notEmpty ? sumValue_.real() / (static_cast<double>(count_)) : 0,
notEmpty ? sumValue_.imag() / (static_cast<double>(count_)) : 0);
}
std::string GetMaxStr() const override
{
return "(" + std::to_string(maxValue_.real()) + "," + std::to_string(maxValue_.imag()) + ")";
}
std::string GetMinStr() const override
{
return "(" + std::to_string(minValue_.real()) + "," + std::to_string(minValue_.imag()) + ")";
}
std::string GetMeanStr() const override
{
return "(" + std::to_string(average_.real()) + "," + std::to_string(average_.imag()) + ")";
}
std::string GetL2NormStr() const override
{
return std::to_string(l2norm_);
}
};
template<>
class Statistics<std::string> : public StatisticsBase {
public:
std::string maxValue_;
std::string minValue_;
std::string average_;
std::string l2norm_;
Statistics() : maxValue_("N/A"),
minValue_("N/A"),
average_("N/A"),
l2norm_("N/A") {}
std::string GetMaxStr() const override
{
return maxValue_;
}
std::string GetMinStr() const override
{
return minValue_;
}
std::string GetMeanStr() const override
{
return average_;
}
std::string GetL2NormStr() const override
{
return l2norm_;
}
};
}
#endif