import numpy as np
from ge.es import GraphBuilder
from ge.graph import Tensor, DumpFormat
from ge.graph.types import DataType, Format
from ge.ge_global import GeApi
from ge.session import Session
from ge.es.all import *
def build_transformer_graph():
builder = GraphBuilder("MakeTransformerSubGraph")
input1, input2, input3 = builder.create_inputs(3)
reshape_result1 = Reshape(input1, [-1, 7168])
matmul_result1 = MatMul(
Cast(reshape_result1, dst_type=DataType.DT_FLOAT),
Transpose(Cast(input2, dst_type=DataType.DT_FLOAT), builder.create_vector_int64([1, 0])),
transpose_x1=False,
transpose_x2=False
)
sigmoid_result1 = Sigmoid(matmul_result1)
add_result1 = Reshape(sigmoid_result1, [-1, 256]) + Cast(Unsqueeze(input3, axes=[0]), dst_type=DataType.DT_FLOAT)
topkv2_result1 = TopKV2(add_result1, 2)
reducesum_result1 = ReduceSum(topkv2_result1.values, [-1])
topkv2_result2 = TopKV2(reducesum_result1, 4, sorted=False)
cast_result2 = Cast(topkv2_result2.indices, dst_type=DataType.DT_INT64)
scatterelements_result1 = ScatterElements(
ZerosLike(reducesum_result1),
cast_result2,
Fill(Shape(cast_result2), Cast(builder.create_scalar_float(1.0), dst_type=DataType.DT_FLOAT))
)
identity_result1 = Identity(BroadcastTo(Unsqueeze(scatterelements_result1, axes=[-1]), [256, 256]))
maskedfill_result1 = MaskedFill(
add_result1,
LogicalNot(
Cast(Reshape(identity_result1, builder.create_vector_int64([256, 256])),
dst_type=DataType.DT_BOOL)),
builder.create_scalar_float(0.0)
)
cast_result3 = Cast(
TopKV2(maskedfill_result1, 4, sorted=False).indices,
dst_type=DataType.DT_INT64
)
gatherelements_result1 = GatherElements(sigmoid_result1, cast_result3, dim=1)
realdiv_result1 = RealDiv(gatherelements_result1, 1e-20)
builder.set_graph_output(cast_result3, 0)
builder.set_graph_output(Cast(realdiv_result1 + builder.create_scalar_float(2.5), dst_type=DataType.DT_FLOAT), 1)
return builder.build_and_reset()
def build_transformer_graph_and_dump(graph):
graph.dump_to_file(format=DumpFormat.kOnnx, suffix="make_transformer_graph")
def run_graph(graph) -> int:
"""
编译并运行图
Args:
graph: Graph对象
Returns:
int: 0表示成功,非0表示失败
"""
config = {
"ge.exec.deviceId": "0",
"ge.graphRunMode": "0"
}
ge_api = GeApi()
ret = ge_api.ge_initialize(config)
if ret != 0:
print(f"GE初始化失败,返回码: {ret}")
return ret
print("GE环境初始化成功 (Device ID: 0)")
try:
session = Session()
graph_id = 1
ret = session.add_graph(graph_id, graph)
if ret != 0:
print(f"添加图失败,返回码: {ret}")
return ret
print(f"图已添加到Session (Graph ID: {graph_id})")
input1_data = np.random.randn(32, 8, 7168).astype(np.float32)
input2_data = np.random.randn(256, 7168).astype(np.float32)
input3_data = np.random.randn(256).astype(np.float32)
tensor1 = Tensor(
input1_data.flatten().tolist(),
None,
DataType.DT_FLOAT,
Format.FORMAT_ND,
[32, 8, 7168]
)
tensor2 = Tensor(
input2_data.flatten().tolist(),
None,
DataType.DT_FLOAT,
Format.FORMAT_ND,
[256, 7168]
)
tensor3 = Tensor(
input3_data.flatten().tolist(),
None,
DataType.DT_FLOAT,
Format.FORMAT_ND,
[256]
)
inputs = [tensor1, tensor2, tensor3]
ret = session.run_graph(graph_id, inputs)
print("[Info] 图运行成功!")
for idx, tensor in enumerate(ret, start=1):
print(f"Tensor{idx}详情:{tensor}")
return 0
except Exception as e:
print(f"[Error] 执行过程中出错: {e}")
import traceback
traceback.print_exc()
return -1
finally:
print("[Info] 清理GE环境...")
ge_api.ge_finalize()
print("[Success] GE环境已清理")
graph = build_transformer_graph()
build_transformer_graph_and_dump(graph)
run_graph(graph)