from flagai.trainer import Trainer
from flagai.model.glm_model import GLMForSeq2Seq
from flagai.data.tokenizer import Tokenizer
from flagai.data.dataset import Seq2SeqDataset
from flagai.data.dataset.superglue.control import DEFAULT_METRICS, CH_TASKS
from flagai.data.dataset import ConstructSeq2seqStrategy
from flagai.test_utils import Seq2SeqCollateArguments
import unittest
class TrainerTestCase(unittest.TestCase):
def test_init_trainer_pytorch(self):
task_name = "cmrc"
cl_args = Seq2SeqCollateArguments()
trainer = Trainer(env_type='pytorch',
epochs=1,
batch_size=1,
eval_interval=100,
log_interval=50,
experiment_name='glm_large',
fp16=True,
pytorch_device='cuda',
lr=1e-4)
print("downloading...")
if task_name in CH_TASKS:
model_name = 'GLM-large-ch'
else:
model_name = 'GLM-large-en'
tokenizer = Tokenizer.from_pretrained(model_name)
train_dataset = Seq2SeqDataset(task_name=task_name,
data_dir='./data/',
dataset_type='train',
tokenizer=tokenizer)
valid_dataset = Seq2SeqDataset(task_name=task_name,
data_dir='./data/',
dataset_type='dev',
tokenizer=tokenizer)
collate_fn = ConstructSeq2seqStrategy(cl_args,
tokenizer,
task_name=task_name)
train_dataset.example_list = train_dataset.example_list[:10]
valid_dataset.example_list = valid_dataset.example_list[:10]
model = GLMForSeq2Seq.from_pretrain(model_name=model_name,
only_download_config=True)
trainer.train(model,
collate_fn=collate_fn,
train_dataset=train_dataset,
valid_dataset=valid_dataset,
metric_methods=DEFAULT_METRICS[task_name])
def suite():
suite = unittest.TestSuite()
suite.addTest(TrainerTestCase('test_init_trainer_pytorch'))
return suite
if __name__ == '__main__':
runner = unittest.TextTestRunner()
runner.run(suite())