import argparse
import sys
sys.path.append('./fairseq/')
import fairseq
import torch
def run_pth2onnx(args):
models, cfg, task = fairseq.checkpoint_utils.load_model_ensemble_and_task([args.model_path])
model = models[0].to("cpu")
model.eval()
source = torch.zeros([1, 580000], dtype=torch.float32).to("cpu")
input_names = ["source"]
output_names = ["result"]
torch.onnx.export(model, source, args.onnx_path, input_names=input_names, output_names=output_names, opset_version=11, verbose=True)
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--model_path', type=str, default='./data/pt/hubert_large_ll60k_finetune_ls960.pt')
parser.add_argument('--onnx_path', type=str, default='./hubert.onnx')
args = parser.parse_args()
run_pth2onnx(args)
if __name__ == '__main__':
main()