import argparse
from typing import List
import auto_optimizer
import numpy as np
class FaAdapter:
def __init__(self, origin_onnx: auto_optimizer.OnnxGraph):
self.__graph = origin_onnx
self.__nodes_to_remove = []
self.__indices_of_gather_batch = self.__graph.add_initializer(
'indices_of_gather_batch',
np.array([0], dtype=np.int64),
)
self.__indices_of_gather_seq_len = self.__graph.add_initializer(
'indices_of_gather_seq_len',
np.array([1], dtype=np.int64),
)
self.__indices_of_axis_to_unsqueeze_for_mask = self.__graph.add_initializer(
'indices_of_axis_to_unsqueeze_for_mask',
np.array([1], dtype=np.int64),
)
def adapt(self) -> auto_optimizer.OnnxGraph:
softmaxs = self.__graph.get_nodes('Softmax')
for softmax in softmaxs:
self.__adapt_layer(softmax)
for node in self.__nodes_to_remove:
self.__graph.remove(node.name, {})
self.__graph.infer_shape()
return self.__graph
def __adapt_layer(self, softmax: auto_optimizer.OnnxNode) -> None:
layer_name_prefix = '/'.join(softmax.name.split('/')[:-1])
matmul_2 = self.__graph.get_next_nodes(softmax.outputs[0])[0]
reshape_3 = self.__graph.get_next_nodes(matmul_2.outputs[0])[0]
reshape_2 = self.__graph.get_prev_node(softmax.inputs[0])
add_2 = self.__graph.get_prev_node(reshape_2.inputs[0])
add_1 = self.__graph.get_prev_node(add_2.inputs[0])
reshape_1 = self.__graph.get_prev_node(add_1.inputs[0])
matmul_1 = self.__graph.get_prev_node(reshape_1.inputs[0])
mul = self.__graph.get_prev_node(matmul_1.inputs[0])
transpose = self.__graph.get_prev_node(matmul_1.inputs[1])
q = mul.inputs[0]
k = transpose.inputs[0]
v = matmul_2.inputs[1]
addend_1 = add_1.inputs[1]
addend_2 = add_2.inputs[1]
target_shape = self.__infer_target_shape(layer_name_prefix, q)
add_1.inputs = [addend_1, addend_2]
reshape_1.inputs = [add_1.outputs[0], target_shape]
add_2.inputs = [matmul_1.outputs[0], reshape_1.outputs[0]]
softmax.inputs = [add_2.outputs[0]]
self.__nodes_to_remove.append(reshape_2)
cast = self.__graph.add_node(
layer_name_prefix + 'Cast_seq_len',
'Cast',
inputs=[layer_name_prefix + 'seq_len'],
outputs=[layer_name_prefix + 'seq_len_int_32'],
attrs={'to': 6},
)
unsqueeze = self.__graph.add_node(
layer_name_prefix + 'Unsqueeze_mask',
'Unsqueeze',
inputs=[reshape_1.outputs[0], self.__indices_of_axis_to_unsqueeze_for_mask.name],
outputs=[layer_name_prefix + 'unsqueeze_mask_output'],
)
flash_attention = self.__graph.add_node(
layer_name_prefix + 'flash_attention',
'FlashAttentionSoftmaxFp32',
inputs=[q, k, v, cast.outputs[0], cast.outputs[0], unsqueeze.outputs[0]],
outputs=[layer_name_prefix + 'flash_attention_output'],
)
reshape_3.inputs[0] = flash_attention.outputs[0]
self.__nodes_to_remove += [mul, transpose, matmul_1, add_2, softmax, matmul_2]
def __infer_target_shape(self, layer_name_prefix: str, q: str) -> str:
shape_q = self.__graph.add_node(
layer_name_prefix + 'Shape_q',
'Shape',
inputs=[q],
outputs=[layer_name_prefix + 'Shape_of_q'],
)
gather_batch = self.__graph.add_node(
layer_name_prefix + 'Gather_batch',
'Gather',
inputs=[shape_q.outputs[0], self.__indices_of_gather_batch.name],
outputs=[layer_name_prefix + 'batch'],
)
gather_seq_len = self.__graph.add_node(
layer_name_prefix + 'Gather_seq_len',
'Gather',
inputs=[shape_q.outputs[0], self.__indices_of_gather_seq_len.name],
outputs=[layer_name_prefix + 'seq_len'],
)
concat_for_mask_shape = self.__graph.add_node(
layer_name_prefix + 'Concat_for_mask_shape',
'Concat',
inputs=[
gather_batch.outputs[0],
gather_seq_len.outputs[0],
gather_seq_len.outputs[0],
],
outputs=[layer_name_prefix + 'mask_shape'],
attrs={'axis': 0},
)
return concat_for_mask_shape.outputs[0]
if __name__ == '__main__':
parser = argparse.ArgumentParser(
description='Modify the SAM encoder ONNX model to adapt Ascend chips.'
)
parser.add_argument(
'--input',
type=str,
required=True,
help='The path to the original SAM encoder ONNX model.',
)
parser.add_argument(
'--output',
type=str,
required=True,
help='The path to save the adapted SAM encoder ONNX model to.',
)
args = parser.parse_args()
graph = auto_optimizer.OnnxGraph.parse(args.input)
fa_adapter = FaAdapter(graph)
graph = fa_adapter.adapt()
graph.save(args.output)