* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#include "include/Utils.h"
#include <sstream>
using namespace llvm;
namespace mlir {
namespace asc {
StringRef fetchOpClass(StringRef defName)
{
auto split = defName.rsplit('_');
if (split.second.empty()) {
return defName;
}
return split.second;
}
void fetchResults(const DagInit* resultsDag, std::vector<VirtualArg>& dest)
{
const auto* outsOp = dyn_cast<DefInit>(resultsDag->getOperator());
assert(outsOp && outsOp->getDef()->getName() == "outs");
for (unsigned i = 0, e = resultsDag->getNumArgs(); i < e; ++i) {
VirtualArg result;
auto name = resultsDag->getArgNameStr(i);
if (name.empty()) {
result.name = "arg" + std::to_string(i);
} else {
result.name = name;
}
result.substitution = result.name;
const auto* init = dyn_cast<DefInit>(resultsDag->getArg(i));
assert(init && "argument must have defined types");
const auto* resultDef = init->getDef();
if (resultDef->isSubClassOf("Variadic")) {
result.cppType = "::std::vector< ::mlir::Type >";
} else {
result.cppType = "::mlir::Type";
}
dest.push_back(result);
}
}
void fetchArguments(const DagInit* argsDag, std::vector<VirtualArg>& dest)
{
const auto* insOp = dyn_cast<DefInit>(argsDag->getOperator());
assert(insOp && insOp->getDef()->getName() == "ins");
for (unsigned i = 0, e = argsDag->getNumArgs(); i < e; ++i) {
VirtualArg arg;
auto name = argsDag->getArgNameStr(i);
if (name.empty()) {
arg.name = "arg" + std::to_string(i);
} else {
arg.name = name;
}
arg.substitution = arg.name;
const auto* init = dyn_cast<DefInit>(argsDag->getArg(i));
assert(init && "argument must have defined types");
const auto* argDef = init->getDef();
if (argDef->isSubClassOf("TypeConstraint")) {
if (argDef->isSubClassOf("Variadic")) {
arg.cppType = "::std::vector< ::mlir::Value >";
} else {
arg.cppType = "::mlir::Value";
}
if (argDef->isSubClassOf("Optional")) {
arg.optional = true;
std::stringstream str;
str << arg.name << ".value_or(" << arg.cppType << "{})";
arg.substitution = str.str();
arg.cppType = "::std::optional< " + arg.cppType + " >";
arg.defaultValue = "py::none()";
}
} else if (argDef->isSubClassOf("AttrConstraint")) {
arg.cppType = argDef->getValueAsString("returnType");
if (argDef->getName() == "UnitAttr") {
arg.optional = true;
arg.substitution = arg.name;
arg.defaultValue = "false";
} else if (argDef->isSubClassOf("OptionalAttr")) {
arg.optional = true;
std::stringstream str;
str << arg.name << ".value_or(" << argDef->getValueAsString("storageType").str() << "{})";
arg.substitution = str.str();
arg.defaultValue = "py::none()";
}
}
dest.push_back(arg);
}
}
StringRef removeDialectPrefix(StringRef fullName, StringRef dialectName)
{
std::string prefix = (dialectName + "_").str();
if (fullName.starts_with(prefix)) {
return fullName.substr(dialectName.size() + 1);
}
return fullName;
}
StringRef removeAscDialectNameSpace(StringRef fullName, StringRef ascCppNamespace)
{
std::string prefix = ascCppNamespace.str();
if (fullName.starts_with(prefix)) {
return fullName.substr(ascCppNamespace.size());
}
return fullName;
}
}
}