* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
* \file test_data_loader.h
* \brief
*/
#pragma once
#include <string>
#include <unordered_map>
#include <nlohmann/json.hpp>
#include <fstream>
#include <variant>
#include "tilefwk/tensor.h"
#include "interface/interpreter/raw_tensor_data.h"
#include "interface/tensor/float.h"
#include "test_dev_func_runner.h"
#include "test_data_prepare.h"
#include <iomanip>
using Scalar = std::variant<int, float, double, bool, std::string>;
constexpr int COLUMN_WIDTH = 20;
namespace {
std::string ShapeToString(const Shape& shape)
{
std::ostringstream oss;
oss << "(";
for (size_t i = 0; i < shape.size(); ++i) {
if (i != 0)
oss << ", ";
oss << shape[i];
}
oss << ")";
return oss.str();
}
void PrintTableRow(int index, const Tensor& tensor)
{
std::string name = tensor.GetName();
std::string shapeStr = ShapeToString(tensor.GetShape());
std::string typeStr = DataType2String(tensor.GetDataType());
std::string formatStr = tensor.GetStorage()->Format() == TileOpFormat::TILEOP_ND ? "TILEOP_ND" : "TILEOP_NZ";
std::cout << "| " << std::setw(COLUMN_WIDTH / 2) << std::left << index << " | " << std::setw(COLUMN_WIDTH)
<< std::left << name << " | " << std::setw(COLUMN_WIDTH) << std::left << shapeStr << " | "
<< std::setw(COLUMN_WIDTH / 2) << std::left << typeStr << " | " << std::setw(COLUMN_WIDTH / 2)
<< std::left << formatStr << " |" << std::endl;
}
void PrintTableDivider()
{
std::cout << "+------------+----------------------+----------------------+------------+------------+" << std::endl;
}
Scalar ConvertJsonToScalar(const nlohmann::json& value)
{
if (value.is_number_integer()) {
return value.get<int>();
} else if (value.is_number_float()) {
return value.get<double>();
} else if (value.is_boolean()) {
return value.get<bool>();
} else if (value.is_string()) {
return value.get<std::string>();
} else {
throw std::invalid_argument("Unsupported JSON type for Scalar");
}
}
bool IsValidWeightFormat(int value)
{
return value == static_cast<int>(TileOpFormat::TILEOP_ND) || value == static_cast<int>(TileOpFormat::TILEOP_NZ);
}
}
class TestDataLoader {
public:
TestDataLoader(const std::string& path) : configPath(path)
{
std::ifstream configFile(this->configPath);
if (!configFile.is_open()) {
throw std::runtime_error("Failed to open config file: " + this->configPath);
}
try {
meta = nlohmann::json::parse(configFile);
} catch (const nlohmann::json::exception& e) {
throw std::runtime_error("Failed to parse config file: " + std::string(e.what()));
}
LoadParameters();
LoadInputTensors();
LoadGoldenTensors();
}
auto GetParams() { return params; }
auto Param(const std::string& name) { return params.at(name); }
Tensor& InputTensor(const std::string& name) { return inputTensors.at(name); }
Tensor& OutputTensor(const std::string& name) { return outputTensors.at(name); }
const std::vector<std::reference_wrapper<const Tensor>>& GetInputTensorList() const { return inputTensorList; }
const std::vector<std::reference_wrapper<const Tensor>>& GetOutputTensorList() const { return outputTensorList; }
std::vector<RawTensorDataPtr> GetInputDataList() { return inputDataList; }
std::vector<RawTensorDataPtr> GetOutputDataList() { return outputDataList; }
RawTensorDataPtr GoldenData(const std::string& name) { return goldens.at(name); }
int GetInputNameToIdx(const std::string& name) { return inputNameToIdx.at(name); }
int GetOutputNameToIdx(const std::string& name) { return outputNameToIdx.at(name); }
Tensor& InputTensorCheck(
const std::string& name, const DataType dtype, const Shape& shape,
const TileOpFormat format = TileOpFormat::TILEOP_ND)
{
auto it = this->inputTensors.find(name);
if (it == this->inputTensors.end()) {
throw std::runtime_error("Tensor " + name + " not found in input tensors.");
}
const DataType curDType = it->second.GetDataType();
const Shape& curShape = it->second.GetShape();
const TileOpFormat curOpFormat = it->second.GetStorage()->Format();
if (dtype != curDType) {
throw std::runtime_error("Data type mismatch for inputTensor " + name);
}
if (shape.size() != curShape.size()) {
throw std::runtime_error("Shape size mismatch for inputTensor " + name);
}
for (size_t i = 0; i < shape.size(); ++i) {
if (curShape[i] != -1 && shape[i] != curShape[i]) {
throw std::runtime_error("Shape mismatch for inputTensor " + name);
}
}
if (format != curOpFormat) {
throw std::runtime_error("Weight format mismatch for inputTensor " + name);
}
return inputTensors.at(name);
}
Tensor& OutputTensorCheck(
const std::string& name, const DataType dtype, const Shape& shape,
const TileOpFormat format = TileOpFormat::TILEOP_ND)
{
auto it = this->outputTensors.find(name);
if (it == this->outputTensors.end()) {
throw std::runtime_error("Tensor " + name + " not found in output tensors.");
}
const DataType curDType = it->second.GetDataType();
const Shape& curShape = it->second.GetShape();
const TileOpFormat curOpFormat = it->second.GetStorage()->Format();
if (dtype != curDType) {
throw std::runtime_error("Data type mismatch for outputTensor " + name);
}
if (shape.size() != curShape.size()) {
throw std::runtime_error("Shape size mismatch for outputTensor " + name);
}
for (size_t i = 0; i < shape.size(); ++i) {
if (curShape[i] != -1 && shape[i] != curShape[i]) {
throw std::runtime_error("Shape mismatch for outputTensor " + name);
}
}
if (format != curOpFormat) {
throw std::runtime_error("Weight format mismatch for outputTensor " + name);
}
return outputTensors.at(name);
}
RawTensorDataPtr GoldenDataCheck(const std::string& name, const DataType dtype, const Shape& shape)
{
auto it = this->goldens.find(name);
if (it == this->goldens.end()) {
throw std::runtime_error("Golden tensor " + name + " not found.");
}
const DataType curDType = it->second->GetDataType();
const Shape& curShape = it->second->GetShape();
if (dtype != curDType) {
throw std::runtime_error("Data type mismatch for Golden " + name);
}
if (shape.size() != curShape.size()) {
throw std::runtime_error("Shape size mismatch for Golden " + name);
}
for (size_t i = 0; i < shape.size(); ++i) {
if (shape[i] != curShape[i]) {
throw std::runtime_error("Shape mismatch for Golden " + name);
}
}
return goldens.at(name);
}
void Dump()
{
int index = 0;
std::cout << "Input tensors: ";
PrintTableDivider();
std::cout << "| " << std::setw(COLUMN_WIDTH / 2) << std::left << "Index"
<< " | " << std::setw(COLUMN_WIDTH) << std::left << "Name"
<< " | " << std::setw(COLUMN_WIDTH) << std::left << "Shape"
<< " | " << std::setw(COLUMN_WIDTH / 2) << std::left << "Datatype"
<< " | " << std::setw(COLUMN_WIDTH / 2) << std::left << "OpFormat"
<< " |" << std::endl;
;
PrintTableDivider();
for (const auto& tensor : inputTensorList) {
PrintTableRow(index, tensor);
index++;
}
PrintTableDivider();
index = 0;
std::cout << "Output tensors: ";
PrintTableDivider();
std::cout << "| " << std::setw(COLUMN_WIDTH / 2) << std::left << "Index"
<< " | " << std::setw(COLUMN_WIDTH) << std::left << "Name"
<< " | " << std::setw(COLUMN_WIDTH) << std::left << "Shape"
<< " | " << std::setw(COLUMN_WIDTH / 2) << std::left << "Datatype"
<< " | " << std::setw(COLUMN_WIDTH / 2) << std::left << "OpFormat"
<< " |" << std::endl;
PrintTableDivider();
for (const auto& tensor : outputTensorList) {
PrintTableRow(index, tensor);
index++;
}
PrintTableDivider();
}
static std::pair<Tensor, RawTensorDataPtr> CreateTensor(
const std::string& name, const std::string& dtype, const Shape& shape, const std::string& fileName,
const TileOpFormat format = TileOpFormat::TILEOP_ND)
{
Tensor t(CostModel::ToDataType(const_cast<string&>(dtype)), shape, name, format);
auto dataPtr = CreateHelper(dtype, t, fileName);
return std::make_pair(t, dataPtr);
}
static RawTensorDataPtr CreateHelper(const std::string& dtype, const Tensor& tensor, const std::string& fileName)
{
static const std::unordered_map<std::string, std::function<RawTensorDataPtr(const Tensor&, const std::string&)>>
creators = {
{"DT_INT8", TestDataLoader::CreateTensorData<int8_t>},
{"DT_INT32", TestDataLoader::CreateTensorData<int32_t>},
{"DT_INT64", TestDataLoader::CreateTensorData<int64_t>},
{"DT_FP16", TestDataLoader::CreateTensorData<npu::tile_fwk::float16>},
{"DT_FP32", TestDataLoader::CreateTensorData<float>},
{"DT_BF16", TestDataLoader::CreateTensorData<npu::tile_fwk::bfloat16>},
};
auto it = creators.find(dtype);
if (it != creators.end()) {
return it->second(tensor, fileName);
}
std::cerr << "Unsupported type : " << dtype;
return nullptr;
}
template <typename T>
static RawTensorDataPtr CreateTensorData(const Tensor& tensor, const std::string& fileName);
Tensor& SetInputDynAxis(const std::string& name, const std::vector<int>& dynAxises)
{
auto& tensor = this->inputTensors.at(name);
Shape dynamicShape = tensor.GetShape();
for (int axis : dynAxises) {
ASSERT(axis >= 0 && (size_t)axis < dynamicShape.size());
dynamicShape[axis] = -1;
}
Tensor dynamicT(tensor.GetDataType(), dynamicShape, tensor.GetName(), tensor.GetStorage()->Format());
auto [it, is_inserted] = this->inputTensors.insert_or_assign(name, dynamicT);
ASSERT(is_inserted);
this->inputTensorList[this->inputNameToIdx.at(name)] = std::cref(it->second);
return it->second;
}
Tensor& SetOutputDynAxis(const std::string& name, const std::vector<SymbolicScalar>& dynShape)
{
auto& tensor = this->outputTensors.at(name);
ASSERT(tensor.GetShape().size() == dynShape.size());
Tensor dynamicT(tensor.GetDataType(), dynShape, tensor.GetName(), tensor.GetStorage()->Format());
auto [it, is_inserted] = this->outputTensors.insert_or_assign(name, dynamicT);
ASSERT(is_inserted);
this->outputTensorList[this->outputNameToIdx.at(name)] = std::cref(it->second);
return it->second;
}
private:
void LoadParameters()
{
if (!meta.contains("parameters") || !meta["parameters"].is_object()) {
throw std::runtime_error("Config file does not contain a valid 'parameters' object.");
}
const auto& paramsJson = meta["parameters"];
for (const auto& [key, value] : paramsJson.items()) {
params[key] = ConvertJsonToScalar(value);
}
}
void LoadInputTensors()
{
if (!meta.contains("inputs") || !meta["inputs"].is_object()) {
throw std::runtime_error("Config file does not contain a valid 'inputs' object.");
}
const auto& inputsJson = meta["inputs"];
size_t index = 0;
for (const auto& [tensorName, tensorConfig] : inputsJson.items()) {
if (!tensorConfig.is_object()) {
throw std::runtime_error("Invalid tensor configuration for tensor: " + tensorName);
}
std::string binFile = tensorConfig["bin_file"];
if (binFile.empty()) {
throw std::runtime_error("Missing 'bin_file' in tensor configuration: " + tensorName);
}
binFile = GetFullPath(binFile);
Shape shape;
if (tensorConfig.contains("shape") && tensorConfig["shape"].is_array()) {
shape = tensorConfig["shape"].get<Shape>();
} else {
throw std::runtime_error("Missing or invalid 'shape' in tensor configuration: " + tensorName);
}
TileOpFormat opFormat;
if (tensorConfig.contains("opFormat") && tensorConfig["opFormat"].is_number_integer()) {
int value = tensorConfig["opFormat"].get<int>();
if (IsValidWeightFormat(value)) {
opFormat = static_cast<TileOpFormat>(value);
} else {
throw std::runtime_error(
"Undefined 'opFormat' value in tensor configuration: '" + tensorName +
"': " + std::to_string(value));
}
} else {
throw std::runtime_error("Missing or invalid 'opFormat' in tensor configuration: " + tensorName);
}
std::string dtypeStr;
if (tensorConfig.contains("dtype") && tensorConfig["dtype"].is_string()) {
dtypeStr = tensorConfig["dtype"];
}
auto [tensor, dataPtr] = CreateTensor(tensorName, dtypeStr, shape, binFile, opFormat);
auto [it, is_inserted] = inputTensors.emplace(tensorName, tensor);
ASSERT(is_inserted);
inputTensorList.push_back(std::cref(it->second));
this->inputDataList.push_back(dataPtr);
this->inputNameToIdx.emplace(tensorName, index++);
}
}
void LoadGoldenTensors()
{
size_t index = 0;
for (auto& [name, value] : this->meta["golden_outputs"].items()) {
auto binFile = value["bin_file"].get<std::string>();
auto shape = value["shape"].get<Shape>();
auto dtype = value["dtype"].get<std::string>();
binFile = GetFullPath(binFile);
auto [tensor, dataPtr] = CreateTensor(name, dtype, shape, binFile);
(void)tensor;
this->goldens.emplace(name, dataPtr);
auto [output, outputData] = CreateTensor(name, dtype, shape, std::string());
auto [it, is_inserted] = this->outputTensors.emplace(name, output);
ASSERT(is_inserted);
this->outputTensorList.push_back(std::cref(it->second));
this->outputDataList.push_back(outputData);
this->outputNameToIdx.emplace(name, index++);
}
}
std::string GetFullPath(const std::string& relativePath)
{
size_t pos = this->configPath.find_last_of('/');
ASSERT(pos != std::string::npos);
std::string fullPath = this->configPath.substr(0, pos + 1) + relativePath;
return fullPath;
}
nlohmann::json meta;
std::string configPath;
std::unordered_map<std::string, Scalar> params;
std::unordered_map<std::string, Tensor> inputTensors;
std::unordered_map<std::string, Tensor> outputTensors;
std::unordered_map<std::string, RawTensorDataPtr> goldens;
std::unordered_map<std::string, int> inputNameToIdx;
std::vector<std::reference_wrapper<const Tensor>> inputTensorList;
std::vector<RawTensorDataPtr> inputDataList;
std::unordered_map<std::string, int> outputNameToIdx;
std::vector<std::reference_wrapper<const Tensor>> outputTensorList;
std::vector<RawTensorDataPtr> outputDataList;
};
template <typename T>
RawTensorDataPtr TestDataLoader::CreateTensorData(const Tensor& tensor, const std::string& fileName)
{
auto shape = tensor.GetShape();
int capacity = std::accumulate(shape.begin(), shape.end(), 1, std::multiplies<>());
std::vector<T> values(capacity, 0);
if (!fileName.empty()) {
readInput<T>(fileName, values);
}
return RawTensorData::CreateTensor<T>(tensor, values);
}