from flagai.auto_model.auto_loader import AutoLoader
from flagai.model.predictor.predictor import Predictor
import torch
from flagai.data.tokenizer import Tokenizer
from flagai.model.bert_model import BertModel, BertForSeq2seq, \
BertForSequenceLabeling, \
BertForSequenceLabelingGP, \
BertForClsClassifier, \
BertForSequenceLabelingCRF
from flagai.data.tokenizer.bert.bert_tokenizer import BertTokenizer
import unittest
class BertTestCase(unittest.TestCase):
def setUp(self) -> None:
self.models = [BertForClsClassifier,
BertForSeq2seq,
BertForSequenceLabeling,
BertForSequenceLabelingGP,
BertForSequenceLabelingCRF]
self.model_name = "RoBERTa-base-ch"
self.bert_path = "./checkpoints/RoBERTa-base-ch/config.json"
self.tokenizer = Tokenizer.from_pretrained(self.model_name)
print("loading bert model successfully!")
def test_model_predict(self):
for model in self.models:
model = model.init_from_json(self.bert_path, class_num=3, inner_dim=64)
class_name = type(model).__name__.lower()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
predictor = Predictor(model, self.tokenizer)
text = "今天吃饭吃了肯德基"
if "seq2seq" in class_name:
output = predictor.predict_generate_beamsearch(text, out_max_length=20)
output = predictor.predict_generate_randomsample(text, out_max_length=20)
elif "cls" in class_name:
output = predictor.predict_cls_classifier(text)
elif "sequencelabeling" in class_name:
output = predictor.predict_ner(text, target=["0", "1", "2"])
else :
output = None
print(f"model_name is {class_name}, output is {output}")
def suite():
suite = unittest.TestSuite()
suite.addTest(BertTestCase('test_model_predict'))
return suite
if __name__ == '__main__':
runner = unittest.TextTestRunner()
runner.run(suite())