* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
* \file function.cpp
* \brief Python bindings for Function class
*/
#include "pybind_common.h"
#include "interface/function/function.h"
#include "interface/program/program.h"
using namespace npu::tile_fwk;
namespace pypto {
void BindFunction(py::module_& m)
{
py::class_<Function, std::shared_ptr<Function>>(m, "Function")
.def("GetMagicName", &Function::GetMagicName, "Get the magic name of the function")
.def("GetRawName", &Function::GetRawName, "Get the raw name of the function")
.def("GetFuncMagic", &Function::GetFuncMagic, "Get the function magic number")
.def("Dump", &Function::Dump, "Dump the function in brief format")
.def("DumpSSA", &Function::DumpSSA, "Dump the function in SSA format")
.def("GetFunctionType", &Function::GetFunctionType, "Get the function type")
.def("GetFunctionTypeStr", &Function::GetFunctionTypeStr, "Get the function type as string")
.def("GetGraphType", &Function::GetGraphType, "Get the graph type")
.def("IsEager", &Function::IsEager, "Check if function is eager")
.def("IsStatic", &Function::IsStatic, "Check if function is static")
.def("IsExplicit", &Function::IsExplicit, "Check if function is explicit")
.def(
"IsFunctionType", py::overload_cast<FunctionType>(&Function::IsFunctionType, py::const_), py::arg("type"),
"Check if function is of given type")
.def(
"IsGraphType", py::overload_cast<GraphType>(&Function::IsGraphType, py::const_), py::arg("type"),
"Check if function has given graph type")
.def(
"IsFunctionTypeAndGraphType",
py::overload_cast<FunctionType, GraphType>(&Function::IsFunctionTypeAndGraphType, py::const_),
py::arg("func_type"), py::arg("graph_type"), "Check if function has given function type and graph type")
.def("HasParent", &Function::HasParent, "Check if function has a parent")
.def("GetRootFunction", &Function::GetRootFunction, py::return_value_policy::reference, "Get the root function")
.def(
"GetIncast",
[](const Function& self) -> std::vector<Tensor> {
return std::vector<Tensor>(self.GetIncast().begin(), self.GetIncast().end());
},
py::return_value_policy::reference_internal, "Get input casts")
.def(
"GetOutcast",
[](const Function& self) -> std::vector<Tensor> {
return std::vector<Tensor>(self.GetOutcast().begin(), self.GetOutcast().end());
},
py::return_value_policy::reference_internal, "Get output casts")
.def(
"GetOriginIncast",
[](const Function& self) -> std::vector<Tensor> {
return std::vector<Tensor>(self.GetOriginIncast().begin(), self.GetOriginIncast().end());
},
py::return_value_policy::reference_internal, "Get original input casts")
.def(
"GetOriginOutcast",
[](const Function& self) -> std::vector<Tensor> {
return std::vector<Tensor>(self.GetOriginOutcast().begin(), self.GetOriginOutcast().end());
},
py::return_value_policy::reference_internal, "Get original output casts")
.def("DumpFile", &Function::DumpFile, py::arg("file_path"), "Dump the function to a file")
.def(
"DumpJsonFile", [](Function& self, const std::string& fileName) { self.DumpJsonFile(fileName); },
py::arg("file_name") = "", "Dump the function to a JSON file")
.def("__repr__", [](const Function& self) {
return "<Function '" + self.GetRawName() + "' (magic: " + std::to_string(self.GetFuncMagic()) + ")>";
});
m.def(
"GetLastFunction", []() -> Function* { return Program::GetInstance().GetLastFunction(); },
py::return_value_policy::reference, "Get the last compiled function from the Program");
m.def(
"GetCurrentFunction", []() -> Function* { return Program::GetInstance().GetCurrentFunction(); },
py::return_value_policy::reference, "Get the current function being built in the Program");
}
}