import sys
sys.path.append('./PraNet')
import numpy as np
import torch.nn.functional as F
import torch
import os
from glob import glob
from scipy import misc
from utils.dataloader import test_dataset
testsize = 352
def test(pred_dir, save_path, data_path):
bin_images = glob(os.path.join(pred_dir, '*_3.bin'))
bin_images = sorted(bin_images)
image_root = '{}/images/'.format(data_path)
gt_root = '{}/masks/'.format(data_path)
test_loader = test_dataset(image_root, gt_root, testsize)
os.makedirs(save_path, exist_ok=True)
for path in bin_images:
image, gt, name = test_loader.load_data()
gt = np.asarray(gt, np.float32)
gt /= (gt.max() + 1e-8)
res2 = np.fromfile(path, np.float32)
res = np.reshape(res2, (1, 1, 352, 352))
res = torch.from_numpy(res)
res = F.upsample(res, size=gt.shape, mode='bilinear', align_corners=False)
res = res.sigmoid().data.cpu().numpy().squeeze()
res = (res - res.min()) / (res.max() - res.min() + 1e-8)
misc.imsave(save_path+name, res)
print("Postprocess Done !")
if __name__ == "__main__":
data_path = sys.argv[1]
pred_dir = sys.argv[2]
save_path = sys.argv[3]
test(pred_dir, save_path,data_path)