import argparse
import torch
import torchvision
from torchvision import datasets, transforms
from PIL import Image
from lib.PraNet_Res2Net import PraNet
import torch.nn.functional as F
import numpy as np
class test_dataset:
def __init__(self, path):
self.testsize = 352
self.data_path = '{}/images/cju0u82z3cuma0835wlxrnrjv.png'.format(path)
self.gt_path = '{}/masks/cju0u82z3cuma0835wlxrnrjv.png'.format(path)
self.transform = transforms.Compose([
transforms.Resize((self.testsize, self.testsize)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406],
[0.229, 0.224, 0.225])])
self.gt_transform = transforms.ToTensor()
def load_data(self):
image = self.rgb_loader(self.data_path)
image = self.transform(image).unsqueeze(0)
gt = self.binary_loader(self.gt_path)
return image, gt
def rgb_loader(self, path):
with open(path, 'rb') as f:
img = Image.open(f)
return img.convert('RGB')
def binary_loader(self, path):
with open(path, 'rb') as f:
img = Image.open(f)
return img.convert('L')
def test():
path = './data/TestDataset/Kvasir'
loc = 'npu:0'
loc_cpu = 'cpu'
torch.npu.set_device(loc)
model = PraNet()
pretrained_dict = torch.load("./snapshots/PraNet_Res2Net/PraNet-19.pth", map_location="cpu")
model.load_state_dict({k.replace('module.',''):v for k, v in pretrained_dict.items()})
if "fc.weight" in pretrained_dict:
pretrained_dict.pop('fc.weight')
pretrained_dict.pop('fc.bias')
model.load_state_dict(pretrained_dict, strict=False)
model = model.to(loc)
model.eval()
test_loader = test_dataset(path)
image, gt = test_loader.load_data()
gt = np.asarray(gt, np.float32)
gt /= (gt.max() + 1e-8)
image = image.npu()
res5, res4, res3, res2 = model(image)
res = res2
res = F.upsample(res, size=gt.shape, mode='bilinear', align_corners=False)
res = res.sigmoid().data.cpu().numpy().squeeze()
res = (res - res.min()) / (res.max() - res.min() + 1e-8)
print('res is ', res)
if __name__ == "__main__":
test()