import numpy as np
import cv2
import torch
from util.config import config as cfg
from util.misc import fill_hole, regularize_sin_cos
from util.misc import norm2, vector_cos, vector_sin
from util.misc import disjoint_merge, merge_polygons
class TextDetector(object):
def __init__(self, model, tr_thresh=0.4, tcl_thresh=0.6):
self.model = model
self.tr_thresh = tr_thresh
self.tcl_thresh = tcl_thresh
model.eval()
def find_innerpoint(self, cont):
"""
generate an inner point of input polygon using mean of x coordinate by:
1. calculate mean of x coordinate(xmean)
2. calculate maximum and minimum of y coordinate(ymax, ymin)
3. iterate for each y in range (ymin, ymax), find first segment in the polygon
4. calculate means of segment
:param cont: input polygon
:return:
"""
xmean = cont[:, 0, 0].mean()
ymin, ymax = cont[:, 0, 1].min(), cont[:, 0, 1].max()
found = False
found_y = []
for i in np.arange(ymin - 1, ymax + 1, 0.5):
in_poly = cv2.pointPolygonTest(cont, (xmean, i), False)
if in_poly > 0:
found = True
found_y.append(i)
if in_poly < 0 and found:
break
if len(found_y) > 0:
return (xmean, np.array(found_y).mean())
else:
for p in range(len(cont)):
point = cont[p, 0]
for i in range(-1, 2, 1):
for j in range(-1, 2, 1):
test_pt = point + [i, j]
if cv2.pointPolygonTest(cont, (int(test_pt[0]), int(test_pt[1])), False) > 0:
return test_pt
def in_contour(self, cont, point):
"""
utility function for judging whether `point` is in the `contour`
:param cont: cv2.findCountour result
:param point: 2d coordinate (x, y)
:return:
"""
x, y = point
return cv2.pointPolygonTest(cont, (int(x), int(y)), False) > 0
def centerlize(self, x, y, H, W, tangent_cos, tangent_sin, tcl_contour, stride=1.):
"""
centralizing (x, y) using tangent line and normal line.
:return: coordinate after centralizing
"""
normal_cos = -tangent_sin
normal_sin = tangent_cos
_x, _y = x, y
while self.in_contour(tcl_contour, (_x, _y)):
_x = _x + normal_cos * stride
_y = _y + normal_sin * stride
if int(_x) >= W or int(_x) < 0 or int(_y) >= H or int(_y) < 0:
break
end1 = np.array([_x, _y])
_x, _y = x, y
while self.in_contour(tcl_contour, (_x, _y)):
_x = _x - normal_cos * stride
_y = _y - normal_sin * stride
if int(_x) >= W or int(_x) < 0 or int(_y) >= H or int(_y) < 0:
break
end2 = np.array([_x, _y])
center = (end1 + end2) / 2
return center
def mask_to_tcl(self, pred_sin, pred_cos, pred_radii, tcl_contour, init_xy, direct=1):
"""
Iteratively find center line in tcl mask using initial point (x, y)
:param pred_sin: predict sin map
:param pred_cos: predict cos map
:param tcl_contour: predict tcl contour
:param init_xy: initial (x, y)
:param direct: direction [-1|1]
:return:
"""
H, W = pred_sin.shape
x_shift, y_shift = init_xy
result = []
max_attempt = 200
attempt = 0
while self.in_contour(tcl_contour, (x_shift, y_shift)):
attempt += 1
sin = pred_sin[int(y_shift), int(x_shift)]
cos = pred_cos[int(y_shift), int(x_shift)]
x_c, y_c = self.centerlize(x_shift, y_shift, H, W, cos, sin, tcl_contour)
sin_c = pred_sin[int(y_c), int(x_c)]
cos_c = pred_cos[int(y_c), int(x_c)]
radii_c = pred_radii[int(y_c), int(x_c)]
result.append(np.array([x_c, y_c, radii_c]))
for shrink in [1/2., 1/4., 1/8., 1/16., 1/32.]:
t = shrink * radii_c
x_shift_pos = x_c + cos_c * t * direct
y_shift_pos = y_c + sin_c * t * direct
x_shift_neg = x_c - cos_c * t * direct
y_shift_neg = y_c - sin_c * t * direct
if len(result) == 1:
x_shift, y_shift = x_shift_pos, y_shift_pos
else:
dist_pos = norm2(result[-2][:2] - (x_shift_pos, y_shift_pos))
dist_neg = norm2(result[-2][:2] - (x_shift_neg, y_shift_neg))
if dist_pos > dist_neg:
x_shift, y_shift = x_shift_pos, y_shift_pos
else:
x_shift, y_shift = x_shift_neg, y_shift_neg
if int(x_shift) >= W or int(x_shift) < 0 or int(y_shift) >= H or int(y_shift) < 0:
continue
if self.in_contour(tcl_contour, (x_shift, y_shift)):
break
if int(x_shift) >= W or int(x_shift) < 0 or int(y_shift) >= H or int(y_shift) < 0:
break
if attempt > max_attempt:
break
return np.array(result)
def build_tcl(self, tcl_pred, sin_pred, cos_pred, radii_pred):
"""
Find TCL's center points and radii of each point
:param tcl_pred: output tcl mask, (512, 512)
:param sin_pred: output sin map, (512, 512)
:param cos_pred: output cos map, (512, 512)
:param radii_pred: output radii map, (512, 512)
:return: (list), tcl array: (n, 3), 3 denotes (x, y, radii)
"""
all_tcls = []
tcl_mask = fill_hole(tcl_pred)
tcl_contours, _ = cv2.findContours(tcl_mask.astype(np.uint8), cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
for cont in tcl_contours:
init = self.find_innerpoint(cont)
if init is None:
continue
x_init, y_init = init
tcl_left = self.mask_to_tcl(sin_pred, cos_pred, radii_pred, cont, (x_init, y_init), direct=1)
tcl_right = self.mask_to_tcl(sin_pred, cos_pred, radii_pred, cont, (x_init, y_init), direct=-1)
tcl = np.concatenate([tcl_left[::-1][:-1], tcl_right])
all_tcls.append(tcl)
return all_tcls
def detect_contours(self, image, tr_pred, tcl_pred, sin_pred, cos_pred, radii_pred):
"""
Input: FCN output, Output: text detection after post-processing
:param image: (np.array) input image (3, H, W)
:param tr_pred: (np.array), text region prediction, (2, H, W)
:param tcl_pred: (np.array), text center line prediction, (2, H, W)
:param sin_pred: (np.array), sin prediction, (H, W)
:param cos_pred: (np.array), cos line prediction, (H, W)
:param radii_pred: (np.array), radii prediction, (H, W)
:return:
(list), tcl array: (n, 3), 3 denotes (x, y, radii)
"""
tr_pred_mask = tr_pred[1] > self.tr_thresh
tcl_pred_mask = tcl_pred[1] > self.tcl_thresh
tcl_mask = tcl_pred_mask * tr_pred_mask
sin_pred, cos_pred = regularize_sin_cos(sin_pred, cos_pred)
detect_result = self.build_tcl(tcl_mask, sin_pred, cos_pred, radii_pred)
return self.postprocessing(image, detect_result, tr_pred_mask)
def detect(self, image):
"""
:param image:
:return:
"""
output = self.model(image)
image = image[0].data.cpu().numpy()
tr_pred = output[0, 0:2].softmax(dim=0).data.cpu().numpy()
tcl_pred = output[0, 2:4].softmax(dim=0).data.cpu().numpy()
sin_pred = output[0, 4].data.cpu().numpy()
cos_pred = output[0, 5].data.cpu().numpy()
radii_pred = output[0, 6].data.cpu().numpy()
contours = self.detect_contours(image, tr_pred, tcl_pred, sin_pred, cos_pred, radii_pred)
output = {
'image': image,
'tr': tr_pred,
'tcl': tcl_pred,
'sin': sin_pred,
'cos': cos_pred,
'radii': radii_pred
}
return contours, output
def merge_contours(self, all_contours):
""" Merge overlapped instances to one instance with disjoint find / merge algorithm
:param all_contours: (list(np.array)), each with (n_points, 2)
:return: (list(np.array)), each with (n_points, 2)
"""
def stride(disks, other_contour, left, step=0.3):
if len(disks) < 2:
return False
if left:
last_point, before_point = disks[:2]
else:
before_point, last_point = disks[-2:]
radius = last_point[2]
cos = vector_cos(last_point[:2] - before_point[:2])
sin = vector_sin(last_point[:2] - before_point[:2])
new_point = last_point[:2] + radius * step * np.array([cos, sin])
return self.in_contour(other_contour, new_point)
def can_merge(disks, other_contour):
return stride(disks, other_contour, left=True) or stride(disks, other_contour, left=False)
F = list(range(len(all_contours)))
for i in range(len(all_contours)):
cont_i, disk_i = all_contours[i]
for j in range(i + 1, len(all_contours)):
cont_j, disk_j = all_contours[j]
if can_merge(disk_i, cont_j):
disjoint_merge(i, j, F)
merged_polygons = merge_polygons([cont for cont, disks in all_contours], F)
return merged_polygons
def postprocessing(self, image, detect_result, tr_pred_mask):
""" convert geometric info(center_x, center_y, radii) into contours
:param image: (np.array), input image
:param result: (list), each with (n, 3), 3 denotes (x, y, radii)
:param tr_pred_mask: (np.array), predicted text area mask, each with shape (H, W)
:return: (np.ndarray list), polygon format contours
"""
all_conts = []
for disk in detect_result:
reconstruct_mask = np.zeros(image.shape[1:], dtype=np.uint8)
for x, y, r in disk:
if cfg.post_process_expand > 0.0:
r *= (1. + cfg.post_process_expand)
cv2.circle(reconstruct_mask, (int(x), int(y)), max(1, int(r)), 1, -1)
if (reconstruct_mask * tr_pred_mask).sum() < reconstruct_mask.sum() * 0.5:
continue
conts, _ = cv2.findContours(reconstruct_mask, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
if len(conts) > 1:
conts.sort(key=lambda x: cv2.contourArea(x), reverse=True)
elif not conts:
continue
all_conts.append((conts[0][:, 0, :], disk))
if cfg.post_process_merge:
all_conts = self.merge_contours(all_conts)
else:
all_conts = [cont[0] for cont in all_conts]
return all_conts