from __future__ import print_function, division
from apex import amp
from apex.optimizers import NpuFusedSGD
import torch
if torch.__version__ >= "1.8":
import torch_npu
else:
import torch.npu
import argparse
import torch.nn as nn
import torch.optim as optim
from torch.autograd import Variable
from torch.utils.data import Dataset, DataLoader
import numpy as np
import torchvision
from torchvision import transforms, datasets, models
import os
import cv2
import time
import torch.utils.data.distributed
from collections import OrderedDict
from torch.nn.parallel import DistributedDataParallel
from model.residual_attention_network import ResidualAttentionModel_92_32input_update as ResidualAttentionModel
device = None
args = None
log_file = None
def parse_args():
parser = argparse.ArgumentParser(description='3D Attention Net')
parser.add_argument('--device_type', type=str)
parser.add_argument('--device_id', type=int)
parser.add_argument('--device_num', type=int)
parser.add_argument('--total_epochs', type=int)
parser.add_argument('--is_train', type=str)
parser.add_argument('--is_pretrain', type=str)
parser.add_argument('--num_classes', type=int)
parser.add_argument('--dist_url', type=str)
parser.add_argument('--train_batch_size', type=int)
parser.add_argument('--test_batch_size', type=int)
args = parser.parse_args()
return args
def test(model, test_loader):
classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
model.eval()
correct = 0
total = 0
class_correct = list(0. for i in range(10))
class_total = list(0. for i in range(10))
for images, labels in test_loader:
images = images.to(device, non_blocking=True)
labels = labels.to(device, non_blocking=True)
outputs = model(images)
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels.data).sum()
c = (predicted == labels.data).squeeze()
for i in range(len(labels.data)):
label = labels.data[i]
class_correct[label] += c[i]
class_total[label] += 1
if args.device_id == 0:
print(f"Accuracy of the model(on device: {args.device_id}) on the test images: {100 * float(correct) / total} %")
write_log('Accuracy of the model on the test images: %d %%\n' % (100 * float(correct) / total))
write_log(f'Accuracy of the model on the test images: {float(correct)/total} \n')
return float(correct) / total
def write_log(output):
if log_file is not None:
log_file.write(output)
def main():
global args
global device
args = parse_args()
print(args)
model_file = 'model_92_sgd.pkl'
train_batch_size = args.train_batch_size
test_batch_size = args.test_batch_size
lr = 0.1
is_train = args.is_train == "True"
is_pretrain = args.is_pretrain == "True"
acc_best = 0
total_epoch = args.total_epochs
distribute = args.device_num > 1
if(args.device_type == "GPU"):
device = torch.device("cuda", args.device_id)
if distribute:
torch.cuda.set_device(args.device_id)
torch.distributed.init_process_group(backend="nccl", init_method=args.dist_url, world_size=args.device_num, rank=args.device_id)
else:
device = f"npu:{args.device_id}"
if distribute:
os.environ['MASTER_ADDR'] = '127.0.0.1'
os.environ['MASTER_PORT'] = '49876'
torch.npu.set_device(device)
print("rank:",args.device_id)
torch.distributed.init_process_group(backend="hccl", world_size=args.device_num, rank=args.device_id)
transform = transforms.Compose([
transforms.RandomHorizontalFlip(),
transforms.RandomCrop((32, 32), padding=4),
transforms.ToTensor()
])
test_transform = transforms.Compose([
transforms.ToTensor()
])
train_dataset = datasets.CIFAR10(root='./data/', train=True, transform=transform, download=False)
test_dataset = datasets.CIFAR10(root='./data/', train=False, transform=test_transform)
train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) if distribute else None
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, \
batch_size=train_batch_size, \
shuffle=(train_sampler is None), \
num_workers=4, \
pin_memory=True, \
sampler = train_sampler if is_train else None, \
drop_last = True)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=test_batch_size, shuffle=False)
model = ResidualAttentionModel(args.num_classes).to(device)
criterion = nn.CrossEntropyLoss().to(device)
optimizer = None
if args.device_type == "GPU":
optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0.9, nesterov=True, weight_decay=0.0001)
else:
optimizer = NpuFusedSGD(model.parameters(), lr=lr, momentum=0.9, nesterov=True, weight_decay=0.0001)
model, optimizer = amp.initialize(model, optimizer, opt_level="O2", loss_scale=128)
if distribute:
if args.device_type == "GPU":
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.device_id])
else:
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.device_id], broadcast_buffers=False)
global log_file
if args.device_id == 0:
log_file = open("train_log_file" if is_train else "test_log_file", mode="w+")
if is_train is True:
if is_pretrain == True:
base_weights = torch.load(model_file, map_location="cpu")
print('Loading base network...')
new_state_dict = OrderedDict()
for k, v in base_weights.items():
if(k[0: 7] == "module."):
name = k[7:]
else:
name = k[0:]
new_state_dict[name] = v
if "fc.weight" in new_state_dict:
print("pop fc layer weight")
new_state_dict.pop("fc.weight")
new_state_dict.pop("fc.bias")
model.load_state_dict(new_state_dict, strict=False)
total_tims = 0
total_samples = 0
for epoch in range(total_epoch):
model.train()
tims = time.time()
epoch_samples = 0
if train_sampler is not None:
train_sampler.set_epoch(epoch)
for i, (images, labels) in enumerate(train_loader):
epoch_samples += images.shape[0]
if i == 5:
tims = time.time()
images = images.to(device, non_blocking=True)
labels = labels.to(device, non_blocking=True)
outputs = model(images)
loss = criterion(outputs, labels)
optimizer.zero_grad()
with amp.scale_loss(loss, optimizer) as scaled_loss:
scaled_loss.backward()
optimizer.step()
if (i+1) % 20 == 0 and args.device_id == 0:
print("Epoch [%d/%d], Iter [%d/%d] Loss: %.4f" %(epoch+1, total_epoch, i+1, len(train_loader), loss.item()))
write_log("Epoch [%d/%d], Iter [%d/%d] Loss: %.4f \n" %(epoch+1, total_epoch, i+1, len(train_loader), loss.item()))
total_tims += time.time() - tims
total_samples += epoch_samples
if args.device_id == 0:
print(f'the epoch {epoch+1} takes time:',time.time()-tims)
print(f"epoch {epoch+1} FPS: {(epoch_samples - 5 * train_batch_size)* args.device_num / (time.time()-tims)}")
print('evaluate test set:')
write_log(f'the epoch {epoch+1} takes time: {time.time()-tims} \n')
write_log(f"epoch {epoch+1} FPS: {(epoch_samples - 5 * train_batch_size)* args.device_num / (time.time()-tims)} \n")
acc = test(model, test_loader)
if acc > acc_best:
acc_best = acc
print('current best acc,', acc_best)
if args.device_id == 0:
torch.save(model.state_dict(), model_file)
if (epoch+1) / float(total_epoch) == 0.3 or (epoch+1) / float(total_epoch) == 0.6 or (epoch+1) / float(total_epoch) == 0.9:
lr /= 10
print('reset learning rate to:', lr)
for param_group in optimizer.param_groups:
param_group['lr'] = lr
print(param_group['lr'])
if args.device_id == 0:
torch.save(model.state_dict(), 'last_model_92_sgd.pkl')
elif args.device_id == 0:
base_weights = torch.load(model_file, map_location="cpu")
print('Loading base network...')
new_state_dict = OrderedDict()
for k, v in base_weights.items():
if(k[0: 7] == "module."):
name = k[7:]
else:
name = k[0:]
new_state_dict[name] = v
model.load_state_dict(new_state_dict)
test(model, test_loader)
if __name__ == "__main__":
main()