* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#include <iostream>
#include <fstream>
#include <sstream>
#include "nlohmann/json.hpp"
#include <cstdlib>
#include "acl/acl_prof.h"
#include "graph/ge_error_codes.h"
#include "experiment/msprof/toolchain/prof_api.h"
#include "fused_graph_test.h"
static const int32_t SUCCESS = 0;
static const int32_t FAILED = -1;
using json = nlohmann::json;
using TilingFunc = ge::graphStatus(*)(void *, uint32_t *, uint32_t *, void *);
using LaunchFunc = uint32_t(*)(uint32_t, void *, void **, int32_t, void **, int32_t, void *, void *);
using GetTilingSizeFunc = size_t(*)();
std::map<std::string, ge::DataType> kDataTypeMap {
{"float32", ge::DT_FLOAT},
{"float16", ge::DT_FLOAT16},
{"int8", ge::DT_INT8},
{"int16", ge::DT_INT16},
{"int32", ge::DT_INT32},
{"int64", ge::DT_INT64},
{"uint8", ge::DT_UINT8},
{"uint16", ge::DT_UINT16},
{"uint32", ge::DT_UINT32},
{"uint64", ge::DT_UINT64},
{"double", ge::DT_DOUBLE},
{"bool", ge::DT_BOOL},
{"string", ge::DT_STRING}
};
int64_t GetShapeSize(const std::vector<int64_t>& shape) {
int64_t shapeSize = 1;
for (auto i : shape) {
shapeSize *= i;
}
return shapeSize;
}
int EnvInit(int32_t deviceId, aclrtStream* stream) {
auto ret = aclInit(nullptr);
CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("aclInit failed. ERROR: %d\n", ret); return ret);
ret = aclrtSetDevice(deviceId);
CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("aclrtSetDevice failed. ERROR: %d\n", ret); return ret);
ret = aclrtCreateStream(stream);
CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("aclrtCreateStream failed. ERROR: %d\n", ret); return ret);
return 0;
}
int32_t CreateDeviceData(const std::vector<int8_t>& host_data,
const std::vector<int64_t>& shape,
ge::DataType dtype, void** device_addr) {
auto size = GetShapeSize(shape) * ge::GetSizeByDataType(dtype);
auto ret = aclrtMalloc(device_addr, size, ACL_MEM_MALLOC_HUGE_FIRST);
CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("aclrtMalloc failed. ERROR: %d\n", ret); return ret);
ret = aclrtMemcpy(*device_addr, size, host_data.data(), size, ACL_MEMCPY_HOST_TO_DEVICE);
CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("aclrtMemcpy failed. ERROR: %d\n", ret); return ret);
return SUCCESS;
}
int32_t CreateDeviceDataV2(const void* host_addr, const std::vector<int64_t>& shape, ge::DataType dtype, void** device_addr) {
auto size = GetShapeSize(shape) * ge::GetSizeByDataType(dtype);
auto ret = aclrtMalloc(device_addr, size, ACL_MEM_MALLOC_HUGE_FIRST);
CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("aclrtMalloc failed. ERROR: %d\n", ret); return ret);
ret = aclrtMemcpy(*device_addr, size, host_addr, size, ACL_MEMCPY_HOST_TO_DEVICE);
CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("aclrtMemcpy failed. ERROR: %d\n", ret); return ret);
return SUCCESS;
}
void Finalize(int32_t deviceId, aclrtStream stream) {
aclrtDestroyStream(stream);
aclrtResetDevice(deviceId);
aclFinalize();
}
std::unique_ptr<int8_t[]> GenFromBin(const std::vector<int64_t>& shape, const std::string& file_path, ge::DataType data_type) {
int64_t data_len = GetShapeSize(shape) * ge::GetSizeByDataType(data_type);
std::ifstream bin_file(file_path.c_str(), std::ifstream::binary);
if (!bin_file.is_open()) {
LOG_PRINT("ERROR: open file %s failed.\n", file_path.c_str());
return nullptr;
}
bin_file.seekg(0, std::ifstream::end);
uint32_t file_size = bin_file.tellg();
bin_file.seekg(0, std::ifstream::beg);
if (data_len != file_size) {
bin_file.close();
LOG_PRINT("ERROR: data_len != file_size. \n");
return nullptr;
}
std::unique_ptr<int8_t[]> data_holder = MakeUnique<int8_t[]>(data_len);
bin_file.read(reinterpret_cast<char *>(data_holder.get()), data_len);
bin_file.close();
return data_holder;
}
int32_t SaveBinToFile(const std::string &file_path, void *data, size_t data_len) {
std::ofstream file(file_path, std::ios::out | std::ios::binary);
if (!file.is_open()) {
LOG_PRINT("ERROR: open file %s failed.\n", file_path.c_str());
return FAILED;
}
file.write(reinterpret_cast<char *>(data), data_len);
file.close();
return SUCCESS;
}
int32_t ReadConfigJson(const std::string &json_path, json &config_json) {
std::ifstream config_json_file(json_path);
if (!config_json_file.is_open()) {
LOG_PRINT("ERROR: open file %s failed.\n", json_path.c_str());
return FAILED;
}
try {
config_json_file >> config_json;
} catch (const nlohmann::json::exception &e) {
LOG_PRINT("ERROR: fail to read file %s, error msg: %s.\n", json_path.c_str(), e.what());
return FAILED;
}
return SUCCESS;
}
int32_t ParseMapConfig(const json &config_json, const std::string &key, std::map<std::string, std::string> &configs) {
if (config_json.contains(key)) {
try {
std::map<std::string, std::string> buff = config_json[key];
for (auto &it : buff) {
configs[it.first.c_str()] = it.second.c_str();
}
} catch (const nlohmann::json::exception &e) {
LOG_PRINT("ERROR: fail to parse %s from json %s, err msg: %s\n", key.c_str(), config_json[key].dump().c_str(), e.what());
return FAILED;
}
} else {
LOG_PRINT("ERROR: fail to parse %s from json %s\n", key.c_str(), config_json[key].dump().c_str());
return FAILED;
}
return SUCCESS;
}
std::vector<std::vector<int64_t>> ParseDataShape(const std::string &data_shape) {
std::vector<std::vector<int64_t>> result;
std::string temp;
std::istringstream ss(data_shape);
while (getline(ss, temp, ';')) {
std::vector<int64_t> inner_result;
std::istringstream inner_ss(temp);
while (getline(inner_ss, temp, ',')) {
inner_result.emplace_back(std::stol(temp));
}
result.emplace_back(inner_result);
}
return result;
}
std::vector<ge::DataType> ParseDataType(std::string &data_type) {
std::istringstream iss(data_type);
std::string token;
std::vector<ge::DataType> data_types;
while (getline(iss, token, ',')) {
token.erase(0, token.find_first_not_of(" "));
token.erase(token.find_last_not_of(" ") + 1);
if (kDataTypeMap.find(token) != kDataTypeMap.end()) {
data_types.emplace_back(kDataTypeMap[token]);
}
}
return data_types;
}
int32_t AutofuseKernelInfo::Init(void *stream, const std::string &config_path) {
stream_ = stream;
auto ret = InitKernelConfig(config_path);
if (ret != SUCCESS) {
LOG_PRINT("ERROR: InitKernelConfig fail\n.");
return FAILED;
}
(void)InitDeviceData();
return SUCCESS;
}
int32_t AutofuseKernelInfo::InitKernelConfig(const std::string &config_file_path) {
json config_json;
if (ReadConfigJson(config_file_path, config_json) != SUCCESS) {
LOG_PRINT("ERROR: Fail to read config json: %s\n.", config_file_path.c_str());
return FAILED;
}
const std::string kernel_config_key = "kernel_config";
std::map<std::string, std::string> json_kernel_config;
if (ParseMapConfig(config_json, kernel_config_key, json_kernel_config) != SUCCESS) {
LOG_PRINT("ERROR: Fail to parse config: %s\n.", kernel_config_key.c_str());
return FAILED;
}
for (auto &config : json_kernel_config) {
LOG_PRINT("kernel_config: %s : %s\n.", config.first.c_str(), config.second.c_str());
}
if (SetKernelConfig(json_kernel_config) != SUCCESS) {
LOG_PRINT("ERROR: Fail to set kernel config: %s\n.", kernel_config_key.c_str());
return FAILED;
}
return SUCCESS;
}
int32_t AutofuseKernelInfo::SetKernelConfig(std::map<std::string, std::string> &src_kernel_config) {
graph_name_snake_ = src_kernel_config["graph_name"];
kernel_config_.input_num = std::stoi(src_kernel_config["input_num"]);
kernel_config_.output_num = std::stoi(src_kernel_config["output_num"]);
kernel_config_.input_shape = ParseDataShape(src_kernel_config["input_shapes"]);
kernel_config_.output_shape = ParseDataShape(src_kernel_config["output_shapes"]);
kernel_config_.input_data_type = ParseDataType(src_kernel_config["input_data_types"]);
kernel_config_.output_data_type = ParseDataType(src_kernel_config["output_data_types"]);
if (kernel_config_.input_num != kernel_config_.input_shape.size()) {
LOG_PRINT("ERROR: input_num[%d] != input_shape size[%d]\n.", kernel_config_.input_num, kernel_config_.input_shape.size());
return FAILED;
}
return SUCCESS;
}
int32_t AutofuseKernelInfo::InitDeviceData() {
input_addr_ = MakeUnique<void *[]>(kernel_config_.input_num);
output_addr_ = MakeUnique<void *[]>(kernel_config_.output_num);
std::vector<std::unique_ptr<int8_t[]>> input_host_data;
for (int i = 0; i < kernel_config_.input_num; ++i) {
std::string file_path = "input/Data_" + std::to_string(i) + ".bin";
auto tmp_input = GenFromBin(kernel_config_.input_shape[i], file_path, kernel_config_.input_data_type[i]);
input_host_data.emplace_back(std::move(tmp_input));
}
std::vector<std::vector<int8_t>> output_host_data;
for (int i = 0; i < kernel_config_.output_num; ++i) {
int64_t data_size = GetShapeSize(kernel_config_.output_shape[i]) * ge::GetSizeByDataType(kernel_config_.output_data_type[i]);
std::vector<int8_t> tmp_host_data;
for (int64_t j = 0; j < data_size; ++j) {
tmp_host_data.emplace_back((int8_t)1);
}
output_host_data.emplace_back(tmp_host_data);
}
for (int i = 0; i < kernel_config_.input_num; ++i) {
auto ret = CreateDeviceDataV2((void *)input_host_data[i].get(), kernel_config_.input_shape[i], kernel_config_.input_data_type[i], &input_addr_[i]);
CHECK_RET(ret == SUCCESS, LOG_PRINT("Create input device data failed. input: %d, ERROR: %d\n", i, ret); return FAILED);
}
for (int i = 0; i < kernel_config_.output_num; ++i) {
auto ret = CreateDeviceData(output_host_data[i], kernel_config_.output_shape[i], kernel_config_.output_data_type[i], &output_addr_[i]);
CHECK_RET(ret == SUCCESS, LOG_PRINT("Create output device data failed. output: %d, ERROR: %d\n", i, ret); return FAILED);
}
return SUCCESS;
}
int32_t AutofuseKernelInfo::LoadSoHandles() {
std::string so_path = graph_name_snake_ + ".so";
LOG_PRINT("so path: %s\n", so_path.c_str());
char real_path[MMPA_MAX_PATH] = {};
auto ret = mmRealPath(so_path.c_str(), &real_path[0], MMPA_MAX_PATH);
if (ret != EN_OK) {
LOG_PRINT("mmRealPath failed, ERROR: %d, so path: %s\n", ret, so_path.c_str());
return FAILED;
}
handles_ = mmDlopen(real_path, static_cast<int32_t>(MMPA_RTLD_NOW));
if (handles_ == nullptr) {
LOG_PRINT("so %s open failed\n", so_path.c_str());
return FAILED;
}
LOG_PRINT("so %s open success, handles_:%p\n", so_path.c_str(), handles_);
return SUCCESS;
}
int32_t AutofuseKernelInfo::ParseTaskRunParam() {
std::string get_tiling_size_func_name = "GetTilingDataSize";
const auto get_tiling_size_func = reinterpret_cast<GetTilingSizeFunc>(mmDlsym(handles_, get_tiling_size_func_name.c_str()));
if (get_tiling_size_func == nullptr) {
LOG_PRINT("ERROR: get_tiling_size_func is nullptr, %s\n", get_tiling_size_func_name.c_str());
return FAILED;
}
tiling_size_ = get_tiling_size_func();
LOG_PRINT("name: %s, tiling_size_: %zu, func_name: %s\n", graph_name_snake_.c_str(), tiling_size_, get_tiling_size_func_name.c_str());
return SUCCESS;
}
int32_t AutofuseKernelInfo::DoTiling(std::unique_ptr<uint8_t[]> &tiling_data_holder, uint32_t &workspace_size) {
std::string tiling_func_name = "AutofuseTiling";
const auto tiling_func = reinterpret_cast<TilingFunc>(mmDlsym(handles_, tiling_func_name.c_str()));
if (tiling_func == nullptr) {
LOG_PRINT("ERROR: tiling_func is nullptr, %s\n", tiling_func_name.c_str());
return FAILED;
}
auto tiling_data = static_cast<void *>(tiling_data_holder.get());
std::vector<uint32_t> symbols_value;
auto ret = tiling_func(tiling_data, &workspace_size, &block_dim_, nullptr);
if (ret != 0) {
LOG_PRINT("ERROR: tiling_func failed, ret: %d\n", ret);
return FAILED;
}
return SUCCESS;
}
int32_t AutofuseKernelInfo::MallocWorkSpace(uint32_t &size) {
if (size == 0) {
LOG_PRINT("WARN: workspace size is %u\n", size);
return SUCCESS;
}
auto ret = aclrtMalloc(&workspace_, size, ACL_MEM_MALLOC_HUGE_FIRST);
if (ret != ACL_SUCCESS) {
LOG_PRINT("ERROR: allocate workspace failed, ret: %d\n", ret);
return FAILED;
}
ret = aclrtMemset(workspace_, size, 0U, size);
if (ret != ACL_SUCCESS) {
LOG_PRINT("ERROR: memset workspace failed, ret: %d\n", ret);
(void)aclrtFree(workspace_);
return FAILED;
}
return SUCCESS;
}
void ReportApiInfo(const uint64_t beginTime) {
MsprofApi info{};
info.type = MSPROF_REPORT_NODE_LAUNCH_TYPE;
info.itemId = 0UL;
info.level = MSPROF_REPORT_NODE_LEVEL;
info.threadId = gettid();
info.beginTime = beginTime;
info.endTime = MsprofSysCycleTime();
info.magicNumber = MSPROF_REPORT_DATA_MAGIC_NUM;
info.reserve = 0U;
const int32_t res = MsprofReportApi(true, &info);
}
void ReportNodeBasicInfo(const uint64_t beginTime, uint32_t block_dim, std::string graph_name) {
struct MsprofCompactInfo nodeInfo;
nodeInfo.level = MSPROF_REPORT_NODE_LEVEL;
nodeInfo.type = MSPROF_REPORT_NODE_BASIC_INFO_TYPE;
nodeInfo.threadId = gettid();
nodeInfo.timeStamp = beginTime;
nodeInfo.data.nodeBasicInfo.opName = MsprofGetHashId(graph_name.c_str(), graph_name.length());
nodeInfo.data.nodeBasicInfo.opType = MsprofGetHashId(graph_name.c_str(), graph_name.length());
nodeInfo.data.nodeBasicInfo.taskType = MSPROF_GE_TASK_TYPE_AIV;
nodeInfo.data.nodeBasicInfo.blockDim = block_dim;
MsprofReportCompactInfo(false, static_cast<void *>(&nodeInfo), sizeof(MsprofCompactInfo));
}
int32_t AutofuseKernelInfo::Distribute(uint64_t *tiling_data) {
const std::string launch_func_name = "AutofuseLaunchV2";
const auto launch_func = reinterpret_cast<LaunchFunc>(mmDlsym(handles_, launch_func_name.c_str()));
if (launch_func == nullptr) {
LOG_PRINT("ERROR: launch_func is nullptr, %s\n", launch_func_name.c_str());
return FAILED;
}
const uint64_t startTime = MsprofSysCycleTime();
auto ret = launch_func(block_dim_, stream_, input_addr_.get(), kernel_config_.input_num,
output_addr_.get(), kernel_config_.output_num, workspace_, reinterpret_cast<void*>(tiling_data));
CHECK_RET(ret == SUCCESS, LOG_PRINT("launch failed, ERROR: %d\n", ret); return FAILED);
ReportNodeBasicInfo(startTime, block_dim_, graph_name_snake_);
ReportApiInfo(startTime);
ret = aclrtSynchronizeStream(stream_);
CHECK_RET(ret == SUCCESS, LOG_PRINT("aclrtSynchronizeStream failed, ERROR: %d\n", ret); return FAILED);
for (int i = 0; i < kernel_config_.output_num; ++i) {
auto output_size = GetShapeSize(kernel_config_.output_shape[i]) * ge::GetSizeByDataType(kernel_config_.output_data_type[i]);
std::vector<int8_t> result_data(output_size, 0);
ret = aclrtMemcpy(result_data.data(), result_data.size() * sizeof(result_data[0]),
output_addr_[i], output_size, ACL_MEMCPY_DEVICE_TO_HOST);
auto file_path = "out/Output_" + std::to_string(i) + ".bin";
SaveBinToFile(file_path, (void *)result_data.data(), output_size);
}
return SUCCESS;
}
int FuseGraphTest(int deviceId, aclrtStream &stream, const std::string config_path) {
auto ret = EnvInit(deviceId, &stream);
CHECK_FREE_RET(ret == ACL_SUCCESS, LOG_PRINT("Init acl failed. ERROR: %d\n", ret); return FAILED);
AutofuseKernelInfo fuse_kernel;
fuse_kernel.Init(stream, config_path);
ret = fuse_kernel.LoadSoHandles();
CHECK_FREE_RET(ret == ACL_SUCCESS, LOG_PRINT("LoadSoHandles failed. ERROR: %d\n", ret); return FAILED);
ret = fuse_kernel.ParseTaskRunParam();
CHECK_FREE_RET(ret == ACL_SUCCESS, LOG_PRINT("ParseTaskRunParam failed. ERROR: %d\n", ret); return FAILED);
size_t tiling_size = fuse_kernel.GetTilingSize();
std::unique_ptr<uint8_t[]> tiling_data_holder = MakeUnique<uint8_t[]>(tiling_size);
CHECK_FREE_RET(tiling_data_holder != nullptr, LOG_PRINT("ERROR: tiling_data_holder is nullptr\n"); return FAILED);
uint32_t workspace_size = 0U;
ret = fuse_kernel.DoTiling(tiling_data_holder, workspace_size);
CHECK_FREE_RET(ret == SUCCESS, LOG_PRINT("ERROR: DoTiling failed.\n"); return FAILED);
LOG_PRINT("workspace_size: %u\n", workspace_size);
ret = fuse_kernel.MallocWorkSpace(workspace_size);
CHECK_FREE_RET(ret == SUCCESS, LOG_PRINT("ERROR: MallocWorkSpace failed.\n"); return FAILED);
uint64_t *tiling_data = reinterpret_cast<uint64_t *>(tiling_data_holder.get());
ret = fuse_kernel.Distribute(tiling_data);
CHECK_FREE_RET(ret == SUCCESS, LOG_PRINT("ERROR: Distribute failed.\n"); return FAILED);
LOG_PRINT("Distribute success.\n");
return ACL_SUCCESS;
}
int main(int argc, char *argv[]) {
if ((argc != 2) || (argv[1] == nullptr)) {
std::string cmd_demo = "./fused_graph_test /home/autofuse/asc_graph_test/config/fused_graph_kernel_config.json";
LOG_PRINT("ERROR: Please input config path. The cmd like: %s\n", cmd_demo.c_str());
return FAILED;
}
std::string kernel_config_path = argv[1];
LOG_PRINT("DEBUG: kernel_config_path: %s\n", kernel_config_path.c_str());
int32_t deviceId = 0;
const char* env_var = std::getenv("ASCEND_DEVICE_ID");
if (env_var != nullptr) {
try {
deviceId = std::stoi(env_var);
} catch (const std::exception &e) {
std::cerr << "ASCEND_DEVICE_ID : [" << env_var << "] is error" << std::endl;
}
}
LOG_PRINT("DEBUG: using device id : %d\n", deviceId);
aclrtStream stream;
auto ret = FuseGraphTest(deviceId, stream, kernel_config_path);
CHECK_FREE_RET(ret == ACL_SUCCESS, LOG_PRINT("FuseGraphTest failed. ERROR: %d\n", ret); return ret);
Finalize(deviceId, stream);
return SUCCESS;
}