import os
import sys
import ast
import cv2
import argparse
import numpy as np
from ais_bench.infer.interface import InferSession
from sam_preprocessing_pytorch import encoder_preprocessing, decoder_preprocessing
from sam_postprocessing_pytorch import sam_postprocessing
def check_device_range_valid(value):
min_value = 0
max_value = 255
if ',' in value:
ilist = [ int(v) for v in value.split(',') ]
for ivalue in ilist[:2]:
if ivalue < min_value or ivalue > max_value:
raise argparse.ArgumentTypeError("{} of device:{} is invalid. valid value range is [{}, {}]".format(
ivalue, value, min_value, max_value))
return ilist[:2]
else:
ivalue = int(value)
if ivalue < min_value or ivalue > max_value:
raise argparse.ArgumentTypeError("device:{} is invalid. valid value range is [{}, {}]".format(
ivalue, min_value, max_value))
return ivalue
def save_mask(mask, image, src_path, save_path, random_color=False):
if random_color:
color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
else:
color = np.array([0.1176, 0.5647, 1, 0.6])
h, w = mask.shape[-2: ]
mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
mask_image = mask_image[:, :, 0:3]
mask_img = mask_image * 200 + image
image_name = src_path.split('/')[-1].split('.')[0]
cv2.imwrite(save_path + "/" + image_name + "_result.jpg", mask_img)
def encoder_infer(session_encoder, x):
encoder_outputs = session_encoder.infer([x])
image_embedding = encoder_outputs[0]
return image_embedding
def decoder_infer(session_decoder, decoder_inputs):
decoder_outputs = session_decoder.infer(decoder_inputs, mode="dymdims", custom_sizes=[1000, 1000000])
low_res_masks = decoder_outputs[1]
return low_res_masks
def sam_infer(src_path, session_encoder, session_decoder, input_point=None, box=None, save_path="./"):
image = cv2.imread(src_path)
x = encoder_preprocessing(image)
image_embedding = encoder_infer(session_encoder, x)
decoder_inputs = decoder_preprocessing(image_embedding, input_point=input_point, box=box, image=image)
low_res_masks = decoder_infer(session_decoder, decoder_inputs)
masks = sam_postprocessing(low_res_masks, image)
save_mask(masks, image, src_path, save_path, random_color=True)
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--src-path', type=str, default='./datasets/demo.jpg', help='input path to image')
parser.add_argument('--save-path', type=str, default='./outputs', help='output path to image')
parser.add_argument('--encoder-model-path', type=str, default='./models/encoder.om', help='path to encoder model')
parser.add_argument('--decoder-model-path', type=str, default='./models/decoder.om', help='path to decoder model')
parser.add_argument('--input-point', type=ast.literal_eval, default=[[500, 375], [1125, 625], [1520, 625]], help='input points')
parser.add_argument('--device-id', type=check_device_range_valid, default=0, help='NPU device id. Give 2 ids to enable parallel inferencing.')
args = parser.parse_args()
if not os.path.exists(args.save_path):
os.makedirs(os.path.realpath(args.save_path), mode=0o744)
session_encoder = InferSession(args.device_id, args.encoder_model_path)
session_decoder = InferSession(args.device_id, args.decoder_model_path)
sam_infer(args.src_path, session_encoder, session_decoder, input_point=args.input_point, save_path=args.save_path)
if __name__ == '__main__':
main()