05360171创建于 2022年3月18日历史提交
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import os

import torch
import torch.nn.functional as F

# Lib files
import numpy as np
from torch.utils.data import dataset
import lib.utils as utils
import lib.medloaders as medical_loaders
import lib.medzoo as medzoo
from lib.visual3D_temp import non_overlap_padding,test_padding
from lib.losses3D import DiceLoss
from lib.utils.general import prepare_input
from lib.medloaders.brats2018 import MICCAIBraTS2018





def main():
    args = get_arguments()
    seed = 1777777
    utils.reproducibility(args, seed)

    training_generator, val_generator, full_volume, affine = medical_loaders.generate_datasets(args,
                                                                                               path='./datasets')
    model, optimizer = medzoo.create_model(args)
    model.eval()
    criterion = DiceLoss(classes=args.classes)
    model.restore_checkpoint(args.pretrained)

    for s, input_tuple in enumerate(val_generator):
        input_tuple = [i.unsqueeze(1) for i in input_tuple[:-1]] + [input_tuple[-1]]
        input_tensor, target = prepare_input(input_tuple=input_tuple, args=args)
        
        img = input_tensor.numpy().astype(np.float32)
        img.tofile(os.path.join(args.output_bin, str(s) + ".bin"))
        
        target = np.array(target).astype(np.float32)
        path1 = os.path.join(args.output_label, str(s) + ".pth")
        torch.save(target, path1)
        s = s + 1
        
def get_arguments():
    parser = argparse.ArgumentParser()
    parser.add_argument('--batchSz', type=int, default=4)
    parser.add_argument('--dataset_name', type=str, default="brats2018")
    parser.add_argument('--dim', nargs="+", type=int, default=(64, 64, 64))
    parser.add_argument('--nEpochs', type=int, default=100)
    parser.add_argument('--classes', type=int, default=4)
    parser.add_argument('--samples_train', type=int, default=1024)
    parser.add_argument('--samples_val', type=int, default=128)
    parser.add_argument('--inChannels', type=int, default=4)
    parser.add_argument('--inModalities', type=int, default=4)
    parser.add_argument('--threshold', default=0.00000000001, type=float)
    parser.add_argument('--terminal_show_freq', default=50)
    parser.add_argument('--augmentation', action='store_true', default=True)
    parser.add_argument('--normalization', default='full_volume_mean', type=str,
                        help='Tensor normalization: options ,max_min,',
                        choices=('max_min', 'full_volume_mean', 'brats', 'max', 'mean'))
    parser.add_argument('--split', default=0.8, type=float, help='Select percentage of training data(default: 0.8)')
    parser.add_argument('--lr', default=5e-3, type=float,
                        help='learning rate (default: 5e-3)')
    parser.add_argument('--cuda', action='store_true', default=False)
    
    parser.add_argument('--loadData', default=True)
    parser.add_argument('--resume', default='', type=str, metavar='PATH',
                        help='path to latest checkpoint (default: none)')
    parser.add_argument('--model', type=str, default='UNET3D',
                        choices=('VNET', 'VNET2', 'UNET3D', 'DENSENET1', 'DENSENET2', 'DENSENET3', 'HYPERDENSENET'))
    parser.add_argument('--opt', type=str, default='sgd',
                        choices=('sgd', 'adam', 'rmsprop'))
    parser.add_argument('--log_dir', type=str,
                        default='./runs/')
    parser.add_argument('--prof', default=False, action='store_true',
                    help='use profiling to evaluate the performance of model')

    parser.add_argument('--world_size', type=int, default=1)
    parser.add_argument('--rank', type=int, default=0)

    parser.add_argument('--amp', action='store_true', default=False)
    parser.add_argument('--workers', type=int, default=8)

    parser.add_argument('--device', default='npu', type=str, help='npu or gpu')
    parser.add_argument('--pretrained',
                default="none",
                type=str, metavar='PATH',
                help='path to pretrained model')

    parser.add_argument('--output_bin', default='none', type=str)
    parser.add_argument('--output_label', default='none', type=str)
    
    args = parser.parse_args()
    args.save = '../inference_checkpoints/' + args.model + '_checkpoints/' + args.model + '_{}_{}_'.format(
        utils.datestr(), args.dataset_name)
    args.tb_log_dir = '../runs/'
    return args

if __name__ == '__main__':
    main()