import time
import torch
import numpy as np
from train_eval import train, init_network
from importlib import import_module
import argparse
import torch.distributed as dist
import os
from apex import amp
import apex
if torch.__version__ >= '1.8':
import torch_npu
parser = argparse.ArgumentParser(description='Chinese Text Classification')
parser.add_argument('--model', type=str, required=True, help='choose a model: TextCNN, TextRNN, FastText, TextRCNN, TextRNN_Att, DPCNN, Transformer')
parser.add_argument('--embedding', default='pre_trained', type=str, help='random or pre_trained')
parser.add_argument('--word', default=False, type=bool, help='True for word, False for char')
parser.add_argument('--dist_backend', default='hccl', type=str, help='hccl for npu, must!')
parser.add_argument('--world_size', default=1, type=int, help='ddp world size')
parser.add_argument('--local_rank', default=0, type=int, help='local rank')
parser.add_argument('--num_epochs', default=20, type=int, help='number of train epoch')
parser.add_argument('--distributed', action="store_true", help='distributed')
parser.add_argument('--data_path', default='THUCNews', type=str, help='data path')
args = parser.parse_args()
def main():
dataset = args.data_path
print("args.world_size = ", args.world_size)
print("args.addr = ", os.environ["MASTER_ADDR"])
print("args.Port = ", os.environ["MASTER_PORT"])
args.rank = 0
if args.distributed:
args.device = 'npu:%d' % args.local_rank
torch.npu.set_device(args.device)
dist.init_process_group(backend=args.dist_backend, world_size=args.world_size, rank=args.local_rank)
args.world_size = int(os.environ['WORLD_SIZE'])
args.rank = torch.distributed.get_rank()
else:
args.device = f'npu:{args.local_rank}'
torch.npu.set_device(args.device)
print('local_rank {}'.format(args.local_rank))
embedding = 'embedding_SougouNews.npz'
if args.embedding == 'random':
embedding = 'random'
model_name = args.model
if model_name == 'FastText':
from utils_fasttext import build_dataset, build_iterator, get_time_dif
embedding = 'random'
else:
from utils import build_dataset, build_iterator, get_time_dif
x = import_module('models.' + model_name)
config = x.Config(dataset, embedding)
config.device = args.device
config.world_size = args.world_size
config.local_rank = args.local_rank
config.batch_size = config.batch_size * config.world_size
config.distributed = args.distributed
config.num_epochs = args.num_epochs
np.random.seed(666)
torch.manual_seed(666)
torch.npu.manual_seed_all(666)
torch.backends.cudnn.deterministic = True
start_time = time.time()
print("Loading data...")
vocab, train_data, dev_data, test_data = build_dataset(config, args.word)
train_iter = build_iterator(train_data, config)
dev_iter = build_iterator(dev_data, config)
test_iter = build_iterator(test_data, config)
time_dif = get_time_dif(start_time)
print("Time usage:", time_dif)
config.n_vocab = len(vocab)
model = x.Model(config).to(config.device)
optimizer = apex.optimizers.NpuFusedAdam(model.parameters(), lr=config.learning_rate)
model, optimizer = amp.initialize(model, optimizer, opt_level='O2', loss_scale="dynamic", combine_grad=True,master_weights=True)
if config.distributed:
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank], broadcast_buffers=False)
if model_name != 'Transformer':
init_network(model)
print(model.parameters)
train(config, model, train_iter, dev_iter, test_iter, optimizer)
if __name__ == '__main__':
main()