"""TinyBERT finetuning runner specifically for SST-2 dataset."""
from __future__ import absolute_import, division, print_function
import argparse
import random
import numpy as np
import torch
from transformer.modeling import TinyBertForSequenceClassification
def random_int_list(start, stop, length):
start, stop = (int(start), int(stop)) if start <= stop else (int(stop), int(start))
length = int(abs(length)) if length else 0
random_list = []
for i in range(length):
random_list.append(random.randint(start, stop))
return random_list
def is_same(a,b,i):
result = (a == b).mean()
if result == 1:
print("step {} = step {}: {}".format(i-1,i,'True'))
else:
print("step {} = step {}: {}".format(i - 1, i, 'False'))
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--model",
default=None,
type=str,
help="The model dir.")
parser.add_argument("--max_seq_length",
default=64,
type=int,
help="The maximum total input sequence length after WordPiece tokenization. \n"
"Sequences longer than this will be truncated, and sequences shorter \n"
"than this will be padded.")
args = parser.parse_args()
model = TinyBertForSequenceClassification.from_pretrained(args.model, num_labels=2)
model.eval()
input_ids = torch.tensor(random_int_list(0,9999,args.max_seq_length), dtype=torch.long).view(1,-1)
print(input_ids)
segment_ids = torch.tensor(random_int_list(0,1,args.max_seq_length), dtype=torch.long).view(1,-1)
input_mask = torch.tensor(random_int_list(0,1,args.max_seq_length), dtype=torch.long).view(1,-1)
repeat_time = 20
for i in range(1,repeat_time+1):
logits, _, _ = model(input_ids, segment_ids, input_mask)
logits = logits.squeeze()
print("step {}, logits = {}".format(i,logits))
if i == 1:
a = logits
elif i == 2:
b = logits
is_same(a.detach().numpy(),b.detach().numpy(),i)
else:
a = b
b = logits
is_same(a.detach().numpy(),b.detach().numpy(),i)
if __name__ == "__main__":
main()