import torch
import argparse
from config import get_config
from classification import build_model
from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.data.transforms import _pil_interp
from torchvision import transforms
def parse_option():
parser = argparse.ArgumentParser('Focal Transformer training and evaluation script', add_help=False)
parser.add_argument('--cfg', default='./configs/focalv2_small_useconv_patch4_window7_224.yaml', type=str, metavar="FILE", help='path to config file', )
parser.add_argument(
"--opts",
help="Modify config options by adding 'KEY VALUE' pairs. ",
default=None,
nargs='+',
)
parser.add_argument('--batch-size', default=64, type=int, help="batch size for single GPU")
parser.add_argument('--dataset', type=str, default='imagenet', help='dataset name')
parser.add_argument('--data-path', type=str, help='path to dataset')
parser.add_argument('--zip', action='store_true', help='use zipped dataset instead of folder dataset')
parser.add_argument('--cache-mode', type=str, default='part', choices=['no', 'full', 'part'],
help='no: no cache, '
'full: cache all data, '
'part: sharding the dataset into nonoverlapping pieces and only cache one piece')
parser.add_argument('--resume', help='resume from checkpoint')
parser.add_argument('--accumulation-steps', type=int, help="gradient accumulation steps")
parser.add_argument('--use-checkpoint', action='store_true',
help="whether to use gradient checkpointing to save memory")
parser.add_argument('--amp-opt-level', type=str, default='O1', choices=['O0', 'O1', 'O2'],
help='mixed precision opt level, if O0, no amp is used')
parser.add_argument('--output', default='output', type=str, metavar='PATH',
help='root of output folder, the full path is <output>/<model_name>/<tag> (default: output)')
parser.add_argument('--tag', help='tag of experiment')
parser.add_argument('--eval', action='store_true', help='Perform evaluation only')
parser.add_argument('--throughput', action='store_true', help='Test throughput only')
parser.add_argument('--debug', action='store_true', help='Perform debug only')
parser.add_argument("--local_rank", type=int, default=1, help='local rank for DistributedDataParallel')
args, unparsed = parser.parse_known_args()
config = get_config(args)
return args, config
def build_transform():
t = []
size = int((256 / 224) * 224)
t.append(
transforms.Resize(size, interpolation=_pil_interp('bicubic')),
)
t.append(transforms.CenterCrop(224))
t.append(transforms.ToTensor())
t.append(transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD))
return transforms.Compose(t)
def get_raw_data():
from PIL import Image
from urllib.request import urlretrieve
IMAGE_URL = 'https://bbs-img.huaweicloud.com/blogs/img/thumb/1591951315139_8989_1363.png'
urlretrieve(IMAGE_URL, 'tmp.jpg')
img = Image.open("tmp.jpg")
img = img.convert('RGB')
return img
def test():
loc = 'npu:0'
loc_cpu = 'cpu'
torch.npu.set_device(loc)
checkpoint = torch.load("./focalv2-small-useconv-is224-ws7.pth", map_location=loc)
_, config = parse_option()
model = build_model(config)
model.load_state_dict(checkpoint['model'], strict=True)
model = model.to(loc)
model.eval()
rd = get_raw_data()
preprocess = build_transform()
inputs = preprocess(rd)
inputs = inputs.unsqueeze(0)
inputs = inputs.to(loc)
output = model(inputs)
output = output.to(loc_cpu)
_, pred = output.topk(1, 1, True, True)
result = torch.argmax(output, 1)
print("class: ", pred[0][0].item())
print(result)
if __name__ == "__main__":
test()