"""
Convert RT-DETR PyTorch model to Ascend OM format
"""
import os
import sys
import argparse
import subprocess


def export_onnx(config, checkpoint, onnx_file, input_size):
    """Export PyTorch model to ONNX"""
    cmd = [
        'python', 'tools/export_onnx.py',
        '-c', config,
        '-r', checkpoint,
        '-o', onnx_file,
        '-s', str(input_size),
        '--simplify'
    ]
    print(f"Exporting to ONNX: {' '.join(cmd)}")
    result = subprocess.run(cmd)
    if result.returncode != 0:
        raise RuntimeError("ONNX export failed")
    print(f"ONNX export successful: {onnx_file}\n")


def export_om(onnx_file, om_file, input_size, soc_version, batch_size):
    """Convert ONNX to OM using ATC"""
    input_shape = f"images:{batch_size},3,{input_size},{input_size};orig_target_sizes:{batch_size},2"

    cmd = [
        'atc',
        '--model', onnx_file,
        '--framework', '5',
        '--output', om_file,
        '--soc_version', soc_version,
        '--input_shape', input_shape,
        '--op_select_implmode', 'high_precision',
        '--log', 'info'
    ]

    print(f"Converting to OM: {' '.join(cmd)}")
    result = subprocess.run(cmd)
    if result.returncode != 0:
        raise RuntimeError("OM conversion failed")
    print(f"OM conversion successful: {om_file}.om\n")


def main():
    parser = argparse.ArgumentParser(description='Convert RT-DETR to Ascend OM format')
    parser.add_argument('-c', '--config', required=True, help='Model config file')
    parser.add_argument('-r', '--checkpoint', required=True, help='Checkpoint file (best.pth)')
    parser.add_argument('-o', '--output', default='rtdetr', help='Output OM file name (without extension)')
    parser.add_argument('-s', '--input-size', type=int, default=640, help='Input image size')
    parser.add_argument('-b', '--batch-size', type=int, default=1, help='Batch size')
    parser.add_argument('--soc-version', default='Ascend910B3', help='SOC version (Ascend910B3, Ascend310P3, etc.)')
    parser.add_argument('--keep-onnx', action='store_true', help='Keep intermediate ONNX file')
    args = parser.parse_args()

    onnx_file = f"{args.output}.onnx"
    om_file = args.output

    # Step 1: Export to ONNX
    export_onnx(args.config, args.checkpoint, onnx_file, args.input_size)

    # Step 2: Convert ONNX to OM
    export_om(onnx_file, om_file, args.input_size, args.soc_version, args.batch_size)

    # Clean up ONNX file if not needed
    if not args.keep_onnx and os.path.exists(onnx_file):
        os.remove(onnx_file)
        print(f"Removed intermediate file: {onnx_file}")

    print(f"\nConversion complete! Output: {om_file}.om")
    print(f"SOC version: {args.soc_version}")
    print(f"Input shape: images:[{args.batch_size},3,{args.input_size},{args.input_size}], orig_target_sizes:[{args.batch_size},2]")


if __name__ == '__main__':
    main()