# Copyright 2022 Huawei Technologies Co., Ltd

#

# Licensed under the Apache License, Version 2.0 (the "License");

# you may not use this file except in compliance with the License.

# You may obtain a copy of the License at

#

#     http://www.apache.org/licenses/LICENSE-2.0

#

# Unless required by applicable law or agreed to in writing, software

# distributed under the License is distributed on an "AS IS" BASIS,

# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.

# See the License for the specific language governing permissions and

# limitations under the License.

import sys

sys.path.append(r"./SpanBERT/code")

import argparse

import collections

import json

import logging

import re

from io import open

import numpy as np

import torch

import os

from tqdm import tqdm

from torch.utils.data import DataLoader, TensorDataset

from run_squad import *

from pytorch_pretrained_bert.tokenization import (BertTokenizer,

                                                  whitespace_tokenize)

logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s',

                    datefmt='%m/%d/%Y %H:%M:%S',

                    level=logging.INFO)

logger = logging.getLogger(__name__)







def main(args):



    tokenizer = BertTokenizer.from_pretrained(

        args.model, do_lower_case=args.do_lower_case)

    eval_examples = read_squad_examples(

        input_file=args.dev_file, is_training=False,

        version_2_with_negative=args.version_2_with_negative)

    eval_features = convert_examples_to_features(

        examples=eval_examples,

        tokenizer=tokenizer,

        max_seq_length=args.max_seq_length,

        doc_stride=args.doc_stride,

        max_query_length=args.max_query_length,

        is_training=False)

    all_input_ids = torch.tensor([f.input_ids for f in eval_features], dtype=torch.long)

    all_input_mask = torch.tensor([f.input_mask for f in eval_features], dtype=torch.long)

    all_segment_ids = torch.tensor([f.segment_ids for f in eval_features], dtype=torch.long)

    all_example_index = torch.arange(all_input_ids.size(0), dtype=torch.long)

    eval_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_example_index)

    eval_dataloader = DataLoader(eval_data, batch_size=args.batch_size)

    input_ids_path = "./input_ids"

    input_mask_path = "./input_mask"

    segment_ids_path = "./segment_ids"

    if not os.path.exists(input_ids_path):

        os.makedirs(input_ids_path)

    if not os.path.exists(input_mask_path):

        os.makedirs(input_mask_path)

    if not os.path.exists(segment_ids_path):

        os.makedirs(segment_ids_path)

    i = -1

    # for input_ids, input_mask, segment_ids, example_indices in eval_dataloader:

    for input_ids, input_mask, segment_ids, example_indices in tqdm(eval_dataloader, leave = False):

        i = i + 1

        input_ids_np = input_ids.numpy()

        input_mask_np = input_mask.numpy()

        segment_ids_np = segment_ids.numpy()

        input_ids_np.tofile(os.path.join(input_ids_path, str(i) + '.bin'))

        segment_ids_np.tofile(os.path.join(segment_ids_path, str(i) + '.bin'))

        input_mask_np.tofile(os.path.join(input_mask_path, str(i) + '.bin'))





if __name__ == "__main__":

        parser = argparse.ArgumentParser()

        parser.add_argument("--model", default='spanbert-base-cased', type=str)

        parser.add_argument("--dev_file", default='dev-v1.1.json', type=str,

                            help="SQuAD json for predictions. E.g., dev-v1.1.json or test-v1.1.json")

        parser.add_argument("--max_seq_length", default=512, type=int,

                            help="The maximum total input sequence length after WordPiece tokenization. Sequences "

                                 "longer than this will be truncated, and sequences shorter than this will be padded.")

        parser.add_argument("--doc_stride", default=128, type=int,

                            help="When splitting up a long document into chunks, "

                                 "how much stride to take between chunks.")

        parser.add_argument("--max_query_length", default=64, type=int,

                            help="The maximum number of tokens for the question. Questions longer than this will "

                                 "be truncated to this length.")

        parser.add_argument("--do_lower_case", action='store_true', help="Set this flag if you are using an uncased model.")

        parser.add_argument('--version_2_with_negative', action='store_true',

                            help='If true, the SQuAD examples contain some that do not have an answer.')

        parser.add_argument("--batch_size", default=1, type=int,

                            help="When splitting up a long document into chunks, "

                                 "how much stride to take between chunks.")

        args = parser.parse_args()



        main(args)