import os
import numpy as np
from PIL import Image, ImageDraw
import torch
from torchvision import transforms
import cfg
from label import point_inside_of_quad
from model_VGG import advancedEAST
from preprocess import resize_image
from nms import nms
def sigmoid(x):
"""`y = 1 / (1 + exp(-x))`"""
return 1 / (1 + np.exp(-x))
def cut_text_line(geo, scale_ratio_w, scale_ratio_h, im_array, img_path, s):
geo /= [scale_ratio_w, scale_ratio_h]
p_min = np.amin(geo, axis=0)
p_max = np.amax(geo, axis=0)
min_xy = p_min.astype(int)
max_xy = p_max.astype(int) + 2
sub_im_arr = im_array[min_xy[1]:max_xy[1], min_xy[0]:max_xy[0], :].copy()
for m in range(min_xy[1], max_xy[1]):
for n in range(min_xy[0], max_xy[0]):
if not point_inside_of_quad(n, m, geo, p_min, p_max):
sub_im_arr[m - min_xy[1], n - min_xy[0], :] = 255
sub_im = Image.fromarray(sub_im_arr.astype('uint8')).convert('RGB')
sub_im.save(img_path + '_subim%d.jpg' % s)
def predict(east_detect, img_path, pixel_threshold, quiet=False):
img = Image.open(img_path)
d_wight, d_height = resize_image(img, cfg.max_predict_img_size)
img = img.resize((d_wight, d_height), Image.NEAREST).convert('RGB')
x = transforms.ToTensor()(img)
x = torch.unsqueeze(x, 0)
y = east_detect(x)
y = torch.squeeze(y, 0)
print(y.shape)
y = y.detach().numpy()
if y.shape[0] == 7:
y = y.transpose((1, 2, 0))
y[:, :, :3] = sigmoid(y[:, :, :3])
cond = np.greater_equal(y[:, :, 0], pixel_threshold)
activation_pixels = np.where(cond)
quad_scores, quad_after_nms = nms(y, activation_pixels)
with Image.open(img_path) as im:
im_array = np.array(im.convert('RGB'))
d_wight, d_height = resize_image(im, cfg.max_predict_img_size)
scale_ratio_w = d_wight / im.width
scale_ratio_h = d_height / im.height
im = im.resize((d_wight, d_height), Image.NEAREST).convert('RGB')
quad_im = im.copy()
draw = ImageDraw.Draw(im)
for i, j in zip(activation_pixels[0], activation_pixels[1]):
px = (j + 0.5) * cfg.pixel_size
py = (i + 0.5) * cfg.pixel_size
line_width, line_color = 1, 'red'
if y[i, j, 1] >= cfg.side_vertex_pixel_threshold:
if y[i, j, 2] < cfg.trunc_threshold:
line_width, line_color = 2, 'yellow'
elif y[i, j, 2] >= 1 - cfg.trunc_threshold:
line_width, line_color = 2, 'green'
draw.line([(px - 0.5 * cfg.pixel_size, py - 0.5 * cfg.pixel_size),
(px + 0.5 * cfg.pixel_size, py - 0.5 * cfg.pixel_size),
(px + 0.5 * cfg.pixel_size, py + 0.5 * cfg.pixel_size),
(px - 0.5 * cfg.pixel_size, py + 0.5 * cfg.pixel_size),
(px - 0.5 * cfg.pixel_size, py - 0.5 * cfg.pixel_size)],
width=line_width, fill=line_color)
im.save(img_path + '_act.jpg')
quad_draw = ImageDraw.Draw(quad_im)
txt_items = []
for score, geo, s in zip(quad_scores, quad_after_nms,
range(len(quad_scores))):
if np.amin(score) > 0:
quad_draw.line([tuple(geo[0]),
tuple(geo[1]),
tuple(geo[2]),
tuple(geo[3]),
tuple(geo[0])], width=2, fill='red')
if cfg.predict_cut_text_line:
cut_text_line(geo, scale_ratio_w, scale_ratio_h, im_array,
img_path, s)
rescaled_geo = geo / [scale_ratio_w, scale_ratio_h]
rescaled_geo_list = np.reshape(rescaled_geo, (8,)).tolist()
txt_item = ','.join(map(str, rescaled_geo_list))
txt_items.append(txt_item + '\n')
elif not quiet:
print('quad invalid with vertex num less then 4.')
quad_im.save(img_path + '_predict.jpg')
if cfg.predict_write2txt and len(txt_items) > 0:
with open(img_path[:-4] + '.txt', 'w') as f_txt:
f_txt.writelines(txt_items)
def predict_txt(east_detect, img_path, txt_path, pixel_threshold, quiet=False):
img = Image.open(img_path)
d_wight, d_height = resize_image(img, cfg.max_predict_img_size)
scale_ratio_w = d_wight / img.width
scale_ratio_h = d_height / img.height
transform = transforms.Compose([
transforms.Resize((d_wight, d_height), interpolation=2),
transforms.ToTensor()
])
x = transform(img)
x = torch.unsqueeze(x, 0)
y = east_detect(x)
y = torch.squeeze(y, 0)
print(y.shape)
y = y.detach().numpy()
if y.shape[0] == 7:
y = y.transpose((1, 2, 0))
y[:, :, :3] = sigmoid(y[:, :, :3])
cond = np.greater_equal(y[:, :, 0], pixel_threshold)
activation_pixels = np.where(cond)
quad_scores, quad_after_nms = nms(y, activation_pixels)
txt_items = []
for score, geo in zip(quad_scores, quad_after_nms):
if np.amin(score) > 0:
rescaled_geo = geo / [scale_ratio_w, scale_ratio_h]
rescaled_geo_list = np.reshape(rescaled_geo, (8,)).tolist()
txt_item = ','.join(map(str, rescaled_geo_list))
txt_items.append(txt_item + '\n')
elif not quiet:
print('quad invalid with vertex num less then 4.')
if cfg.predict_write2txt and len(txt_items) > 0:
with open(txt_path, 'w') as f_txt:
f_txt.writelines(txt_items)
if __name__ == '__main__':
if not os.path.exists('demo'):
os.makedirs('./demo', exist_ok=True)
img_path = cfg.img_path
threshold = float(cfg.predict_threshold)
pth_path = cfg.pth_path if cfg.pth_path else 'saved_model/3T736_latest.pth'
print(img_path, threshold)
east = advancedEAST()
state_dict = {k.replace('module.', ''): v for k, v in torch.load(pth_path, map_location='cpu').items()}
east.load_state_dict(state_dict)
predict(east, img_path, threshold)