GE-PY Python 模块类关系文档
概述
GE-PY 是 GraphEngine 的 Python 接口模块,提供了 Pythonic 的图相关接口。为用户提供了便捷的图构建和操作,编译执行等功能。该模块对外头文件位于 api/python/ge/ge/ 目录下。
目录结构
graph模块
├── __init__.py # 模块初始化文件
├── graph.py # Graph 类定义
├── node.py # Node 类定义
├── types.py # 数据类型定义
├── tensor.py # Tensor 类定义
├── tensor_desc.py # Shape / TensorDesc 类定义
├── _attr.py # 内部属性值类定义
└── _numeric.py # 内部数值转换类定义
注:下划线开头的为 Python 风格下的对内模块
graph核心类关系图
graph TB
subgraph "Python API Layer"
Graph[Graph<br/>图类]
Node[Node<br/>节点类]
Tensor[Tensor<br/>张量类]
Shape[Shape<br/>形状类]
TensorDesc[TensorDesc<br/>张量元信息类]
DataType[DataType<br/>数据类型枚举]
Format[Format<br/>格式枚举]
Placement[Placement<br/>数据存储位置枚举]
AttrValue[_AttrValue<br/>属性值类]
end
subgraph "C API Wrapper Layer"
GraphLib[graph<br/>C库包装器]
ESBLib[esb_lib<br/>基础库包装器]
PyGraphWrapper[pygraph_wrapper<br/>Python C API包装]
PyESWrapper[pyes_graph_builder_wrapper<br/>Python C API包装]
end
subgraph "C++ Backend"
CGraph[ge::Graph<br/>C++图对象]
CGNode[ge::GNode<br/>C++节点对象]
CAttrValue[ge::AttrValue<br/>C++属性值对象]
CTensor[ge::EsCTensor<br/>C++Tenor对象]
CTensorDesc[ge::TensorDesc<br/>C++张量元信息对象]
end
%% Python层关系
Graph -->|"包含多个"| Node
Graph -->|"使用"| DataType
Graph -->|"使用"| Format
Graph -->|"使用"| AttrValue
Tensor -->|"包含"| DataType
Tensor -->|"包含"| Format
Tensor -->|"包含"| Placement
Tensor -->|"获取"| TensorDesc
TensorDesc -->|"包含"| Shape
TensorDesc -->|"包含"| DataType
TensorDesc -->|"包含"| Format
Node -->|"使用"| AttrValue
Node -->|"获取/更新输入输出描述"| TensorDesc
%% Python到C API
Graph -.->|"通过"| GraphLib
Node -.->|"通过"| GraphLib
AttrValue -.->|"通过"| GraphLib
Tensor -.->|"通过"| GraphLib
Tensor -.->|"通过"| ESBLib
GraphLib -->|"调用"| PyGraphWrapper
ESBLib -->|"调用"| PyESWrapper
%% C API到C++
PyGraphWrapper -->|"转换为"| CGraph
PyGraphWrapper -->|"转换为"| CGNode
PyGraphWrapper -->|"转换为"| CAttrValue
PyGraphWrapper -->|"转换为"| CTensor
PyGraphWrapper -->|"转换为"| CTensorDesc
PyESWrapper -->|"转换为"| CTensor
类详细说明
1. Graph 类
文件位置: graph.py
功能: 图操作的主要接口类
主要方法:
__init__(name)- 初始化图get_all_nodes()- 获取所有节点get_direct_node()- 获取直接连接节点find_node_by_name(name)- 根据名称获取节点get_attr(key)- 获取图属性set_attr(key, value)- 设置图属性remove_node(node)- 移除节点remove_edge(src_node, src_port_index, dst_node, dst_port_index)- 移除边add_data_edge(src_node, src_port_index, dst_node, dst_port_index)- 添加数据边add_control_edge(src_node, dst_node)- 添加控制边save_to_air(file_path)- 将图保存成AIR文件load_from_air(file_path)- 从AIR文件加载图get_all_subgraphs()- 获取所有子图get_subgraph(name)- 根据名称获取子图add_subgraph(subgraph)- 添加子图,以子图的名称为key,不允许出现重复。若添加名称相同的子图,添加子图失败remove_subgraph(name)- 根据名称移除子图
属性:
_handle- 底层C图对象的句柄_owns_handle- 是否拥有句柄的所有权_owner- 句柄所有者_name- 图名称
关系:
- 通过
graph_lib调用底层C API - 管理多个
Node对象
2. Node 类
文件位置: node.py
功能: 图节点操作接口类
主要方法:
get_attr(key)- 获取节点属性(可返回 string / number / list /Tensor等 Python 值)set_attr(key, value)- 设置节点属性get_in_data_nodes_and_port_indexes(in_index)- 获取输入节点和端口get_out_data_nodes_and_port_indexes(out_index)- 获取输出节点和端口get_inputs_size()- 获取输入数量get_outputs_size()- 获取输出数量has_attr(key)- 是否含有节点属性get_input_desc(index)- 获取第index个输入的TensorDescupdate_input_desc(index, tensor_desc)- 更新第index个输入的TensorDescget_output_desc(index)- 获取第index个输出的TensorDescupdate_output_desc(index, tensor_desc)- 更新第index个输出的TensorDesc
属性:
_handle- 底层C节点对象的句柄_owns_handle- 是否拥有句柄的所有权name- 节点名称(只读属性)type- 节点类型(只读属性)
关系:
- 通过
graph_lib调用底层C API - 与
Graph对象关联
3. DataType 枚举
文件位置: types.py
功能: 定义支持的数据类型
关系:
- 与 C++ 中的
ge::DataType对应 - 在
Graph和Node操作中使用
4. Format 枚举
文件位置: types.py
功能: 定义张量格式
关系:
- 与 C++ 中的
ge::Format对应 - 用于张量形状和格式描述
5. Placement 枚举
文件位置: types.py
功能: 定义 Tensor 数据的存储位置
关系:
- 与 C++ 中的
ge::Placement对应 - 用于描述数据存放的存储位置
依赖关系
-
内部依赖:
- Graph库
ge._capi.pygraph_wrapper- C API包装器
-
外部依赖:
- ctypes库
6. Tensor 类
文件位置: tensor.py
功能: 张量数据类
主要方法:
set_format(format)- 设置格式get_format()- 获取格式set_data_type(data_type)- 设置数据类型get_data_type()- 获取数据类型get_tensor_desc()- 获取张量元信息描述get_shape()- 获取形状get_data()- 获取数据get_placement()- 获取数据所在存储位置to_device()- 将当前 Tensor 从 Host 移动到 Deviceto_host()- 将当前 Tensor 从 Device 移动到 Host
属性:
_handle- 底层C节点对象的句柄_owns_handle- 是否拥有句柄的所有权_owner- 句柄所有者 关系:- 通过
graph_lib和esb_lib调用底层C API - 与
Session对象关联
7. TensorDesc 类
文件位置: tensor_desc.py
功能: 张量元信息描述类,用于描述 shape、format、data type 以及 origin shape/origin format。
主要方法:
__init__(shape=None, format=Format.FORMAT_ND, data_type=DataType.DT_FLOAT)- 创建 TensorDesc;shape=None表示标量get_shape()/set_shape(shape)- 获取或设置 shapeget_origin_shape()/set_origin_shape(shape)- 获取或设置 origin shapeget_format()/set_format(format)- 获取或设置 formatget_origin_format()/set_origin_format(format)- 获取或设置 origin formatget_data_type()/set_data_type(data_type)- 获取或设置 data type
属性:
shape- 张量形状origin_shape- 原始张量形状format- 张量存储格式origin_format- 原始张量格式data_type- 张量数据类型
关系:
- 通过
graph_lib调用底层 C API - 与
Tensor和Node对象关联
8. Shape 类
文件位置: tensor_desc.py
功能: 张量形状类,继承自 Python list,保持普通列表的比较、遍历和索引行为,同时提供形状相关辅助方法。
主要方法:
get_shape_size()- 获取形状元素总数;空 shape 返回0,包含未知维度-1或-2时返回-1is_unknown_shape()- 判断是否包含未知维度
关系:
- 用于描述张量形状
utils 模块
目录结构
├── utils/
│ ├── __init__.py # 导出 GeUtils
│ └── ge_utils.py # GeUtils 公共工具接口
类详细说明
1. GeUtils 类
文件位置: utils/ge_utils.py
功能: GE 公共工具接口,面向 Graph / Node 对象提供 Shape 推导与节点 AICore 支持性校验能力。
主要方法:
infer_shape(graph, input_shapes)- 给定输入 shape, 对传入的 graph 做全图 shape 推导;本接口只做shape推导,不对图做任何其他优化(如常量折叠、死边消除等)check_node_support_on_aicore(node)- 校验指定 node 是否支持在 AICore 上执行
关系:
- 通过
ge_utils_lib调用底层 C API
allocator 模块
目录结构
allocator/
├── __init__.py # 模块初始化文件
└── allocator.py # Allocator、MemBlock 定义
类详细说明
1. MemBlock 类
文件位置: allocator.py
功能: 描述由 allocator 管理的一段 Device 内存。
主要属性:
addr- Device 侧地址size- 内存大小(字节)
2. Allocator 类
文件位置: allocator.py
功能: 内存分配器抽象基类
主要方法:
malloc(size)- 申请一段 Device 内存,返回MemBlockfree(block)- 释放malloc()返回的MemBlock
关系:
- 由
Session.register_external_allocator()注册到指定 stream,在Session.run_graph_with_stream_async()时使用该 allocator
ge_global模块
目录结构
├── __init__.py # 模块初始化文件
└── geapi.py # GeApi接口文件
类详细说明
1. Geapi 类
文件位置: geapi.py
功能:提供 GE 初始化和析构
主要方法:
-
ge_initialize(config)- GE初始化 -
ge_finalize()- GE析构关系:
-
通过
geapi_lib调用底层C API
使用示例:
from ge.ge_global import GeApi
ge_api = GeApi()
# 调用GE初始化函数
config = {"ge.exec.deviceId":"2", "ge.graphRunMode":"0"}
ge_api.ge_initialize(config)
# 调用GE资源释放函数
ge_api.ge_finalize()
offline_compile模块
目录结构
├── __init__.py # 模块初始化文件
└── offline_compile.py # 离线图编译接口文件
接口说明
1. offline_compile 模块
文件位置: offline_compile.py
功能:离线图编译接口
主要接口:
build_initialize(global_options)- 模型构建初始化,用于申请资源build_finalize()- 系统完成模型构建后,通过该接口释放资源build_model(graph, build_options)- 将输入的Graph编译为适配AI处理器的离线模型,并保存到内存缓冲区save_model(output_file, model)- 将离线模型序列化并保存到指定文件中bundle_build_model(graph_with_options)- 将输入的一组Graph编译为适配AI处理器的离线模型,并保存到内存缓冲区,该接口适用于权重更新场景bundle_save_model(output_file, model)- 将离线模型序列化并保存到指定文件中,该接口适用于权重更新场景
辅助类型:
ModelBuffer- 内存缓冲区中的序列化模型数据,持有底层C模型对象的句柄GraphWithOptions- bundle 编译时的图和编译选项对
关系:
- 通过
offline_compile_lib调用底层C API - 输入依赖
Graph对象
使用示例:
from ge.offline_compile import build_initialize, build_finalize, build_model, save_model
from ge.graph import Graph
# 创建Graph
graph = Graph("test_graph")
# 初始化模型构建
build_initialize({"ge.socVersion": "Ascend910B1"})
# 编译模型
model = build_model(graph, {"input_format": "ND"})
# 保存模型
save_model("sample", model)
# 释放模型构建资源
build_finalize()
Session 模块
目录结构
├── __init__.py # 模块初始化文件
└── session.py # session接口文件
类详细说明
1. Session 类
文件位置: session.py
功能: 图编译执行操作接口类
主要方法:
__init__()- 初始化sessionadd_graph(graph_id, add_graph, options)- 添加图remove_graph(graph_id)- 移除图run_graph(graph_id, inputs)- 运行图register_external_allocator(stream, allocator)- 为指定 stream 注册外置 allocatorunregister_external_allocator(stream)- 注销指定 stream 的外置 allocatorrun_graph_with_stream_async(graph_id, stream, inputs)- 在指定 stream 上异步执行图
属性:
-
_handle- 底层C节点对象的句柄 -
_owns_handle- 是否拥有句柄的所有权关系:
-
通过
session_lib调用底层C API 使用示例:
from ge.session import Session
from ge.ge_global import GeApi
from ge.graph import Graph
from ge.graph import Tensor
from ge.graph.types import DataType, Format
# 调用GE初始化函数
config = {"ge.exec.deviceId":"2", "ge.graphRunMode":"0"}
GeApi.ge_initialize(config)
# 创建session
session = Session()
# 创建Graph
graph = Graph("test_graph")
# 设置Graph_id
graph_id = 0
# 添加Graph
session.add_graph(graph_id,graph)
# 创建input_tensor_list
tensor = Tensor([1, 2, 3, 4, 5], None, [1,2,3], DataType.DT_INT8, Format.FORMAT_ND)
input_tensor_list = []
input_tensor_list.append(tensor)
# 运行graph
output_tensor_list = session.run_graph(graph_id,input_tensor_list)
# 调用GE资源释放函数
GeApi.ge_finalize()
passes 模块
目录结构
├── __init__.py # 模块初始化,导出公共 API
├── base.py # Pass 基类定义(FusionBasePass、PatternFusionPass、DecomposePass 等)
├── pattern.py # Pattern / NodeIo 等模式匹配辅助接口
├── replacement.py # replacement graph 构建辅助接口
├── registry.py # Pass 注册中心与装饰器
├── bootstrap.py # 插件发现与加载
└── _bridge.py # Bridge 运行时辅助(Pass 实例管理,供 C++ bridge .so 回调)
注:下划线开头的为 Python 风格下的对内模块
注:PassContext、MatchResult、Pattern、PatternMatcherConfig 等对象由 _ge_pass_native.so 提供 native-backed 实现,base.py / pattern.py 负责对外导出与少量 Python 辅助封装。
运行时 native artifact 选择
_ge_pass_native.so 与 libge_python_pass_bridge.so 作为同一套 artifact set 成套发布,目录固定为:
ge/passes/python_pass_artifacts/<python_tag>-<platform>/manifest.json
ge/passes/python_pass_artifacts/<python_tag>-<platform>/_ge_pass_native.so
ge/passes/python_pass_artifacts/<python_tag>-<platform>/libge_python_pass_bridge.so
主 wheel 保持一份纯 Python 接口,不再内置当前 Python 的默认 native artifact set。native 子 wheel 按 cp39 到 cp314 的 Python minor 版本矩阵分别承载预制 artifact set,并额外提供 ge/passes/_ge_pass_native.so 兼容副本,用于默认 bridge 路径下的 import ge.passes._ge_pass_native。native 子 wheel 通过标准 bdist_wheel 生成。仓内提供矩阵 builder 入口用于自动嗅探 PATH 中可用的 Python minor 版本并分别构建;如果某个 Python 可执行文件存在但开发头文件或 libpython 不完整,builder 会跳过该版本并继续构建其他可用版本。run 包可携带多个 ge_py_pass_bridge native 子 wheel,但安装脚本只应安装与当前执行安装脚本的 Python 解释器兼容的一个子 wheel;推荐使用 pip install --no-index --find-links <ge-compiler/lib64> <ge_py wheel> ge-py-pass-bridge,由 pip 按 wheel tag 自动选择。运行时优先选择与当前进程 Python tag、平台 tag、bridge ABI 匹配的预制 artifact;若没有命中,则回退到同目录 legacy bridge 路径。runtime fallback codegen 作为后续独立阶段接入。
类详细说明
1. PassStage 枚举
文件位置: base.py
功能: 定义 Pass 执行阶段
枚举值:
BEFORE_INFER_SHAPE- 在 InferShape 之前执行AFTER_INFER_SHAPE- 在 InferShape 之后执行AFTER_ASSIGN_LOGIC_STREAM- 在逻辑流分配之后执行AFTER_BUILTIN_FUSION_PASS- 在内置融合 Pass 之后执行AFTER_ORIGIN_GRAPH_OPTIMIZE- 在原始图优化之后执行
2. PassContext native-backed wrapper
文件位置: base.py
功能: Python 侧的 Pass 上下文视图
主要方法:
get_pass_name()- 获取 Pass 名称set_pass_name(pass_name)- 设置 Pass 名称get_option_value(option_key)- 获取编译选项get_error_message()- 获取错误信息set_error_message(error_message)- 设置错误信息
3. MatchResult native-backed wrapper
文件位置: base.py
功能: 模式匹配结果
主要方法:
get_matched_nodes()- 获取当前匹配命中的节点列表get_captured_tensor(capture_index)- 获取指定 capture 的NodeIoget_pattern_graph_name()- 获取 pattern graph 名称__str__()- 返回可读字符串表示
4. SubgraphRewriter native-backed wrappers
文件位置: graph_rewriter_binding.cc
功能: Python 侧的子图边界描述与子图替换接口,用于支持 graph base 类 pass 的“子图替换”能力。
主要类/方法:
SubgraphInput- 描述一个 subgraph 输入(一个输入可对应多个边界上的 node input)SubgraphInput() / SubgraphInput([(node, out_index), ...])- 构造 subgraph 输入add_input(node, out_index)- 追加一个输入锚点(node为ge.graph.Node,out_index为其输出 index)
SubgraphOutput- 描述一个 subgraph 输出SubgraphOutput() / SubgraphOutput(node, out_index)- 构造 subgraph 输出set_output(node, out_index)- 设置输出锚点
SubgraphBoundary- 描述待替换子图的输入/输出边界add_input(index, input)- 绑定第index个 boundary input 到SubgraphInputadd_output(index, output)- 绑定第index个 boundary output 到SubgraphOutput
SubgraphRewriter.replace(boundary, replacement)- 执行子图替换boundary:SubgraphBoundaryreplacement:ge.graph.Graph(replacement 图会在 C++ 侧拷贝并完成重连)
5. Pattern / NodeIo / PatternMatcherConfig
文件位置: pattern.py、base.py
功能:
Pattern- native-backed pattern wrapper,负责持有 pattern graph 与 capture 信息NodeIo- Python 侧描述节点输出位置的轻量 helperPatternMatcherConfig/PatternMatcherConfigBuilder- 模式匹配配置对象与 builder
主要接口:
Pattern(graph)- 从ge.graph.Graph构造 patternPattern.capture_tensor(source, index=0)- 记录 capture tensorPattern.get_captured_tensors()- 获取 capture 列表create_pattern(graph)- 显式构造PatternPatternMatcherConfigBuilder.enable_const_value_match()- 打开常量值匹配PatternMatcherConfigBuilder.enable_ir_attr_match()- 打开 IR 属性匹配PatternMatcherConfigBuilder.build()- 生成配置对象
6. FusionBasePass 类
文件位置: base.py
功能: 基础融合 Pass 基类,直接操作图结构
主要方法:
run(graph, context)- 执行 Pass,接收图对象和PassContext,返回None/bool/int状态值
关系:
PatternFusionPass和DecomposePass的父类- 通过
register_fusion_pass装饰器注册到全局 Pass 注册中心
7. PatternFusionPass 类
文件位置: base.py
功能: 基于模式匹配的融合 Pass 基类
主要方法:
patterns()- 定义匹配模式,返回模式列表meet_requirements(match_result)- 判断匹配结果是否满足融合条件,默认返回 Truereplacement(match_result)- 根据匹配结果生成替换子图,必须返回Graph
可选构造参数:
matcher_config-PatternMatcherConfig,用于控制常量值匹配、IR 属性匹配等 matcher 选项
设计约束:
- 不支持用户自定义
run()方法:PatternFusionPass复用 C++ 的Run()实现来执行标准的 pattern-match-replacement 流程。Python 侧只需实现patterns()、meet_requirements()和replacement()三个 hook 即可。 - 若子类覆写
run()会在类定义阶段直接抛出TypeError:避免用户误以为run()会在PatternFusionPass路径中被调用。 - 不支持在
replacement()中返回None表示跳过:若希望放弃当前匹配,需在meet_requirements()中返回False。 - 需要完全自定义
run()逻辑的场景:请直接使用FusionBasePass基类。
关系:
- 继承自
FusionBasePass - 通过
register_fusion_pass装饰器注册
8. DecomposePass 类
文件位置: base.py
功能: 算子分解 Pass 基类
类属性:
op_types- 需要分解的算子类型列表
主要方法:
meet_requirements(node)- 判断节点是否满足分解条件,默认返回 Truereplacement(node)- 将节点分解为多个子节点,必须返回Graph
设计约束:
- 不支持用户自定义
run()方法:DecomposePass复用 C++ 的Run()实现来执行标准的 node-filter-replacement 流程。Python 侧只需实现meet_requirements()和replacement()两个 hook 即可。 - 若子类覆写
run()会在类定义阶段直接抛出TypeError:避免用户误以为run()会在DecomposePass路径中被调用。 - 不支持在
replacement()中返回None表示跳过:若希望放弃当前节点,需在meet_requirements()中返回False。 op_types由register_decompose_pass(..., op_types=[...])声明并固化到 descriptor:Python 基类不再自行维护另一套构造参数。
关系:
- 继承自
FusionBasePass - 通过
register_decompose_pass装饰器注册
9. PassDescriptor 数据类
文件位置: registry.py
功能: 规范化的 Python Pass 描述符
属性:
descriptor_key- 描述符唯一键(格式:模块名:类名:Pass名)pass_name- Pass 名称module_name- 所属模块名class_name- 类名stage- 执行阶段(PassStage)kind- Pass 类型(fusion_base、pattern_fusion、decompose)cls- Pass 类引用op_types- 关联的算子类型列表
注册与发现
装饰器:
register_fusion_pass(name, stage, kind=None)- 注册 FusionBasePass 或 PatternFusionPassregister_decompose_pass(name, stage, op_types)- 注册 DecomposePass
发现机制:
- 通过环境变量
ASCEND_GE_PY_PASS_PATH指定 Pass 文件或目录路径 bootstrap.py负责扫描路径并动态加载 Python 模块- 支持单个
.py文件和包含__init__.py的 Python 包
使用示例:
from ge.passes import (
FusionBasePass, PatternFusionPass, DecomposePass,
PassStage, PassContext,
register_fusion_pass, register_decompose_pass
)
# 1. FusionBasePass 示例
@register_fusion_pass(name="MyFusionPass", stage=PassStage.AFTER_INFER_SHAPE)
class MyFusionPass(FusionBasePass):
def run(self, graph, context: PassContext):
# 实现图融合逻辑
return graph
# 2. PatternFusionPass 示例
@register_fusion_pass(name="MyPatternPass", stage=PassStage.BEFORE_INFER_SHAPE)
class MyPatternPass(PatternFusionPass):
def patterns(self):
return [...]
def meet_requirements(self, match_result):
return True
def replacement(self, match_result):
pass
# 3. DecomposePass 示例
@register_decompose_pass(
name="MyDecomposePass",
stage=PassStage.BEFORE_INFER_SHAPE,
op_types=["MyOp"]
)
class MyDecomposePass(DecomposePass):
def replacement(self, node):
pass
加载自定义 Pass:
export ASCEND_GE_PY_PASS_PATH=/path/to/my_pass.py:/path/to/pass_dir/
更多设计细节请参考 Python Pass 设计文档。
ES 模块
ES (Eager-Style) 模块提供了函数式风格的图构建接口,详细文档请参考:ES-PY Python 模块文档
使用示例
更多示例请参考 examples/es 目录下的 Python 用例。