import os
import time
from datetime import datetime
import torch
import torch.nn as nn
import torch.backends.cudnn as cudnn
import torch.utils.data as data
from torch.optim import lr_scheduler
from util.shedule import FixLR
from dataset.total_text import TotalText
from dataset.synth_text import SynthText
from network.loss import TextLoss
from network.textnet import TextNet
from util.augmentation import BaseTransform, Augmentation
from util.config import config as cfg, update_config, print_config
from util.misc import AverageMeter
from util.misc import mkdirs, to_device,to_device_parrall
from util.option import BaseOptions
from util.visualize import visualize_network_output
from util.summary import LogSummary
import torch.multiprocessing as mp
import torch.distributed as dist
from apex.parallel import DistributedDataParallel as DDP
from apex import amp
import apex
import random
lr = None
train_step = 0
def save_model(model, epoch, lr, optimzer,cfg):
save_dir = os.path.join(cfg.save_dir, cfg.exp_name)
if not os.path.exists(save_dir):
mkdirs(save_dir)
save_path = os.path.join(save_dir, 'textsnake_{}.pth'.format(epoch))
print('Saving to {}.'.format(save_path))
state_dict = {
'lr': lr,
'epoch': epoch,
'model': model.state_dict() if not cfg.mgpu else model.module.state_dict(),
'optimizer': optimzer.state_dict()
}
torch.save(state_dict, save_path)
def load_model(model, model_path):
print('Loading from {}'.format(model_path))
state_dict = torch.load(model_path, lambda storage, loc: storage)
model.load_state_dict({k.replace('module.',''):v for k,v in state_dict['model'].items()})
def train(model, train_loader, criterion, scheduler, optimizer, epoch, gpu, cfg):
global train_step
batch_time = AverageMeter()
data_time = AverageMeter()
losses = AverageMeter()
progress = ProgressMeter(
len(train_loader),
[batch_time, data_time, losses],
prefix="Epoch: [{}]".format(epoch))
model.train()
end = time.time()
scheduler.step()
print('Epoch: {} : LR = {}'.format(epoch, scheduler.get_lr()))
loc = 'npu:{}'.format(gpu)
for i, (img, train_mask, tr_mask, tcl_mask, radius_map, sin_map, cos_map, meta) in enumerate(train_loader):
data_time.update(time.time() - end)
train_step += 1
img, train_mask, tr_mask, tcl_mask, radius_map, sin_map, cos_map = to_device_parrall(gpu,
img, train_mask, tr_mask, tcl_mask, radius_map, sin_map, cos_map)
output = model(img)
criterion = TextLoss().cpu()
output = output.cpu()
tr_mask = tr_mask.cpu()
tcl_mask = tcl_mask.cpu()
sin_map = sin_map.cpu()
cos_map = cos_map.cpu()
radius_map = radius_map.cpu()
train_mask = train_mask.cpu()
tr_loss, tcl_loss, sin_loss, cos_loss, radii_loss = \
criterion(output, tr_mask, tcl_mask, sin_map, cos_map, radius_map, train_mask)
loss = tr_loss + tcl_loss + sin_loss + cos_loss + radii_loss
loss = loss.to(loc)
optimizer.zero_grad()
with amp.scale_loss(loss, optimizer) as scaled_loss:
scaled_loss.backward()
optimizer.step()
losses.update(loss.item())
batch_time.update(time.time() - end)
end = time.time()
FPS = cfg.batch_size * cfg.world_size / batch_time.avg
if cfg.viz and i % cfg.viz_freq == 0 and gpu==0:
visualize_network_output(output, tr_mask, tcl_mask, mode='train', cfg=cfg)
if i % cfg.display_freq == 0:
print('({:d} / {:d}) - Loss: {:.4f} - tr_loss: {:.4f} - tcl_loss: {:.4f} - sin_loss: {:.4f} - cos_loss: {:.4f} - radii_loss: {:.4f}'.format(
i, len(train_loader), loss.item(), tr_loss.item(), tcl_loss.item(), sin_loss.item(), cos_loss.item(), radii_loss.item())
)
if epoch % cfg.save_freq == 0:
if(gpu==0):
save_model(model, epoch, scheduler.get_lr(), optimizer,cfg)
FPS = cfg.batch_size * cfg.world_size / batch_time.avg
print('Training Loss: {}'.format(losses.avg))
print('FPS: {:.3f}'.format(FPS))
def validation(model, valid_loader, criterion, epoch,gpu,cfg):
with torch.no_grad():
model.eval()
losses = AverageMeter()
tr_losses = AverageMeter()
tcl_losses = AverageMeter()
sin_losses = AverageMeter()
cos_losses = AverageMeter()
radii_losses = AverageMeter()
for i, (img, train_mask, tr_mask, tcl_mask, radius_map, sin_map, cos_map, meta) in enumerate(valid_loader):
img, train_mask, tr_mask, tcl_mask, radius_map, sin_map, cos_map = to_device_parrall(gpu,
img, train_mask, tr_mask, tcl_mask, radius_map, sin_map, cos_map)
output = model(img)
tr_loss, tcl_loss, sin_loss, cos_loss, radii_loss = \
criterion(output, tr_mask, tcl_mask, sin_map, cos_map, radius_map, train_mask)
loss = tr_loss + tcl_loss + sin_loss + cos_loss + radii_loss
losses.update(loss.item())
tr_losses.update(tr_loss.item())
tcl_losses.update(tcl_loss.item())
sin_losses.update(sin_loss.item())
cos_losses.update(cos_loss.item())
radii_losses.update(radii_loss.item())
if cfg.viz and i % cfg.viz_freq == 0:
visualize_network_output(output, tr_mask, tcl_mask, mode='val',cfg=cfg)
if i % cfg.display_freq == 0:
print(
'Validation: - Loss: {:.4f} - tr_loss: {:.4f} - tcl_loss: {:.4f} - sin_loss: {:.4f} - cos_loss: {:.4f} - radii_loss: {:.4f}'.format(
loss.item(), tr_loss.item(), tcl_loss.item(), sin_loss.item(),
cos_loss.item(), radii_loss.item())
)
print('Validation Loss: {}'.format(losses.avg))
def device_id_to_process_device_map(device_list):
devices = device_list.split(",")
devices = [int(x) for x in devices]
devices.sort()
process_device_map = dict()
for process_id, device_id in enumerate(devices):
process_device_map[process_id] = device_id
return process_device_map
def main():
global lr
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = str(random.randrange(1001,49999))
if cfg.dist_url == "env://" and cfg.world_size == -1:
cfg.world_size = int(os.environ["WORLD_SIZE"])
cfg.distributed = cfg.world_size > 1 or cfg.multiprocessing_distributed
if cfg.multiprocessing_distributed:
cfg.world_size = cfg.gpus * cfg.nodes
print(cfg.world_size)
mp.spawn(train_pre, nprocs=cfg.gpus,
args=(cfg.gpus, cfg))
else:
cfg.world_size = 1
train_pre(cfg.gpu, 1, cfg)
def train_pre(gpu,ngpus_per_node, cfg):
global lr
rank = gpu
loc = 'npu:{}'.format(rank)
torch.npu.set_device(loc)
if cfg.distributed:
torch.npu.set_device(rank)
dist.init_process_group(backend='hccl', init_method="env://",
world_size=cfg.world_size, rank=rank)
if cfg.dataset == 'total-text':
trainset = TotalText(
data_root='data/total-text',
ignore_list=None,
is_training=True,
transform=Augmentation(size=cfg.input_size, mean=cfg.means, std=cfg.stds)
)
valset = TotalText(
data_root='data/total-text',
ignore_list=None,
is_training=False,
transform=BaseTransform(size=cfg.input_size, mean=cfg.means, std=cfg.stds)
)
elif cfg.dataset == 'synth-text':
trainset = SynthText(
data_root='data/SynthText',
is_training=True,
transform=Augmentation(size=cfg.input_size, mean=cfg.means, std=cfg.stds)
)
valset = None
else:
pass
if cfg.distributed:
train_sampler = torch.utils.data.distributed.DistributedSampler(trainset,num_replicas=cfg.world_size,rank=rank)
else:
train_sampler = None
train_loader = data.DataLoader(dataset=trainset,
pin_memory=False,
shuffle=(train_sampler is None),
batch_size=cfg.batch_size,
num_workers=cfg.num_workers,
sampler=train_sampler)
if valset:
if cfg.distributed:
val_sampler = torch.utils.data.distributed.DistributedSampler(valset,num_replicas=cfg.world_size,rank=rank)
else:
val_sampler = None
val_loader = data.DataLoader(dataset=valset,
pin_memory=False,
shuffle=(val_sampler is None),
batch_size=cfg.batch_size,
num_workers=cfg.num_workers,
sampler=val_sampler)
else:
valset = None
model = TextNet(is_training=True, backbone=cfg.net)
model.to(loc)
if cfg.resume:
load_model(model, cfg.resume)
lr = cfg.lr
optimizer = apex.optimizers.NpuFusedAdam(model.parameters(), lr=cfg.lr)
model, optimizer = amp.initialize(model, optimizer, opt_level="O1", loss_scale=None,combine_grad=True)
if cfg.distributed:
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[rank], broadcast_buffers=False)
criterion = TextLoss().to(loc)
if cfg.dataset == 'synth-text':
scheduler = FixLR(optimizer)
else:
scheduler = lr_scheduler.StepLR(optimizer, step_size=100, gamma=0.1)
print('Start training TextSnake.')
for epoch in range(cfg.start_epoch, cfg.max_epoch):
if cfg.distributed:
train_sampler.set_epoch(epoch)
train(model, train_loader, criterion, scheduler, optimizer, epoch, rank, cfg)
if valset:
validation(model, val_loader, criterion, epoch,rank,cfg)
print('End.')
class ProgressMeter(object):
def __init__(self, num_batches, meters, prefix=""):
self.batch_fmtstr = self._get_batch_fmtstr(num_batches)
self.meters = meters
self.prefix = prefix
def display(self, batch):
entries = [self.prefix + self.batch_fmtstr.format(batch)]
entries += [str(meter) for meter in self.meters]
print('\t'.join(entries))
def _get_batch_fmtstr(self, num_batches):
num_digits = len(str(num_batches // 1))
fmt = '{:' + str(num_digits) + 'd}'
return '[' + fmt + '/' + fmt.format(num_batches) + ']'
if __name__ == "__main__":
option = BaseOptions()
args = option.initialize()
update_config(cfg, args)
print_config(cfg)
main()