import os
from pathlib import Path
import tqdm
import yaml
import numpy as np
import torch
import segm.utils.torch as ptu
from segm.data.factory import create_dataset
from segm.model.utils import resize, sliding_window
from segm import config
def preprocess(variant_path, save_path, gt_file_path):
with open(variant_path, "r") as f:
variant = yaml.load(f, Loader=yaml.FullLoader)
cfg = config.load_config()
dataset_cfg = cfg["dataset"]["cityscapes"]
normalization = variant["dataset_kwargs"]["normalization"]
im_size = dataset_cfg.get("im_size", variant["dataset_kwargs"]["image_size"])
window_size = variant["dataset_kwargs"]["crop_size"]
window_stride = variant["dataset_kwargs"]["crop_size"] - 32
C=19
dataset_kwargs = dict(
dataset="cityscapes",
image_size=im_size,
crop_size=im_size,
patch_size=16,
batch_size=1,
num_workers=10,
split="val",
normalization=normalization,
crop=False,
rep_aug=False,
)
db = create_dataset(dataset_kwargs)
seg_gt_maps = db.dataset.get_gt_seg_maps()
with open(gt_file_path, 'w', encoding='utf-8') as f:
for img_info in db.dataset.dataset.img_infos:
gt_path = Path(db.dataset.dataset.ann_dir) / img_info["ann"]["seg_map"]
line = f"{Path(img_info['filename']).name}\t{gt_path.__str__()}\n"
f.write(line)
im_size = dataset_kwargs["image_size"]
cat_names = db.base_dataset.names
n_cls = db.unwrapped.n_cls
for batch in tqdm.tqdm(db, desc='Processing'):
colors = batch["colors"]
ims = batch["im"]
ims_metas = batch["im_metas"]
ori_shape = ims_metas[0]["ori_shape"]
ori_shape = (ori_shape[0].item(), ori_shape[1].item())
filename = batch["im_metas"][0]["ori_filename"][0]
seg_map = torch.zeros((C, ori_shape[0], ori_shape[1]), device=ptu.device)
for im, im_metas in zip(ims, ims_metas):
im = im.to(ptu.device)
im = resize(im, window_size)
flip = im_metas["flip"]
windows = sliding_window(im, flip, window_size, window_stride)
crops = torch.stack(windows.pop("crop"))[:, 0]
B = len(crops)
WB = 1
seg_maps = torch.zeros((B, C, window_size, window_size), device=im.device)
cnt = 0
for i in range(0, B, WB):
input_tensor = crops[i : i + WB]
ori_image_name = im_metas['ori_filename'][0]
img = np.array(input_tensor).astype(np.float32)
filename = Path(ori_image_name).stem + "_" + str(cnt)
img.tofile(os.path.join(save_path, filename + ".bin"))
cnt += 1
def main():
import argparse
parser = argparse.ArgumentParser('image convert to binary file.')
parser.add_argument('--cfg-path', type=str, required=True,
help='path to model config file.')
parser.add_argument('--data-root', type=str, required=True,
help='path to parent directory of cityscapes dataset.')
parser.add_argument('--bin-dir', type=str, required=True,
help='directory to save binary file.')
parser.add_argument('--gt-path', type=str, required=True,
help='save the mapping of (srcImage, labelImage)')
args = parser.parse_args()
os.environ['DATASET'] = args.data_root
if not Path(args.bin_dir).is_dir():
Path(args.bin_dir).mkdir(parents=True, exist_ok=True)
preprocess(args.cfg_path, args.bin_dir, args.gt_path)
if __name__ == '__main__':
main()