import onnx
from onnx import helper, TensorProto, shape_inference
import onnx_graphsurgeon as gs
import numpy as np
from typing import List, Dict, Optional, Tuple
from pathlib import Path
import json
CONFIG_PATH = Path(__file__).parent.parent / "config" / "operator_replacements.json"
def _topological_sort_onnx(model) -> "onnx.ModelProto":
"""Perform topological sort on ONNX graph nodes"""
graph = model.graph
node_map = {n.name: n for n in graph.node if n.name}
initializer_names = {i.name for i in graph.initializer}
input_names = {i.name for i in graph.input}
all_inputs = set(initializer_names) | input_names
for node in graph.node:
for inp in node.input:
all_inputs.add(inp)
sorted_nodes = []
remaining_nodes = list(graph.node)
produced_outputs = set(initializer_names) | input_names
while remaining_nodes:
progress = False
for node in remaining_nodes:
node_inputs = set(node.input)
ready = all(inp in produced_outputs for inp in node_inputs)
if ready:
sorted_nodes.append(node)
remaining_nodes.remove(node)
for out in node.output:
produced_outputs.add(out)
progress = True
break
if not progress and remaining_nodes:
sorted_nodes.extend(remaining_nodes)
break
del graph.node[:]
graph.node.extend(sorted_nodes)
return model
def load_replacement_config() -> Dict:
if CONFIG_PATH.exists():
with open(CONFIG_PATH, "r", encoding="utf-8") as f:
return json.load(f)
return {}
def get_unsupported_operators(model_path: str, device: str = "NPU") -> List[str]:
config = load_replacement_config()
model = onnx.load(model_path)
unsupported = []
support_key = f"{device.lower()}_supported"
for node in model.graph.node:
op_type = node.op_type
if op_type in config:
if not config[op_type].get(support_key, True):
unsupported.append(op_type)
return list(set(unsupported))
def replace_split_with_slices(model_path: str, output_path: str,
node_name: str = None,
axis: int = None,
split_sizes: List[int] = None) -> Dict:
model = onnx.load(model_path)
graph = model.graph
split_nodes = [n for n in graph.node if n.op_type == "Split"]
if not split_nodes:
print("[WARNING] 未找到Split节点")
return {
"success": True,
"operators_replaced": [],
"nodes_count": {},
"output_path": output_path,
"message": "未找到Split节点"
}
if node_name:
split_nodes = [n for n in split_nodes if n.name == node_name]
if not split_nodes:
return {
"success": False,
"error": f"未找到指定节点: {node_name}",
"operators_replaced": [],
"nodes_count": {}
}
replaced_count = 0
new_node_list = []
split_node_names = {n.name for n in split_nodes}
for original_node in graph.node:
if original_node.name not in split_node_names:
new_node_list.append(original_node)
continue
split_node = original_node
split_input = split_node.input[0]
outputs_count = len(split_node.output)
node_axis = axis
node_split_sizes = split_sizes
for attr in split_node.attribute:
if attr.name == "axis":
node_axis = attr.i
elif attr.name == "split":
node_split_sizes = list(attr.ints)
if node_axis is None:
node_axis = 0
input_shape = None
for inp in graph.input:
if inp.name == split_input:
if inp.type.tensor_type.shape.dim:
input_shape = [d.dim_value for d in inp.type.tensor_type.shape.dim]
break
if input_shape is None:
for init in graph.initializer:
if init.name == split_input:
shape = list(init.dims)
if len(shape) > node_axis:
input_shape = shape
break
if node_split_sizes is None:
if len(split_node.input) > 1:
split_param_name = split_node.input[1]
for init in graph.initializer:
if init.name == split_param_name:
if init.int64_data:
node_split_sizes = list(init.int64_data)
elif init.int32_data:
node_split_sizes = list(init.int32_data)
elif init.raw_data:
import numpy as np
arr = np.frombuffer(init.raw_data, dtype=np.int64)
node_split_sizes = list(arr)
break
if node_split_sizes is None and input_shape and node_axis < len(input_shape):
total_size = input_shape[node_axis]
node_split_sizes = [total_size // outputs_count] * outputs_count
if node_split_sizes is None:
print(f"[WARNING] 无法确定Split节点 {split_node.name} 的split sizes,保留原节点")
new_node_list.append(split_node)
continue
start_positions = [0]
for i in range(len(node_split_sizes) - 1):
start_positions.append(start_positions[-1] + node_split_sizes[i])
for i, (start_pos, size) in enumerate(zip(start_positions, node_split_sizes)):
end_pos = start_pos + size
if i == len(node_split_sizes) - 1:
end_value = 9223372036854775807
else:
end_value = end_pos
starts_name = f"{split_node.name}_slice_{i}_starts"
ends_name = f"{split_node.name}_slice_{i}_ends"
axes_name = f"{split_node.name}_slice_{i}_axes"
steps_name = f"{split_node.name}_slice_{i}_steps"
starts_constant = helper.make_tensor(
name=starts_name,
data_type=TensorProto.INT64,
dims=[1],
vals=[start_pos]
)
ends_constant = helper.make_tensor(
name=ends_name,
data_type=TensorProto.INT64,
dims=[1],
vals=[end_value]
)
axes_constant = helper.make_tensor(
name=axes_name,
data_type=TensorProto.INT64,
dims=[1],
vals=[node_axis]
)
steps_constant = helper.make_tensor(
name=steps_name,
data_type=TensorProto.INT64,
dims=[1],
vals=[1]
)
graph.initializer.extend([starts_constant, ends_constant, axes_constant, steps_constant])
slice_output_name = split_node.output[i]
slice_node = helper.make_node(
'Slice',
inputs=[split_input, starts_name, ends_name, axes_name, steps_name],
outputs=[slice_output_name],
name=f"{split_node.name}_Slice_{i}"
)
new_node_list.append(slice_node)
replaced_count += 1
del graph.node[:]
graph.node.extend(new_node_list)
try:
model = shape_inference.infer_shapes(model)
except Exception as e:
print(f"[WARNING] Shape inference failed: {e}")
onnx.save(model, output_path)
print(f"[OK] 成功替换 {replaced_count} 个Split节点为Slice节点")
print(f"[OK] 已保存至: {output_path}")
return {
"success": True,
"operators_replaced": ["Split"],
"nodes_count": {"Split": replaced_count},
"output_path": output_path,
"message": f"成功替换 {replaced_count} 个Split节点"
}
def replace_mod_equivalent(model_path: str, output_path: str,
node_name: str = None) -> Dict:
model = onnx.load(model_path)
graph = gs.import_onnx(model)
mod_nodes = [n for n in graph.nodes if n.op == "Mod"]
if not mod_nodes:
print("[WARNING] 未找到Mod节点")
return {
"success": True,
"operators_replaced": [],
"nodes_count": {},
"output_path": output_path,
"message": "未找到Mod节点"
}
if node_name:
mod_nodes = [n for n in mod_nodes if n.name == node_name]
if not mod_nodes:
return {
"success": False,
"error": f"未找到指定节点: {node_name}",
"operators_replaced": [],
"nodes_count": {}
}
replaced_count = 0
for mod_node in mod_nodes:
a = mod_node.inputs[0]
b = mod_node.inputs[1]
out = mod_node.outputs[0]
dtype = out.dtype if out.dtype else np.float32
downstream_consumers = list(out.outputs)
t_div = gs.Variable(name=f"t_div_{mod_node.name}", dtype=dtype)
t_floor = gs.Variable(name=f"t_floor_{mod_node.name}", dtype=dtype)
t_mul = gs.Variable(name=f"t_mul_{mod_node.name}", dtype=dtype)
new_out = gs.Variable(name=f"new_out_{mod_node.name}", dtype=dtype)
div_node = gs.Node(op="Div", inputs=[a, b], outputs=[t_div])
floor_node = gs.Node(op="Floor", inputs=[t_div], outputs=[t_floor])
mul_node = gs.Node(op="Mul", inputs=[t_floor, b], outputs=[t_mul])
sub_node = gs.Node(op="Sub", inputs=[a, t_mul], outputs=[new_out])
graph.nodes.extend([div_node, floor_node, mul_node, sub_node])
for down_node in downstream_consumers:
for i, inp in enumerate(down_node.inputs):
if inp is out:
down_node.inputs[i] = new_out
if out in graph.outputs:
idx = graph.outputs.index(out)
graph.outputs[idx] = new_out
replaced_count += 1
graph.cleanup()
graph.toposort()
new_model = gs.export_onnx(graph)
onnx.checker.check_model(new_model)
onnx.save(new_model, output_path)
print(f"[OK] 成功替换 {replaced_count} 个Mod节点")
print(f"[OK] 已保存至: {output_path}")
return {
"success": True,
"operators_replaced": ["Mod"],
"nodes_count": {"Mod": replaced_count},
"output_path": output_path,
"message": f"成功替换 {replaced_count} 个Mod节点"
}
def replace_expand_with_tile(model_path: str, output_path: str,
node_name: str = None) -> Dict:
model = onnx.load(model_path)
graph = gs.import_onnx(model)
expand_nodes = [n for n in graph.nodes if n.op == "Expand"]
if not expand_nodes:
print("[WARNING] 未找到Expand节点")
return {
"success": True,
"operators_replaced": [],
"nodes_count": {},
"output_path": output_path,
"message": "未找到Expand节点"
}
if node_name:
expand_nodes = [n for n in expand_nodes if n.name == node_name]
if not expand_nodes:
return {
"success": False,
"error": f"未找到指定节点: {node_name}",
"operators_replaced": [],
"nodes_count": {}
}
replaced_count = 0
for expand_node in expand_nodes:
inp = expand_node.inputs[0]
shape = expand_node.inputs[1]
out = expand_node.outputs[0]
dtype = out.dtype if out.dtype else np.float32
downstream_consumers = list(out.outputs)
new_out = gs.Variable(name=f"new_out_{expand_node.name}", dtype=dtype)
tile_node = gs.Node(op="Tile", inputs=[inp, shape], outputs=[new_out])
graph.nodes.append(tile_node)
for down_node in downstream_consumers:
for i, inp_var in enumerate(down_node.inputs):
if inp_var is out:
down_node.inputs[i] = new_out
if out in graph.outputs:
idx = graph.outputs.index(out)
graph.outputs[idx] = new_out
replaced_count += 1
graph.cleanup()
graph.toposort()
new_model = gs.export_onnx(graph)
onnx.checker.check_model(new_model)
onnx.save(new_model, output_path)
print(f"[OK] 成功替换 {replaced_count} 个Expand节点为Tile节点")
print(f"[OK] 已保存至: {output_path}")
return {
"success": True,
"operators_replaced": ["Expand"],
"nodes_count": {"Expand": replaced_count},
"output_path": output_path,
"message": f"成功替换 {replaced_count} 个Expand节点"
}
RUNTIME_OP_REPLACEMENTS = {
"split_to_slices": replace_split_with_slices,
"mod_to_sub_mul_div": replace_mod_equivalent,
"expand_to_tile": replace_expand_with_tile,
}
def fix_runtime_operators(input_path: str, output_path: str,
device: str = "NPU",
operators: List[str] = None) -> Dict:
config = load_replacement_config()
if not Path(input_path).exists():
return {
"success": False,
"error": f"File not found: {input_path}",
"operators_replaced": [],
"nodes_per_operator": {}
}
model = onnx.load(input_path)
unsupported_in_model = get_unsupported_operators(input_path, device)
if not unsupported_in_model:
return {
"success": True,
"operators_replaced": [],
"nodes_per_operator": {},
"output_path": output_path,
"message": f"模型中没有在{device}上不支持的算子"
}
to_replace = {}
for op in unsupported_in_model:
if op in config and config[op].get("strategy"):
count = len([n for n in model.graph.node if n.op_type == op])
if count > 0:
to_replace[op] = {
"count": count,
"strategy": config[op]["strategy"]
}
if operators:
to_replace = {k: v for k, v in to_replace.items() if k in operators}
if not to_replace:
unsupported_ops = [op for op in unsupported_in_model if op not in config or not config[op].get("strategy")]
return {
"success": True,
"operators_replaced": [],
"nodes_per_operator": {},
"output_path": output_path,
"message": f"检测到不支持的算子但无替换策略: {unsupported_ops}",
"unsupported_without_strategy": unsupported_ops
}
print(f"[INFO] 检测到需要替换的算子 ({device}不支持):")
for op, info in to_replace.items():
print(f" - {op}: {info['count']} 个节点 (策略: {info['strategy']})")
temp_path = input_path
final_output = output_path
replaced_ops = []
nodes_per_op = {}
for op, info in to_replace.items():
strategy = info["strategy"]
if strategy in RUNTIME_OP_REPLACEMENTS:
replace_func = RUNTIME_OP_REPLACEMENTS[strategy]
stem = Path(temp_path).stem
parent = Path(temp_path).parent
target_output = str(parent / f"{stem}_{op}_replaced.onnx")
result = replace_func(temp_path, target_output)
if isinstance(result, dict) and result.get("success"):
replaced_ops.append(op)
nodes_per_op[op] = info["count"]
temp_path = target_output
elif isinstance(result, dict) and not result.get("success"):
print(f"[WARNING] {op} 替换失败: {result.get('error', 'unknown error')}")
if temp_path != final_output and replaced_ops:
import shutil
shutil.copy(temp_path, final_output)
print(f"[INFO] 最终输出: {final_output}")
return {
"success": True,
"operators_replaced": replaced_ops,
"nodes_per_operator": nodes_per_op,
"output_path": final_output,
"device": device,
"message": f"成功替换 {len(replaced_ops)} 种算子"
}
if __name__ == "__main__":
import sys
if len(sys.argv) < 3:
print("Usage: python runtime_op_replace.py <input.onnx> <output.onnx> [device] [operators]")
print("\nExample:")
print(" python runtime_op_replace.py model.onnx model_fixed.onnx NPU Split")
print(" python runtime_op_replace.py model.onnx model_fixed.onnx CPU")
sys.exit(1)
input_path = sys.argv[1]
output_path = sys.argv[2]
device = sys.argv[3] if len(sys.argv) > 3 else "NPU"
operators = sys.argv[4:] if len(sys.argv) > 4 else None
result = fix_runtime_operators(input_path, output_path, device, operators)
print("\n" + "=" * 60)
print("Result:")
print(json.dumps(result, indent=2, ensure_ascii=False))