import numpy as np
from ge.es.graph_builder import GraphBuilder, TensorHolder
from ge.graph import Tensor
from ge.graph.types import DataType, Format
from ge.graph import Graph, DumpFormat
from ge.ge_global import GeApi
from ge.session import Session
from ge.es.all import If
def build_if_graph():
builder = GraphBuilder("MakeIfGraph")
input_tensor_holder = builder.create_input(
index=0,
name="test_input",
data_type=DataType.DT_FLOAT,
shape=[2]
)
cond = builder.create_input(
index=1,
name="cond",
data_type=DataType.DT_INT32,
shape=[]
)
then_branch_builder = GraphBuilder("then_branch")
then_branch_input_tensor = then_branch_builder.create_input(
index=0,
name="then_input",
data_type=DataType.DT_FLOAT,
shape=[2]
)
const_tensor_then_branch = then_branch_builder.create_const_int64(5)
then_result = const_tensor_then_branch + then_branch_input_tensor
then_branch_builder.set_graph_output(then_result, 0)
then_graph = then_branch_builder.build_and_reset()
else_branch_builder = GraphBuilder("else_branch")
else_branch_input_tensor = else_branch_builder.create_input(
index=0,
name="then_input",
data_type=DataType.DT_FLOAT,
shape=[2]
)
const_tensor_else_branch = else_branch_builder.create_const_int64(2)
else_result = else_branch_input_tensor + const_tensor_else_branch
else_branch_builder.set_graph_output(else_result, 0)
else_graph = else_branch_builder.build_and_reset()
inputs_tensor_holder_list = [input_tensor_holder]
If(cond, inputs_tensor_holder_list, 1, then_graph, else_graph)
return builder.build_and_reset()
def dump_if_graph(graph):
graph.dump_to_file(format=DumpFormat.kOnnx, suffix="make_if_graph")
def run_graph(graph) -> None:
config = {
"ge.exec.deviceId": "0",
"ge.graphRunMode": "0"
}
ge_api = GeApi()
ret = ge_api.ge_initialize(config)
if ret != 0:
print(f"GE初始化失败,返回码: {ret}")
return ret
print("GE环境初始化成功 (Device ID: 0)")
try:
session = Session()
graph_id = 1
ret = session.add_graph(graph_id, graph)
if ret != 0:
print(f"添加图失败,返回码: {ret}")
return ret
print(f"图已添加到Session (Graph ID: {graph_id})")
tensor1 = Tensor([1.0, 2.0], None, DataType.DT_FLOAT, Format.FORMAT_ND, [2])
tensor_const = Tensor([1], None, DataType.DT_INT32, Format.FORMAT_ND, [])
inputs = [tensor1, tensor_const]
ret = session.run_graph(graph_id, inputs)
print("[Info] 图运行成功!")
for idx, tensor in enumerate(ret, start=1):
print(f"Tensor{idx}详情:{tensor}")
return 0
except Exception as e:
print(f"[Error] 执行过程中出错: {e}")
import traceback
traceback.print_exc()
return -1
finally:
print("[Info] 清理GE环境...")
ge_api.ge_finalize()
print("[Success] GE环境已清理")
graph = build_if_graph()
dump_if_graph(graph)
run_graph(graph)