Bbaishanyanginit project
5f1c8c3b创建于 4 天前历史提交
import onnx
from onnx import helper, TensorProto, shape_inference
import onnx_graphsurgeon as gs
import numpy as np
from typing import List, Dict, Optional, Tuple
from pathlib import Path
import json

CONFIG_PATH = Path(__file__).parent.parent / "config" / "operator_replacements.json"


def _topological_sort_onnx(model) -> "onnx.ModelProto":
    """Perform topological sort on ONNX graph nodes"""
    graph = model.graph
    
    node_map = {n.name: n for n in graph.node if n.name}
    initializer_names = {i.name for i in graph.initializer}
    input_names = {i.name for i in graph.input}
    
    all_inputs = set(initializer_names) | input_names
    
    for node in graph.node:
        for inp in node.input:
            all_inputs.add(inp)
    
    sorted_nodes = []
    remaining_nodes = list(graph.node)
    produced_outputs = set(initializer_names) | input_names
    
    while remaining_nodes:
        progress = False
        for node in remaining_nodes:
            node_inputs = set(node.input)
            ready = all(inp in produced_outputs for inp in node_inputs)
            
            if ready:
                sorted_nodes.append(node)
                remaining_nodes.remove(node)
                for out in node.output:
                    produced_outputs.add(out)
                progress = True
                break
        
        if not progress and remaining_nodes:
            sorted_nodes.extend(remaining_nodes)
            break
    
    del graph.node[:]
    graph.node.extend(sorted_nodes)
    
    return model


def load_replacement_config() -> Dict:
    if CONFIG_PATH.exists():
        with open(CONFIG_PATH, "r", encoding="utf-8") as f:
            return json.load(f)
    return {}


def get_unsupported_operators(model_path: str, device: str = "NPU") -> List[str]:
    config = load_replacement_config()
    model = onnx.load(model_path)
    
    unsupported = []
    support_key = f"{device.lower()}_supported"
    
    for node in model.graph.node:
        op_type = node.op_type
        if op_type in config:
            if not config[op_type].get(support_key, True):
                unsupported.append(op_type)
    
    return list(set(unsupported))


def replace_split_with_slices(model_path: str, output_path: str,
                              node_name: str = None,
                              axis: int = None,
                              split_sizes: List[int] = None) -> Dict:
    model = onnx.load(model_path)
    graph = model.graph
    
    split_nodes = [n for n in graph.node if n.op_type == "Split"]
    
    if not split_nodes:
        print("[WARNING] 未找到Split节点")
        return {
            "success": True,
            "operators_replaced": [],
            "nodes_count": {},
            "output_path": output_path,
            "message": "未找到Split节点"
        }
    
    if node_name:
        split_nodes = [n for n in split_nodes if n.name == node_name]
        if not split_nodes:
            return {
                "success": False,
                "error": f"未找到指定节点: {node_name}",
                "operators_replaced": [],
                "nodes_count": {}
            }
    
    replaced_count = 0
    new_node_list = []
    split_node_names = {n.name for n in split_nodes}
    
    for original_node in graph.node:
        if original_node.name not in split_node_names:
            new_node_list.append(original_node)
            continue
        
        split_node = original_node
        split_input = split_node.input[0]
        outputs_count = len(split_node.output)
        
        node_axis = axis
        node_split_sizes = split_sizes
        
        for attr in split_node.attribute:
            if attr.name == "axis":
                node_axis = attr.i
            elif attr.name == "split":
                node_split_sizes = list(attr.ints)
        
        if node_axis is None:
            node_axis = 0
        
        input_shape = None
        for inp in graph.input:
            if inp.name == split_input:
                if inp.type.tensor_type.shape.dim:
                    input_shape = [d.dim_value for d in inp.type.tensor_type.shape.dim]
                break
        
        if input_shape is None:
            for init in graph.initializer:
                if init.name == split_input:
                    shape = list(init.dims)
                    if len(shape) > node_axis:
                        input_shape = shape
                    break
        
        if node_split_sizes is None:
            if len(split_node.input) > 1:
                split_param_name = split_node.input[1]
                for init in graph.initializer:
                    if init.name == split_param_name:
                        if init.int64_data:
                            node_split_sizes = list(init.int64_data)
                        elif init.int32_data:
                            node_split_sizes = list(init.int32_data)
                        elif init.raw_data:
                            import numpy as np
                            arr = np.frombuffer(init.raw_data, dtype=np.int64)
                            node_split_sizes = list(arr)
                        break
            
            if node_split_sizes is None and input_shape and node_axis < len(input_shape):
                total_size = input_shape[node_axis]
                node_split_sizes = [total_size // outputs_count] * outputs_count
            
            if node_split_sizes is None:
                print(f"[WARNING] 无法确定Split节点 {split_node.name} 的split sizes,保留原节点")
                new_node_list.append(split_node)
                continue
        
        start_positions = [0]
        for i in range(len(node_split_sizes) - 1):
            start_positions.append(start_positions[-1] + node_split_sizes[i])
        
        for i, (start_pos, size) in enumerate(zip(start_positions, node_split_sizes)):
            end_pos = start_pos + size
            
            if i == len(node_split_sizes) - 1:
                end_value = 9223372036854775807
            else:
                end_value = end_pos
            
            starts_name = f"{split_node.name}_slice_{i}_starts"
            ends_name = f"{split_node.name}_slice_{i}_ends"
            axes_name = f"{split_node.name}_slice_{i}_axes"
            steps_name = f"{split_node.name}_slice_{i}_steps"
            
            starts_constant = helper.make_tensor(
                name=starts_name,
                data_type=TensorProto.INT64,
                dims=[1],
                vals=[start_pos]
            )
            ends_constant = helper.make_tensor(
                name=ends_name,
                data_type=TensorProto.INT64,
                dims=[1],
                vals=[end_value]
            )
            axes_constant = helper.make_tensor(
                name=axes_name,
                data_type=TensorProto.INT64,
                dims=[1],
                vals=[node_axis]
            )
            steps_constant = helper.make_tensor(
                name=steps_name,
                data_type=TensorProto.INT64,
                dims=[1],
                vals=[1]
            )
            
            graph.initializer.extend([starts_constant, ends_constant, axes_constant, steps_constant])
            
            slice_output_name = split_node.output[i]
            
            slice_node = helper.make_node(
                'Slice',
                inputs=[split_input, starts_name, ends_name, axes_name, steps_name],
                outputs=[slice_output_name],
                name=f"{split_node.name}_Slice_{i}"
            )
            
            new_node_list.append(slice_node)
        
        replaced_count += 1
    
    del graph.node[:]
    graph.node.extend(new_node_list)
    
    try:
        model = shape_inference.infer_shapes(model)
    except Exception as e:
        print(f"[WARNING] Shape inference failed: {e}")
    
    onnx.save(model, output_path)
    print(f"[OK] 成功替换 {replaced_count} 个Split节点为Slice节点")
    print(f"[OK] 已保存至: {output_path}")
    
    return {
        "success": True,
        "operators_replaced": ["Split"],
        "nodes_count": {"Split": replaced_count},
        "output_path": output_path,
        "message": f"成功替换 {replaced_count} 个Split节点"
    }


def replace_mod_equivalent(model_path: str, output_path: str,
                           node_name: str = None) -> Dict:
    model = onnx.load(model_path)
    graph = gs.import_onnx(model)
    
    mod_nodes = [n for n in graph.nodes if n.op == "Mod"]
    
    if not mod_nodes:
        print("[WARNING] 未找到Mod节点")
        return {
            "success": True,
            "operators_replaced": [],
            "nodes_count": {},
            "output_path": output_path,
            "message": "未找到Mod节点"
        }
    
    if node_name:
        mod_nodes = [n for n in mod_nodes if n.name == node_name]
        if not mod_nodes:
            return {
                "success": False,
                "error": f"未找到指定节点: {node_name}",
                "operators_replaced": [],
                "nodes_count": {}
            }
    
    replaced_count = 0
    
    for mod_node in mod_nodes:
        a = mod_node.inputs[0]
        b = mod_node.inputs[1]
        out = mod_node.outputs[0]
        
        dtype = out.dtype if out.dtype else np.float32
        
        downstream_consumers = list(out.outputs)
        
        t_div = gs.Variable(name=f"t_div_{mod_node.name}", dtype=dtype)
        t_floor = gs.Variable(name=f"t_floor_{mod_node.name}", dtype=dtype)
        t_mul = gs.Variable(name=f"t_mul_{mod_node.name}", dtype=dtype)
        new_out = gs.Variable(name=f"new_out_{mod_node.name}", dtype=dtype)
        
        div_node = gs.Node(op="Div", inputs=[a, b], outputs=[t_div])
        floor_node = gs.Node(op="Floor", inputs=[t_div], outputs=[t_floor])
        mul_node = gs.Node(op="Mul", inputs=[t_floor, b], outputs=[t_mul])
        sub_node = gs.Node(op="Sub", inputs=[a, t_mul], outputs=[new_out])
        
        graph.nodes.extend([div_node, floor_node, mul_node, sub_node])
        
        for down_node in downstream_consumers:
            for i, inp in enumerate(down_node.inputs):
                if inp is out:
                    down_node.inputs[i] = new_out
        
        if out in graph.outputs:
            idx = graph.outputs.index(out)
            graph.outputs[idx] = new_out
        
        replaced_count += 1
    
    graph.cleanup()
    graph.toposort()
    
    new_model = gs.export_onnx(graph)
    onnx.checker.check_model(new_model)
    onnx.save(new_model, output_path)
    print(f"[OK] 成功替换 {replaced_count} 个Mod节点")
    print(f"[OK] 已保存至: {output_path}")
    
    return {
        "success": True,
        "operators_replaced": ["Mod"],
        "nodes_count": {"Mod": replaced_count},
        "output_path": output_path,
        "message": f"成功替换 {replaced_count} 个Mod节点"
    }


def replace_expand_with_tile(model_path: str, output_path: str,
                             node_name: str = None) -> Dict:
    model = onnx.load(model_path)
    graph = gs.import_onnx(model)
    
    expand_nodes = [n for n in graph.nodes if n.op == "Expand"]
    
    if not expand_nodes:
        print("[WARNING] 未找到Expand节点")
        return {
            "success": True,
            "operators_replaced": [],
            "nodes_count": {},
            "output_path": output_path,
            "message": "未找到Expand节点"
        }
    
    if node_name:
        expand_nodes = [n for n in expand_nodes if n.name == node_name]
        if not expand_nodes:
            return {
                "success": False,
                "error": f"未找到指定节点: {node_name}",
                "operators_replaced": [],
                "nodes_count": {}
            }
    
    replaced_count = 0
    
    for expand_node in expand_nodes:
        inp = expand_node.inputs[0]
        shape = expand_node.inputs[1]
        out = expand_node.outputs[0]
        
        dtype = out.dtype if out.dtype else np.float32
        
        downstream_consumers = list(out.outputs)
        
        new_out = gs.Variable(name=f"new_out_{expand_node.name}", dtype=dtype)
        
        tile_node = gs.Node(op="Tile", inputs=[inp, shape], outputs=[new_out])
        
        graph.nodes.append(tile_node)
        
        for down_node in downstream_consumers:
            for i, inp_var in enumerate(down_node.inputs):
                if inp_var is out:
                    down_node.inputs[i] = new_out
        
        if out in graph.outputs:
            idx = graph.outputs.index(out)
            graph.outputs[idx] = new_out
        
        replaced_count += 1
    
    graph.cleanup()
    graph.toposort()
    
    new_model = gs.export_onnx(graph)
    onnx.checker.check_model(new_model)
    onnx.save(new_model, output_path)
    print(f"[OK] 成功替换 {replaced_count} 个Expand节点为Tile节点")
    print(f"[OK] 已保存至: {output_path}")
    
    return {
        "success": True,
        "operators_replaced": ["Expand"],
        "nodes_count": {"Expand": replaced_count},
        "output_path": output_path,
        "message": f"成功替换 {replaced_count} 个Expand节点"
    }


RUNTIME_OP_REPLACEMENTS = {
    "split_to_slices": replace_split_with_slices,
    "mod_to_sub_mul_div": replace_mod_equivalent,
    "expand_to_tile": replace_expand_with_tile,
}


def fix_runtime_operators(input_path: str, output_path: str,
                         device: str = "NPU",
                         operators: List[str] = None) -> Dict:
    config = load_replacement_config()
    
    if not Path(input_path).exists():
        return {
            "success": False,
            "error": f"File not found: {input_path}",
            "operators_replaced": [],
            "nodes_per_operator": {}
        }
    
    model = onnx.load(input_path)
    
    unsupported_in_model = get_unsupported_operators(input_path, device)
    
    if not unsupported_in_model:
        return {
            "success": True,
            "operators_replaced": [],
            "nodes_per_operator": {},
            "output_path": output_path,
            "message": f"模型中没有在{device}上不支持的算子"
        }
    
    to_replace = {}
    for op in unsupported_in_model:
        if op in config and config[op].get("strategy"):
            count = len([n for n in model.graph.node if n.op_type == op])
            if count > 0:
                to_replace[op] = {
                    "count": count,
                    "strategy": config[op]["strategy"]
                }
    
    if operators:
        to_replace = {k: v for k, v in to_replace.items() if k in operators}
    
    if not to_replace:
        unsupported_ops = [op for op in unsupported_in_model if op not in config or not config[op].get("strategy")]
        return {
            "success": True,
            "operators_replaced": [],
            "nodes_per_operator": {},
            "output_path": output_path,
            "message": f"检测到不支持的算子但无替换策略: {unsupported_ops}",
            "unsupported_without_strategy": unsupported_ops
        }
    
    print(f"[INFO] 检测到需要替换的算子 ({device}不支持):")
    for op, info in to_replace.items():
        print(f"  - {op}: {info['count']} 个节点 (策略: {info['strategy']})")
    
    temp_path = input_path
    final_output = output_path
    replaced_ops = []
    nodes_per_op = {}
    
    for op, info in to_replace.items():
        strategy = info["strategy"]
        if strategy in RUNTIME_OP_REPLACEMENTS:
            replace_func = RUNTIME_OP_REPLACEMENTS[strategy]
            
            stem = Path(temp_path).stem
            parent = Path(temp_path).parent
            target_output = str(parent / f"{stem}_{op}_replaced.onnx")
            
            result = replace_func(temp_path, target_output)
            
            if isinstance(result, dict) and result.get("success"):
                replaced_ops.append(op)
                nodes_per_op[op] = info["count"]
                temp_path = target_output
            elif isinstance(result, dict) and not result.get("success"):
                print(f"[WARNING] {op} 替换失败: {result.get('error', 'unknown error')}")
    
    if temp_path != final_output and replaced_ops:
        import shutil
        shutil.copy(temp_path, final_output)
        print(f"[INFO] 最终输出: {final_output}")
    
    return {
        "success": True,
        "operators_replaced": replaced_ops,
        "nodes_per_operator": nodes_per_op,
        "output_path": final_output,
        "device": device,
        "message": f"成功替换 {len(replaced_ops)} 种算子"
    }


if __name__ == "__main__":
    import sys
    
    if len(sys.argv) < 3:
        print("Usage: python runtime_op_replace.py <input.onnx> <output.onnx> [device] [operators]")
        print("\nExample:")
        print("  python runtime_op_replace.py model.onnx model_fixed.onnx NPU Split")
        print("  python runtime_op_replace.py model.onnx model_fixed.onnx CPU")
        sys.exit(1)
    
    input_path = sys.argv[1]
    output_path = sys.argv[2]
    device = sys.argv[3] if len(sys.argv) > 3 else "NPU"
    operators = sys.argv[4:] if len(sys.argv) > 4 else None
    
    result = fix_runtime_operators(input_path, output_path, device, operators)
    print("\n" + "=" * 60)
    print("Result:")
    print(json.dumps(result, indent=2, ensure_ascii=False))