import argparse
import copy
import os
import cv2
import shutil
import math
import numpy as np
import torch
import torch.nn.functional as F
from ctpn import config
from ctpn.ctpn import CTPN_Model
from ctpn.utils import gen_anchor, transform_bbox, clip_bbox, filter_bbox, nms, TextProposalConnectorOriented
from tqdm import tqdm
parser = argparse.ArgumentParser(description='PyTorch CTPN Testing')
parser.add_argument('--data-path', default='/home/dockerHome/ctpn/ctpn_8p/imagedata/', type=str,
help='number of data loading workers (default: 4)')
parser.add_argument('--model-path', default='./output_models/checkpoint-200.pth.tar', type=str,
help='number of total epochs to run')
args = parser.parse_args()
def load_state_dict(state_dicts):
new_state_dicts = {}
for k, v in state_dicts.items():
if (k[:7] == 'module.'):
name = k[7:]
else:
name = k
new_state_dicts[name] = v
return new_state_dicts
long_len = 1008
device = 'npu:0'
weights = args.model_path
model = CTPN_Model().to(device)
state_dict = torch.load(weights, map_location='cpu')['state_dict']
new_state_dict = load_state_dict(state_dict)
model.load_state_dict(new_state_dict)
model.eval()
def get_text_boxes(image, img_name=None, display=True, prob_thresh=0.5):
h, w = image.shape[:2]
image_c = copy.deepcopy(image)
image = image[:, :, ::-1]
rescale_fac = max(h, w) / long_len
if rescale_fac > 1.0:
h = int(h / rescale_fac)
w = int(w / rescale_fac)
image = cv2.resize(image, (w, h))
h, w = image.shape[:2]
target_w = int(math.ceil(w / 16)) * 16
target_h = int(math.ceil(h / 16)) * 16
w_pad = (target_w - w) // 2
h_pad = (target_h - h) // 2
w_pad2 = target_w - w - w_pad
h_pad2 = target_h - h - h_pad
image = np.pad(image, ((h_pad, h_pad2), (w_pad, w_pad2), (0, 0)))
image = image.astype(np.float32) - config.IMAGE_MEAN
image = torch.from_numpy(image.transpose(2, 0, 1)).unsqueeze(0).float().to(device)
h, w = target_h, target_w
with torch.no_grad():
cls, regr = model(image)
cls_prob = F.softmax(cls, dim=-1).cpu().numpy()
regr = regr.cpu().numpy()
anchor = gen_anchor((int(h / 16), int(w / 16)), 16)
bbox = transform_bbox(anchor, regr)
bbox = clip_bbox(bbox, [h, w])
fg = np.where(cls_prob[0, :, 1] > prob_thresh)[0]
select_anchor = bbox[fg, :]
select_score = cls_prob[0, fg, 1]
select_anchor = select_anchor.astype(np.int32)
keep_index = filter_bbox(select_anchor, 16)
select_anchor = select_anchor[keep_index]
select_score = select_score[keep_index]
select_score = np.reshape(select_score, (select_score.shape[0], 1))
nmsbox = np.hstack((select_anchor, select_score))
keep = nms(nmsbox, 0.3)
select_anchor = select_anchor[keep]
select_score = select_score[keep]
textConn = TextProposalConnectorOriented()
text = textConn.get_text_lines(select_anchor, select_score, [h, w])
text_n = copy.deepcopy(text)
text[:, [0, 2, 4, 6]] = text[:, [0, 2, 4, 6]] - w_pad
text[:, [1, 3, 5, 7]] = text[:, [1, 3, 5, 7]] - h_pad
if rescale_fac > 1.0:
text_ori = text * rescale_fac
else:
text_ori = text
if display:
for i in text_ori:
s = str(round(i[-1] * 100, 2)) + '%'
i = [int(j) for j in i]
cv2.line(image_c, (i[0], i[1]), (i[2], i[3]), (0, 255, 0), 2)
cv2.line(image_c, (i[0], i[1]), (i[4], i[5]), (0, 255, 0), 2)
cv2.line(image_c, (i[6], i[7]), (i[2], i[3]), (0, 255, 0), 2)
cv2.line(image_c, (i[4], i[5]), (i[6], i[7]), (0, 255, 0), 2)
cv2.putText(image_c, s, (i[0] + 13, i[1] + 13), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 255), 2,
cv2.LINE_AA)
cv2.imwrite(f'./display/{img_name}', image_c)
return text_n, text_ori, image_c
def make_dir(path):
if os.path.exists(path):
shutil.rmtree(path)
os.mkdir(path)
def predict_2_txt():
img_name_list = os.listdir(os.path.join(args.data_path, 'Challenge2_Test_Task12_Images'))
make_dir('./display/')
make_dir('./results/')
make_dir('./results/predict_txt_2/')
for k in tqdm(range(len(img_name_list))):
img_path = os.path.join(args.data_path, 'Challenge2_Test_Task12_Images/{}'.format(img_name_list[k]))
input_img = cv2.imread(img_path)
text, text_ori, out_img = get_text_boxes(input_img, img_name=img_name_list[k], display=True)
with open('./results/predict_txt_2/res_{}.txt'.format(img_name_list[k][:-4]), 'w') as f:
for i in range(text_ori.shape[0]):
x_min, y_min = min(text_ori[i][0:8:2]), min(text_ori[i][1:8:2])
x_max, y_max = max(text_ori[i][0:8:2]), max(text_ori[i][1:8:2])
f.write('{},{},{},{}'.format(int(x_min), int(y_min), int(x_max), int(y_max)))
f.write('\n')
if __name__ == '__main__':
predict_2_txt()