import time
import torch
if torch.__version__>= '1.8':
import torch_npu
import torch.utils.data
import torch.optim as optim
from torch import distributed as dist
from model_VGG import advancedEAST
from losses import quad_loss
from dataset import RawDataset, data_collate
from utils import Averager
import cfg
device = torch.device(cfg.device)
def train():
if cfg.distributed:
torch.npu.set_device(cfg.local_rank)
dist.init_process_group(backend='hccl', world_size=cfg.world_size, rank=cfg.local_rank)
""" dataset preparation """
train_dataset = RawDataset(is_val=False)
val_dataset = RawDataset(is_val=True)
if cfg.distributed:
train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset, shuffle=True)
val_sampler = torch.utils.data.distributed.DistributedSampler(val_dataset, shuffle=False)
else:
train_sampler = None
val_sampler = None
train_loader = torch.utils.data.DataLoader(
train_dataset,
batch_size=cfg.batch_size,
collate_fn=data_collate,
shuffle=(train_sampler is None),
num_workers=int(cfg.workers),
pin_memory=True,
sampler=train_sampler)
valid_loader = torch.utils.data.DataLoader(
val_dataset,
batch_size=cfg.batch_size,
collate_fn=data_collate,
shuffle=False,
num_workers=int(cfg.workers),
pin_memory=True,
sampler=val_sampler)
model = advancedEAST()
if cfg.pth_path:
model.load_state_dict(torch.load(cfg.pth_path, map_location='cpu'))
if cfg.is_master_node:
print('Load {}'.format(cfg.pth_path))
elif int(cfg.train_task_id[-3:]) != 256:
id_num = cfg.train_task_id[-3:]
idx_dic = {'384': 256, '512': 384, '640': 512, '736': 640}
state_dict = {k.replace('module.', ''): v for k, v in torch.load(
'./saved_model/3T{}_latest.pth'.format(idx_dic[id_num]), map_location='cpu').items()}
model.load_state_dict(state_dict)
if cfg.is_master_node:
print('Load ./saved_model/3T{}_latest.pth'.format(idx_dic[id_num]))
model = model.to(device)
optimizer = optim.Adam(model.parameters(), lr=cfg.lr, weight_decay=cfg.decay)
if cfg.amp:
import apex
optimizer = apex.optimizers.NpuFusedAdam(model.parameters(), lr=cfg.lr, weight_decay=cfg.decay)
model, optimizer = apex.amp.initialize(model, optimizer, opt_level="O1", loss_scale="dynamic", combine_grad=True)
if cfg.distributed:
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[cfg.local_rank], broadcast_buffers=False)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, cfg.epoch_num)
loss_func = quad_loss
'''start training'''
start_iter = 0
start_time = time.time()
i = start_iter
step_num = 0
loss_avg = Averager()
val_loss_avg = Averager()
total_train_img = int(cfg.total_img * (1 - cfg.validation_split_ratio))
while(True):
model.train()
if cfg.distributed:
train_sampler.set_epoch(i)
epoch_start_time = time.time()
for image_tensors, labels, gt_xy_list in train_loader:
step_num += 1
batch_x = image_tensors.float().to(device)
batch_y = labels.float().to(device)
out = model(batch_x)
loss = loss_func(batch_y, out)
optimizer.zero_grad()
if cfg.amp:
with apex.amp.scale_loss(loss, optimizer) as scaled_loss:
scaled_loss.backward()
else:
loss.backward()
optimizer.step()
loss_avg.add(loss)
loss = loss_avg.val()
if cfg.distributed:
dist.all_reduce(loss)
loss = loss / cfg.world_size
loss = loss.item()
if cfg.is_master_node:
print('Epoch:[{}/{}] Training loss:{:.3f} FPS:{:.3f} LR:{:.3e}'.format(i + 1, cfg.epoch_num,
loss, total_train_img / (time.time() - epoch_start_time), optimizer.param_groups[0]['lr']))
loss_avg.reset()
scheduler.step()
if (i + 1) % cfg.val_interval == 0:
elapsed_time = time.time() - start_time
if cfg.is_master_node:
print('Elapsed time:{}s'.format(round(elapsed_time)))
model.eval()
for image_tensors, labels, gt_xy_list in valid_loader:
batch_x = image_tensors.float().to(device)
batch_y = labels.float().to(device)
out = model(batch_x)
loss = loss_func(batch_y, out)
val_loss_avg.add(loss)
loss = val_loss_avg.val()
if cfg.distributed:
dist.all_reduce(loss)
loss = loss / cfg.world_size
loss = loss.item()
if cfg.is_master_node:
print('Validation loss:{:.3f}'.format(loss))
val_loss_avg.reset()
if i + 1 == cfg.epoch_num:
if cfg.is_master_node:
torch.save(model.state_dict(), './saved_model/{}_latest.pth'.format(cfg.train_task_id))
print('End the training')
break
i += 1
if __name__ == '__main__':
train()