import sys
import numpy as np
from copy import deepcopy
from auto_optimizer import OnnxGraph
FP16_MAX_VALUE = 65504
FP16_MIN_VALUE = -65504
def delete_domain(graph):
for node in graph.nodes:
if node.domain != '':
node.domain = ''
while len(graph.opset_imports) > 1:
graph.opset_imports.pop(1)
def fix_mul(graph):
initializer_list = graph.initializers
mul_input = []
for mul_node in graph.get_nodes('Mul'):
for input_name in mul_node.inputs:
mul_input.append(input_name)
for initializer in initializer_list:
if initializer.name in mul_input:
fixed_value = deepcopy(initializer.value)
value_mask_pos = (fixed_value > FP16_MAX_VALUE)
value_mask_neg = (fixed_value < FP16_MIN_VALUE)
if np.sum(value_mask_pos) > 0 or np.sum(value_mask_neg) > 0:
print(f"Fix value node: {initializer}")
fixed_value[value_mask_pos] = FP16_MAX_VALUE
fixed_value[value_mask_neg] = FP16_MIN_VALUE
initializer.value = fixed_value
if __name__ == '__main__':
input_path = sys.argv[1]
save_path = sys.argv[2]
onnx_graph = OnnxGraph.parse(input_path)
fix_mul(onnx_graph)
delete_domain(onnx_graph)
onnx_graph.save(save_path)