* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#include <fstream>
#include <regex>
#include <securec.h>
#include <unordered_set>
#include <sys/syscall.h>
#include <sys/mman.h>
#include <vector>
#include "common/checker.h"
#include "attribute_group/attr_group_symbolic_desc.h"
#include "attribute_group/attr_group_shape_env.h"
#include "code_printer.h"
#include "mmpa/mmpa_api.h"
#include "graph_metadef/graph/utils/file_utils.h"
#include "framework/common/framework_types_internal.h"
#include "mmpa/mmpa_api.h"
#include "guard_codegen.h"
#include "common/compile_profiling/ge_call_wrapper.h"
namespace ge {
namespace {
const std::string kSymbolCheckSoName = "guard_check.so";
const std::string kSymbolCheckFileName = "guard_check.cc";
const std::string k2Space = " ";
constexpr uint32_t kMaxFileNameLen = 128U;
constexpr char_t const *kGuardCheckSoDataResult = "_guard_check_so_data";
constexpr size_t kConstLogHeaderLen = 110U;
const std::regex kSafePathRegex(R"(^(?!.*\.{2})[A-Za-z0-9./+\-_]+$)");
Status GenCheckInfos(CodePrinter &printer, const std::vector<SymbolCheckInfo> &symbol_check_infos) {
for (const auto &iter : symbol_check_infos) {
printer.AddLine(k2Space + "check_result = " + std::string(iter.expr.Serialize().get()) + ";");
printer.AddLine(k2Space + "check_results.emplace_back(check_result);");
}
GELOGI("GenCheckInfos, symbol_check_infos size: %zu", symbol_check_infos.size());
return SUCCESS;
}
std::string GetFileHeader() {
std::string file_header = R"(
#include <cmath>
#include <unordered_map>
#include <type_traits>
#include <string>
#include <iostream>
#include <limits>
#include <algorithm>
#include <initializer_list>
#include "exe_graph/runtime/tensor.h"
#include "securec.h"
using char_t = char;
namespace {
// integer comparison
template <typename T>
typename std::enable_if<std::is_integral<T>::value, bool>::type
CompareEqual(T a, T b) {
return a == b;
}
// float comparison (absolute tolerance + relative tolerance)
template <typename T>
typename std::enable_if<std::is_floating_point<T>::value, bool>::type
CompareEqual(T a, T b) {
const T abs_diff = std::abs(a - b);
if (abs_diff <= std::numeric_limits<T>::epsilon()) {
return true;
}
const T rel_diff = abs_diff / std::max(std::abs(a), std::abs(b));
return rel_diff <= static_cast<T>(1e-6); // adjustable threshold
}
template <typename TA, typename TB,
typename std::enable_if<(std::is_floating_point<TA>::value || std::is_integral<TA>::value) &&
(std::is_floating_point<TB>::value || std::is_integral<TB>::value), bool>::type = 0>
bool CompareEqual(TA a, TB b) {
if (std::is_same<TA, TB>::value) {
return CompareEqual(a, b);
}
if ((std::is_floating_point<TA>::value) && (std::is_integral<TB>::value)) {
return CompareEqual(a, static_cast<TA>(b));
}
return CompareEqual(static_cast<TB>(a), b);
}
#define Max(a, b) ((a) > (b) ? (a) : (b))
#define Min(a, b) ((a) < (b) ? (a) : (b))
#define Pow(a, b) (std::pow(a, b))
#define Exp(a) (std::exp(a))
#define Log(a) (std::log(a))
#define Sqrt(a) (std::sqrt(a))
#define Ceiling(a) (std::ceil(a))
#define Floor(a) (std::floor(a))
#define Abs(a) (std::abs(a))
#define Rational(a, b) (static_cast<double>(a) / static_cast<double>(b))
#define ExpectEq(a, b) (CompareEqual(a, b))
#define ExpectNe(a, b) (!CompareEqual(a, b))
#define ExpectLe(a, b) (a <= b)
#define ExpectLt(a, b) (a < b)
#define True true
#define False false
#define LogicAnd(...) \
[&] { \
const bool values[] = {__VA_ARGS__}; \
return std::all_of(std::begin(values), std::end(values), [](bool x) { return x; }); \
}()
#define LogicOr(...) \
[&] { \
const bool values[] = {__VA_ARGS__}; \
return std::any_of(std::begin(values), std::end(values), [](bool x) { return x; }); \
}()
template <typename Ta, typename Tb>
auto Mod(Ta left, Tb right) -> decltype(left % right) {
return left % right;
}
// pecialized version for float (using std::fmod)
template <typename Ta, typename Tb>
auto Mod(Ta left, Tb right) ->
typename std::enable_if<std::is_floating_point<Ta>::value || std::is_floating_point<Tb>::value,
decltype(std::fmod(left, right))>::type {
return std::fmod(left, right);
}
struct GraphInputsContext {
GraphInputsContext(gert::Tensor **tensors, size_t num_tensors)
: input_tensors_(tensors), num_inputs_(num_tensors) {}
const gert::Tensor *GetGraphInputTensor(size_t data_index) const {
if (data_index > num_inputs_) {
return nullptr;
}
auto *tensor = reinterpret_cast<const gert::Tensor *>(input_tensors_[data_index]);
if (tensor == nullptr) {
return nullptr;
}
return tensor;
}
private:
gert::Tensor **input_tensors_{nullptr};
size_t num_inputs_{0U};
};
} // namespace
)";
return file_header;
}
Status GenErrorInfos(CodePrinter &printer, const std::vector<SymbolCheckInfo> &symbol_check_infos) {
for (const auto &iter : symbol_check_infos) {
printer.AddLine(k2Space + "err_msgs.emplace_back(\"[" + iter.file + ":" + std::to_string(iter.line) +
"] Check Symbol Check Expression: " + std::string(iter.expr.Serialize().get()) +
" Context Info: " + iter.dfx_info + " missed.\");");
}
return SUCCESS;
}
Status GenGetErrorMsg(CodePrinter &printer, const ge::ShapeEnvAttr *shape_env_attr) {
GE_ASSERT_NOTNULL(shape_env_attr);
printer.AddLine(k2Space + "std::vector<std::string> err_msgs;");
auto err_size = shape_env_attr->GetAllSymbolCheckInfos().size() + shape_env_attr->GetAllSymbolAssertInfos().size();
printer.AddLine(k2Space + "err_msgs.reserve(" + std::to_string(err_size) + ");");
GE_ASSERT_SUCCESS(GenErrorInfos(printer, shape_env_attr->GetAllSymbolCheckInfos()));
GE_ASSERT_SUCCESS(GenErrorInfos(printer, shape_env_attr->GetAllSymbolAssertInfos()));
printer.AddLine(k2Space + "return err_msgs;");
return SUCCESS;
}
Status GenGetCheckResults(CodePrinter &printer, ge::ShapeEnvAttr *shape_env_attr) {
GE_ASSERT_NOTNULL(shape_env_attr);
for (const auto &sym_to_source: shape_env_attr->GetAllSym2Src()) {
printer.AddLine(
k2Space + "const auto " + std::string(sym_to_source.first.Serialize().get()) + " = " + sym_to_source.second->
GetSourceStr() + ";");
}
printer.AddLine(k2Space + "std::vector<uint8_t> check_results;");
auto guard_size = shape_env_attr->GetAllSymbolCheckInfos().size() + shape_env_attr->GetAllSymbolAssertInfos().size();
printer.AddLine(k2Space + "check_results.reserve(" + std::to_string(guard_size) + ");");
printer.AddLine(k2Space + "uint8_t check_result;");
GE_ASSERT_SUCCESS(GenCheckInfos(printer, shape_env_attr->GetAllSymbolCheckInfos()));
GE_ASSERT_SUCCESS(GenCheckInfos(printer, shape_env_attr->GetAllSymbolAssertInfos()));
printer.AddLine(k2Space + "return check_results;");
return SUCCESS;
}
Status GenGuardCheckFunc(CodePrinter &printer) {
std::string check_code = R"( auto context = std::make_unique<GraphInputsContext>(tensors, num_tensors);
if (context == nullptr) {
return false;
}
auto check_results = GetCheckResults(context);
for (size_t i = 0UL; i < check_results.size(); i++) {
if (check_results[i] == 0U) {
const char *log_level = std::getenv("ASCEND_GLOBAL_LOG_LEVEL");
if ((log_level != nullptr) && (std::atoi(log_level) <= 1)) {
auto err_msgs = GetErrorMsg();
snprintf_s(reason, reason_size, reason_size - 1, err_msgs[i].c_str());
reason[reason_size - 1] = '\0';
}
return false;
}
}
return true;)";
printer.AddLine(check_code);
return SUCCESS;
}
Status GetOppPath(std::string &opp_path) {
GELOGI("Enter get opp path schedule");
const char_t *path_env = nullptr;
MM_SYS_GET_ENV(MM_ENV_ASCEND_OPP_PATH, path_env);
GE_ASSERT_TRUE(path_env != nullptr, "[Call][MM_SYS_GET_ENV] Failed to get env ASCEND_OPP_PATH.");
opp_path = path_env;
std::string file_path = RealPath(opp_path.c_str());
GE_ASSERT_TRUE(!file_path.empty(), "[Call][RealPath] File path %s is invalid.", opp_path.c_str());
if (opp_path.back() != '/') {
opp_path += '/';
}
GELOGI("Get opp path from env: %s", opp_path.c_str());
return SUCCESS;
}
void DebugCodeGenResult(const CodePrinter &printer) {
GELOGI("Guard Func Start:");
auto msg = printer.GetOutputStr();
size_t max_log_len = MSG_LENGTH - kConstLogHeaderLen;
size_t current_index = 0U;
while (current_index < msg.size()) {
auto sub_str = msg.substr(current_index, max_log_len);
GELOGI("\n%s", sub_str.c_str());
current_index += max_log_len;
}
GELOGI("Guard Func End");
}
bool CheckPathInCmdIsValid(const std::string &so_path, const std::string &file_path,
const std::string &ascend_include_path) {
return std::regex_match(so_path, kSafePathRegex) && std::regex_match(file_path, kSafePathRegex) &&
std::regex_match(ascend_include_path, kSafePathRegex);
}
graphStatus CompileGuardCheckFunc(const CodePrinter &printer, const ComputeGraphPtr &graph) {
char_t file_path[kMaxFileNameLen] = {"\0"};
auto src_fd = static_cast<int32_t>(syscall(__NR_memfd_create, kSymbolCheckFileName.c_str(), 0));
std::string output_str = printer.GetOutputStr();
mmWrite(src_fd, &output_str[0], output_str.size());
(void)lseek(static_cast<int32_t>(src_fd), 0, SEEK_SET);
auto ret = snprintf_s(file_path, kMaxFileNameLen, kMaxFileNameLen - 1U, "/proc/self/fd/%d", src_fd);
GE_ASSERT_TRUE(ret >= 0, "snprintf_s failed, ret: %d", ret);
GE_MAKE_GUARD(close_src_fd, [&src_fd]() {
mmClose(src_fd);
});
char_t so_path[kMaxFileNameLen] = {"\0"};
auto so_fd = static_cast<int32_t>(syscall(__NR_memfd_create, kSymbolCheckSoName.c_str(), 0));
ret = snprintf_s(so_path, kMaxFileNameLen, kMaxFileNameLen - 1U, "/proc/self/fd/%d", so_fd);
GE_ASSERT_TRUE(ret >= 0, "snprintf_s failed, ret: %d", ret);
GE_MAKE_GUARD(close_so_fd, [&so_fd]() {
mmClose(so_fd);
});
std::string opp_path;
GE_ASSERT_SUCCESS(GetOppPath(opp_path), "Failed to get environment variable ASCEND_OPP_PATH, please set it.");
std::string ascend_include = RealPath((opp_path + "/../include/").c_str());
GE_ASSERT_TRUE(!ascend_include.empty(), "[Call][RealPath] ascend include path %s is invalid.",
ascend_include.c_str());
const std::string so_path_dir(so_path);
const std::string file_path_dir(file_path);
std::string command = "g++ -O2 -fstack-protector-all -shared -fPIC -Wl,-z,now -Wl,-z,noexecstack -s -o " +
so_path_dir + " -I " + ascend_include + " -L c_sec" + " -x c++ " + file_path_dir;
GELOGI("To compile guard_check.so, command: %s", command.c_str());
GE_ASSERT_TRUE(CheckPathInCmdIsValid(so_path_dir, file_path_dir, ascend_include),
"CheckPathInCmdIsValid failed, so_path_dir = %s, file_path_dir = %s, ascend_include = %s.",
so_path_dir.c_str(), file_path_dir.c_str(), ascend_include.c_str());
GE_TIMESTAMP_START(GuardCheckCompile);
GE_ASSERT_TRUE(system(command.c_str()) == 0, "Failed to compile guard_check.so");
GE_TIMESTAMP_END(GuardCheckCompile, "GuardCheckCompile");
std::ifstream file(so_path, std::ios::binary | std::ios::ate);
GE_ASSERT_TRUE(file.is_open());
GE_MAKE_GUARD(close_file, [&file]() {
file.close();
});
std::streamsize file_size = file.tellg();
file.seekg(0, std::ios::beg);
std::vector<char_t> buffer(file_size);
file.read(buffer.data(), file_size);
GE_ASSERT_TRUE(AttrUtils::SetStr(graph, kGuardCheckSoDataResult, std::string(buffer.data(), buffer.size())));
return GRAPH_SUCCESS;
}
}
Status GuardCodegen::GuardFuncCodegenAndCompile(const ComputeGraphPtr &graph) const {
GE_ASSERT_NOTNULL(graph);
auto shape_env_attr = graph->GetAttrsGroup<ShapeEnvAttr>();
GE_ASSERT_NOTNULL(shape_env_attr);
CodePrinter printer;
printer.AddLine(GetFileHeader());
printer.AddLine("using char_t = char;");
printer.AddNamespaceBegin("");
printer.DefineFuncBegin("std::vector<std::string>", "GetErrorMsg", "");
GE_ASSERT_SUCCESS(GenGetErrorMsg(printer, shape_env_attr));
printer.DefineFuncEnd();
printer.DefineFuncBegin("std::vector<uint8_t>", "GetCheckResults",
"const std::unique_ptr<GraphInputsContext> &context");
GE_ASSERT_SUCCESS(GenGetCheckResults(printer, shape_env_attr));
printer.DefineFuncEnd();
printer.AddNamespaceEnd("");
printer.AddLine("extern \"C\"");
printer.DefineFuncBegin("bool", "GuardCheckFunc",
"gert::Tensor **tensors, size_t num_tensors, char_t *reason, size_t reason_size");
GE_ASSERT_SUCCESS(GenGuardCheckFunc(printer));
printer.DefineFuncEnd();
if (dlog_getlevel(GE_MODULE_NAME, nullptr) <= DLOG_INFO) {
DebugCodeGenResult(printer);
}
GE_ASSERT_SUCCESS(CompileGuardCheckFunc(printer, graph));
return GRAPH_SUCCESS;
}
}