import argparse
import os
import torch
import torch.nn.functional as F
import lib.utils as utils
import lib.medloaders as medical_loaders
import lib.medzoo as medzoo
from lib.visual3D_temp import non_overlap_padding,test_padding
from lib.losses3D import DiceLoss
def main():
args = get_arguments()
seed = 1777777
utils.reproducibility(args, seed)
training_generator, val_generator, full_volume, affine = medical_loaders.generate_datasets(args,
path='./datasets')
model, optimizer = medzoo.create_model(args)
criterion = DiceLoss(classes=args.classes)
print(affine.shape)
model.restore_checkpoint(args.pretrained)
if args.cuda:
model = model.cuda()
full_volume = full_volume.cuda()
print("Model transferred in GPU.....")
x = torch.randn(3,156,240,240).cuda()
print(full_volume.shape)
output = non_overlap_padding(args,full_volume,model,criterion,kernel_dim=(32,32,32))
def get_arguments():
parser = argparse.ArgumentParser()
parser.add_argument('--batchSz', type=int, default=1)
parser.add_argument('--dataset_name', type=str, default="iseg2017")
parser.add_argument('--dim', nargs="+", type=int, default=(64, 64, 64))
parser.add_argument('--nEpochs', type=int, default=250)
parser.add_argument('--classes', type=int, default=4)
parser.add_argument('--samples_train', type=int, default=1)
parser.add_argument('--samples_val', type=int, default=1)
parser.add_argument('--split', type=float, default=0.8)
parser.add_argument('--inChannels', type=int, default=2)
parser.add_argument('--inModalities', type=int, default=2)
parser.add_argument('--fold_id', default='1', type=str, help='Select subject for fold validation')
parser.add_argument('--lr', default=1e-2, type=float,
help='learning rate (default: 1e-3)')
parser.add_argument('--cuda', action='store_true', default=True)
parser.add_argument('--resume', default='', type=str, metavar='PATH',
help='path to latest checkpoint (default: none)')
parser.add_argument('--model', type=str, default='UNET3D',
choices=('VNET', 'VNET2', 'UNET3D', 'DENSENET1', 'DENSENET2', 'DENSENET3', 'HYPERDENSENET'))
parser.add_argument('--opt', type=str, default='sgd',
choices=('sgd', 'adam', 'rmsprop'))
parser.add_argument('--pretrained',
default='../saved_models/UNET3D_checkpoints/UNET3D_25_05___15_15_iseg2017_/UNET3D_25_05___15_15_iseg2017__last_epoch.pth',
type=str, metavar='PATH',
help='path to pretrained model')
args = parser.parse_args()
args.save = '../inference_checkpoints/' + args.model + '_checkpoints/' + args.model + '_{}_{}_'.format(
utils.datestr(), args.dataset_name)
args.tb_log_dir = '../runs/'
return args
if __name__ == '__main__':
main()
'''
def overlap_3d_image():
B, C, D, H, W = 2, 1, 144, 192, 256
#B, C, D, H, W = 1, 1, 4, 4, 4
x = torch.randn(B, C, D, H, W)
print('IMAGE shape ', x.shape) # [B, C, D, num_of_patches_H,num_of_patches_W, kernel_size,kernel_size]
kernel_size = 32
stride = 16
patches = x.unfold(4, kernel_size, stride)
print('patches shape ', patches.shape) # [B, C, D, H, num_of_patches_W, kernel_size]
patches = patches.unfold(3, kernel_size, stride)
print('patches shape ', patches.shape) # [B, C, D, num_of_patches_H,num_of_patches_W, kernel_size,kernel_size]
patches = patches.unfold(2, kernel_size, stride)
print('patches shape ', patches.shape) # [B, C, num_of_patches_D, num_of_patches_H,num_of_patches_W, kernel_size ,kernel_size,kernel_size]
# patches = patches.unfold()
# perform the operations on each patchff
# ...
B, C, num_of_patches_D, num_of_patches_H,num_of_patches_W, kernel_size ,kernel_size,kernel_size = patches.shape
# # reshape output to match F.fold input
patches = patches.contiguous().view(B, C,num_of_patches_D* kernel_size, -1, kernel_size * kernel_size)
print(patches.shape)
patches = patches.contiguous().view(B, C,num_of_patches_D* kernel_size, -1, kernel_size * kernel_size)
print(patches.shape)
print('slice shape ',patches[:,:,0,:,:].shape)
slices = []
for i in range(num_of_patches_D * kernel_size):
output = F.fold(
patches[:,:,i,:,:].contiguous().view(B, C * kernel_size * kernel_size,-1), output_size=(H, W), kernel_size=kernel_size, stride=stride)
#print(output.shape) # [B, C, H, W]
slices.append(output)
image = torch.stack(slices)
print(image.shape)
print(image.is_contiguous())
image = image.permute(1,2,0,3,4).contiguous().view(B,C,-1,H*W)
print(image.shape)
output = F.fold(
image.contiguous().view(B*H*W, C*kernel_size, -1), output_size=(D, 1), kernel_size=kernel_size, stride=stride)
print(output.shape) # [B, C, H, W]
'''