* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
* \file flow_verifier.h
* \brief
*/
#pragma once
#include "interface/interpreter/interpreter_log.h"
#include "interface/interpreter/raw_tensor_data.h"
#include "interface/tensor/tensor_slot.h"
#include "interface/operation/attribute.h"
#include "interface/function/function.h"
#include "interface/interpreter/function.h"
#include <float.h>
#include <cmath>
#include <utility>
namespace npu::tile_fwk {
constexpr int THREAD_THOUSAND = 1000;
class FlowVerifier {
public:
static FlowVerifier& GetInstance();
void VerifyTensorGraph(
Function* entry, const std::vector<std::shared_ptr<LogicalTensorData>>& inputDataViewList,
const std::vector<std::shared_ptr<LogicalTensorData>>& outputDataViewList,
const std::vector<std::shared_ptr<LogicalTensorData>>& goldenDataViewList,
const std::shared_ptr<TensorSlotManager>& slotManager);
void VerifyPass(Function* func, int passIndex, const std::string& passIdentifier);
struct CompareResultDetail {
size_t totalCnt;
size_t zeroCnt;
size_t toleranceCnt;
size_t warnNum;
size_t failNum;
double mre;
double mreTop8;
double mreTop1Permil;
double mae;
double maeTop8;
double maeTop1Permil;
double aMax = FLT_MIN;
double aMin = FLT_MAX;
double aAvg;
double aAavg;
size_t aZero = 0;
size_t aInfnan = 0;
double bMax = FLT_MIN;
double bMin = FLT_MAX;
double bAvg;
double bAavg;
size_t bZero = 0;
size_t bInfnan = 0;
size_t infnanCnt = 0;
};
struct CompareElement {
bool isError;
size_t index;
double goldenValue;
double outputValue;
double absDiff;
double relDiff;
double tolerance;
CompareElement() = default;
CompareElement(const CompareElement&) = default;
CompareElement& operator=(const CompareElement&) = default;
CompareElement(
bool isError_, size_t index_, double goldenValue_, double outputValue_, double absDiff_, double relDiff_,
double tolerance_)
: isError(isError_),
index(index_),
goldenValue(goldenValue_),
outputValue(outputValue_),
absDiff(absDiff_),
relDiff(relDiff_),
tolerance(tolerance_)
{}
std::string Dump() const
{
std::ostringstream oss;
oss << "index:" << index << " golden:" << goldenValue << " output:" << outputValue << " absDiff:" << absDiff
<< " relDiff:" << relDiff;
return oss.str();
}
};
struct CompareResult : std::vector<CompareElement> {
CompareResult() {}
CompareResult(
int size, float rtol, float atol, size_t errorCountThreshold = 0, size_t failNum = 0, Shape shape = {})
: size_(size),
rtol_(rtol),
atol_(atol),
errorCountThreshold_(errorCountThreshold),
failNum_(failNum),
shape_(shape)
{}
template <typename... TyArgs>
void AppendError(TyArgs&&... args)
{
errorCount_++;
this->emplace_back(args...);
}
void AppendZero() { zeroCount_++; }
void AppendFail() { failNum_++; }
void UpdateErrorCountThreshold()
{
errorCountThreshold_ = static_cast<int>((size_ - zeroCount_) * std::min(rtol_, atol_));
size_t cnt_adj = static_cast<int>(std::pow((size_ - zeroCount_), 0.5)) / 2;
if (errorCountThreshold_ == 0) {
size_t cnt_normal = 16;
errorCountThreshold_ = std::min(cnt_normal, cnt_adj);
}
}
bool Check() const { return errorCount_ <= errorCountThreshold_ && failNum_ == 0; }
void sortByAbsAdesc()
{
std::sort(this->begin(), this->end(), [](const CompareElement& lhs, const CompareElement& rhs) {
return lhs.absDiff > rhs.absDiff;
});
}
void sortByRelAdesc()
{
std::sort(this->begin(), this->end(), [](const CompareElement& lhs, const CompareElement& rhs) {
return lhs.relDiff > rhs.relDiff;
});
}
double GetMeanTopN(size_t num, bool is_abs) const
{
double sum = 0;
size_t count = std::min(this->size(), num);
if (count == 0)
return 0;
if (is_abs) {
sum = std::accumulate(
this->begin(), this->begin() + count, 0.0,
[](double acc, const CompareElement& item) { return acc + item.absDiff; });
} else {
sum = std::accumulate(
this->begin(), this->begin() + count, 0.0,
[](double acc, const CompareElement& item) { return acc + item.relDiff; });
}
return sum / count;
}
Shape GetOffsetRaw(int64_t offset) const
{
if (shape_.empty()) {
return {};
}
int64_t total = 1;
for (int64_t dim : shape_) {
total *= dim;
}
ASSERT(VerifyResultScene::VERIFY_RESULT_INDEX_OUTOFBOUNDS, offset < total) << "Offset Out Of Bounds";
std::vector<int64_t> indices(shape_.size());
int64_t remaining = offset;
for (int i = static_cast<int>(shape_.size()) - 1; i >= 0; --i) {
indices[i] = remaining % shape_[i];
remaining /= shape_[i];
}
return indices;
}
void DumpDataDetail(std::ostringstream& oss, size_t topk = 64)
{
constexpr auto max_precision{std::numeric_limits<float>::digits10 + 1};
oss << std::setprecision(max_precision);
oss << "GROUP,INDEX,OFFSET,OFFSET_RAW,A>data,B>data,AB>ae,AB>re,AB>tol\n";
size_t count = std::min(topk, this->size());
for (size_t k = 0; k < count; k++) {
auto [isError, index, goldenValue, outputValue, absDiff, relDiff, tolerance] = (*this)[k];
(void)isError;
oss << "firstk," << k << "," << index << "," << FunctionInterpreter::ShapeToString(GetOffsetRaw(index))
<< "," << goldenValue << "," << outputValue << "," << absDiff << "," << relDiff << "," << tolerance
<< "\n";
}
sortByRelAdesc();
for (size_t k = 0; k < count; k++) {
auto [isError, index, goldenValue, outputValue, absDiff, relDiff, tolerance] = (*this)[k];
(void)isError;
oss << "topk_re," << k << "," << index << "," << FunctionInterpreter::ShapeToString(GetOffsetRaw(index))
<< "," << goldenValue << "," << outputValue << "," << absDiff << "," << relDiff << "," << tolerance
<< "\n";
}
}
CompareResultDetail Dump(int indent = 2, size_t maxPrint = 5)
{
double maxAbsDiff = 0;
double maxRelDiff = 0;
double totalAbsDiff = 0;
double totalRelDiff = 0;
CompareResultDetail compareResultDetail;
CompareElement maxAbsElement;
CompareElement maxRelElement;
std::ostringstream oss;
std::string space(indent, ' ');
std::string infoError =
"\n " + space + "Error rtol=" + std::to_string(rtol_) + " atol=" + std::to_string(atol_);
std::string infoZero = "\n " + space + "Zero";
size_t count_ = 0;
for (auto& element : *this) {
auto [isError, index, goldenValue, outputValue, absDiff, relDiff, tolerance] = element;
(void)index;
(void)tolerance;
if (absDiff > maxAbsDiff) {
maxAbsDiff = absDiff;
maxAbsElement = element;
}
if (relDiff > maxRelDiff) {
maxRelDiff = relDiff;
maxRelElement = element;
}
maxAbsDiff = std::max(maxAbsDiff, absDiff);
maxRelDiff = std::max(maxRelDiff, relDiff);
compareResultDetail.aMax = std::max(compareResultDetail.aMax, goldenValue);
compareResultDetail.bMax = std::max(compareResultDetail.bMax, outputValue);
compareResultDetail.aMin = std::min(compareResultDetail.aMin, goldenValue);
compareResultDetail.bMin = std::min(compareResultDetail.bMin, outputValue);
totalAbsDiff += absDiff;
totalRelDiff += relDiff;
std::string info = "";
if (isError) {
info = infoError.c_str();
} else {
continue;
}
if (count_ <= maxPrint) {
oss << space << info << " " << element.Dump() << "";
}
count_++;
}
oss << "\n"
<< space << "All size:" << size_ << " failNum:" << failNum_ << " maxAbsDiff:" << maxAbsDiff
<< " maxRelDiff:" << maxRelDiff << " averageAbsDiff:" << totalAbsDiff / size_
<< " averageRelDiff:" << totalRelDiff / size_ << " errorCount:" << errorCount_
<< " errorRatio:" << errorCount_ * 1.0 / size_ << " zeroCount:" << zeroCount_
<< " zeroRatio:" << zeroCount_ * 1.0 / size_ << "\n";
if (errorCount_ + zeroCount_ > 0) {
oss << space << "maxAbs-> " << maxAbsElement.Dump() << "\n"
<< space << "maxRel-> " << maxRelElement.Dump() << "\n";
}
if (!Check()) {
INTERPRETER_EVENT("%s", oss.str().c_str());
}
compareResultDetail.totalCnt = size_;
compareResultDetail.zeroCnt = zeroCount_;
compareResultDetail.toleranceCnt = errorCountThreshold_;
compareResultDetail.warnNum = errorCount_;
compareResultDetail.failNum = failNum_;
compareResultDetail.mre = totalAbsDiff / size_;
compareResultDetail.mae = totalRelDiff / size_;
constexpr size_t topNCount = 8;
constexpr size_t top1PermilCount = 1000;
compareResultDetail.mreTop8 = GetMeanTopN(topNCount, false);
compareResultDetail.mreTop1Permil = GetMeanTopN(top1PermilCount, false);
sortByAbsAdesc();
compareResultDetail.maeTop1Permil = GetMeanTopN(top1PermilCount, true);
compareResultDetail.maeTop8 = GetMeanTopN(topNCount, true);
compareResultDetail.aMax = goldenMax_;
compareResultDetail.aMin = goldenMin_;
compareResultDetail.aAvg = goldenSum_ / size_;
compareResultDetail.aAavg = goldenAbsSum_ / size_;
compareResultDetail.aZero = goldenZero_;
compareResultDetail.bMax = outputMax_;
compareResultDetail.bMin = outputMin_;
compareResultDetail.bAvg = outputSum_ / size_;
compareResultDetail.bAavg = outputAbsSum_ / size_;
compareResultDetail.bZero = outputZero_;
compareResultDetail.bInfnan = outputInfnan_;
compareResultDetail.aInfnan = goldenInfnan_;
compareResultDetail.infnanCnt = infnanCnt_;
return compareResultDetail;
}
float GetRtol() const { return rtol_; }
float GetAtol() const { return atol_; }
double goldenMax_ = -DBL_MAX;
double outputMax_ = -DBL_MAX;
double goldenMin_ = DBL_MAX;
double outputMin_ = DBL_MAX;
double goldenSum_ = 0;
double outputSum_ = 0;
double goldenAbsSum_ = 0;
double outputAbsSum_ = 0;
size_t goldenZero_ = 0;
size_t outputZero_ = 0;
size_t goldenInfnan_ = 0;
size_t outputInfnan_ = 0;
size_t infnanCnt_ = 0;
private:
size_t size_{0};
float rtol_{0};
float atol_{0};
size_t errorCountThreshold_{0};
size_t failNum_{0};
Shape shape_;
size_t errorCount_ = 0;
size_t zeroCount_ = 0;
};
private:
static void CompareScalarPair(
CompareResult& compareResult, int64_t linearIndex, double goldenValue, double outputValue)
{
compareResult.goldenMax_ = std::max(compareResult.goldenMax_, goldenValue);
compareResult.outputMax_ = std::max(compareResult.outputMax_, outputValue);
compareResult.goldenMin_ = std::min(compareResult.goldenMin_, goldenValue);
compareResult.outputMin_ = std::min(compareResult.outputMin_, outputValue);
compareResult.goldenSum_ += goldenValue;
compareResult.outputSum_ += outputValue;
const double output_abs = std::abs(outputValue);
const double golden_abs = std::abs(goldenValue);
const double output_golden_sub_abs = std::abs(outputValue - goldenValue);
compareResult.goldenAbsSum_ += golden_abs;
compareResult.outputAbsSum_ += output_abs;
if (output_abs <= 0) {
compareResult.outputZero_++;
}
if (golden_abs <= 0) {
compareResult.goldenZero_++;
}
if (!std::isfinite(outputValue)) {
compareResult.outputInfnan_++;
}
if (!std::isfinite(goldenValue)) {
compareResult.goldenInfnan_++;
}
if (!std::isfinite(output_golden_sub_abs)) {
compareResult.infnanCnt_++;
}
const double output_golden_abs_add = output_abs + golden_abs;
if (output_golden_abs_add <= 0) {
compareResult.AppendZero();
return;
}
const double relDiff = output_golden_sub_abs * 2 / output_golden_abs_add;
const double tol_attn = output_golden_abs_add * compareResult.GetRtol() / 2 + compareResult.GetAtol();
const double tol_fail = tol_attn * 128;
if (output_golden_sub_abs > tol_attn) {
compareResult.AppendError(
true, static_cast<size_t>(linearIndex), goldenValue, outputValue, output_golden_sub_abs, relDiff,
tol_attn);
}
if (output_golden_sub_abs > tol_fail) {
compareResult.AppendFail();
}
}
static int64_t GetValidStride(Shape validShape, int64_t axis)
{
int64_t stride = 1;
int64_t dims = validShape.size();
for (int64_t i = axis + 1; i < dims; i++) {
stride = stride * validShape[i];
}
return stride;
}
template <typename Leaf>
static void CompareDataRecursiveWithLeaf(
CompareResult& compareResult, size_t axis, int64_t goldenOffset, int64_t outputOffset,
const std::shared_ptr<LogicalTensorData>& goldenDataView,
const std::shared_ptr<LogicalTensorData>& outputDataView, int64_t index, Leaf&& leaf)
{
auto& validShape = goldenDataView->GetValidShape();
if (axis == validShape.size() - 1) {
leaf(compareResult, validShape[axis], outputOffset, goldenOffset, index, goldenDataView, outputDataView);
} else {
for (int i = 0; i < validShape[axis]; i++) {
int nGoldenOffset = goldenOffset + goldenDataView->GetData()->GetStride()[axis] * i;
int nOutputOffset = outputOffset + outputDataView->GetData()->GetStride()[axis] * i;
int64_t nindex = index + GetValidStride(validShape, axis) * i;
CompareDataRecursiveWithLeaf(
compareResult, axis + 1, nGoldenOffset, nOutputOffset, goldenDataView, outputDataView, nindex,
std::forward<Leaf>(leaf));
}
}
}
public:
template <typename DataType, typename T>
static void CompareData(
CompareResult& compareResult, size_t count, int64_t offset, const DataType* goldenValueList,
const DataType* outputValueList)
{
for (size_t index = 0; index < count; index++) {
const double goldenValue = static_cast<double>(static_cast<T>(goldenValueList[index]));
const double outputValue = static_cast<double>(static_cast<T>(outputValueList[index]));
CompareScalarPair(compareResult, offset + static_cast<int64_t>(index), goldenValue, outputValue);
}
}
template <typename DataType, typename T>
static void CompareDataRecursive(
CompareResult& compareResult, size_t axis, int64_t goldenOffset, int64_t outputOffset,
const std::shared_ptr<LogicalTensorData>& goldenDataView,
const std::shared_ptr<LogicalTensorData>& outputDataView, int64_t index)
{
CompareDataRecursiveWithLeaf(
compareResult, axis, goldenOffset, outputOffset, goldenDataView, outputDataView, index,
[](CompareResult& cr, size_t lastAxisLen, int64_t outOff, int64_t gOff, int64_t idx,
const std::shared_ptr<LogicalTensorData>& gv, const std::shared_ptr<LogicalTensorData>& ov) {
CompareData<DataType, T>(cr, lastAxisLen, idx, &gv->Get<DataType>(gOff), &ov->Get<DataType>(outOff));
});
}
template <typename DataType, typename T>
static CompareResult CompareData(
const std::shared_ptr<LogicalTensorData>& goldenDataView,
const std::shared_ptr<LogicalTensorData>& outputDataView, float rtol, float atol, int errorCountThreshold = 0,
int failNum = 0)
{
auto& validShape = goldenDataView->GetValidShape();
auto size = std::accumulate(validShape.begin(), validShape.end(), 1, std::multiplies<>());
CompareResult compareResult(size, rtol, atol, errorCountThreshold, failNum, validShape);
CompareDataRecursive<DataType, T>(
compareResult, 0, goldenDataView->GetStorageOffset(), outputDataView->GetStorageOffset(), goldenDataView,
outputDataView, 0);
compareResult.UpdateErrorCountThreshold();
return compareResult;
}
static CompareResult VerifyResult(
const std::shared_ptr<LogicalTensorData>& goldenDataView,
const std::shared_ptr<LogicalTensorData>& outputDataView, float rtol, float atol);
bool VerifyResult(
const std::vector<std::shared_ptr<LogicalTensor>>& tensorDatalist,
const std::vector<std::shared_ptr<LogicalTensor>>& goldenDatalist, const std::string& key,
const std::string tensorNameList, const std::vector<std::shared_ptr<LogicalTensorData>>& goldenDataViewList,
const std::vector<std::shared_ptr<LogicalTensorData>>& tensorDataViewList, float rtol, float atol);
std::string ParseErrorMsg(std::string errorMsg);
void WriteUserGolden(const std::vector<std::shared_ptr<LogicalTensorData>>& goldenDataViewList);
void WriteException();
private:
static CompareResult CompareFp8TensorData(
const std::shared_ptr<LogicalTensorData>& goldenDataView,
const std::shared_ptr<LogicalTensorData>& outputDataView, DataType fp8Format, float rtol, float atol,
int errorCountThreshold = 0, int failNum = 0);
void UpdateInterpreterCache();
void Initialize(
Function* entry, const std::vector<std::shared_ptr<LogicalTensorData>>& inputDataViewList,
const std::vector<std::shared_ptr<LogicalTensorData>>& outputDataViewList,
const std::vector<std::shared_ptr<LogicalTensorData>>& goldenDataViewList,
const std::shared_ptr<TensorSlotManager>& slotManager);
private:
Function* entry_;
bool checkResult{true};
std::vector<std::shared_ptr<LogicalTensorData>> inputDataViewList_;
std::vector<std::shared_ptr<LogicalTensorData>> outputDataViewList_;
std::vector<std::shared_ptr<LogicalTensorData>> goldenDataViewList_;
std::shared_ptr<FunctionInterpreter> functionInterpreter_;
std::shared_ptr<FunctionControlFlowExecution> controlFlowExecution_;
std::unordered_map<Function*, std::vector<std::shared_ptr<FunctionCaptureExecution>>> lastCaptureExecution_;
};
}