@@ -86,7 +86,7 @@ def do_parse(
image_dir = str(os.path.basename(local_image_dir))
content_list = pipeline_union_make(pdf_info, MakeMode.CONTENT_LIST, image_dir)
md_writer.write_string(
- f"{pdf_file_name}_content_list.json",
+ f"{pdf_file_name}_content.json",
json.dumps(content_list, ensure_ascii=False, indent=4),
)
@@ -142,7 +142,8 @@ def do_parse(
image_dir = str(os.path.basename(local_image_dir))
content_list = vlm_union_make(pdf_info, MakeMode.CONTENT_LIST, image_dir)
md_writer.write_string(
- f"{pdf_file_name}_content_list.json",
+ # f"{pdf_file_name}_content_list.json",
+ f"{pdf_file_name}_content.json", ## 文件名太长了,linux文件系统ext4超过255字节无法保存
json.dumps(content_list, ensure_ascii=False, indent=4),
)
@@ -3,6 +3,9 @@ from loguru import logger
from tqdm import tqdm
from collections import defaultdict
import numpy as np
+import time
+import torch
+import torch_npu
from .model_init import AtomModelSingleton
from ...utils.config_reader import get_formula_enable, get_table_enable
@@ -95,6 +98,7 @@ class BatchAnalyze:
})
# OCR检测处理
+ from concurrent.futures import ThreadPoolExecutor, as_completed
if self.enable_ocr_det_batch:
# 批处理模式 - 按语言和分辨率分组
# 收集所有需要OCR检测的裁剪图像
@@ -139,79 +143,73 @@ class BatchAnalyze:
)
# 按分辨率分组并同时完成padding
+ stride = 64
resolution_groups = defaultdict(list)
for crop_info in lang_crop_list:
cropped_img = crop_info[0]
h, w = cropped_img.shape[:2]
# 使用更大的分组容差,减少分组数量
# 将尺寸标准化到32的倍数
- normalized_h = ((h + 32) // 32) * 32 # 向上取整到32的倍数
- normalized_w = ((w + 32) // 32) * 32
+ normalized_h = ((h + stride) // stride) * stride # 向上取整到stride的倍数
+ normalized_w = ((w + stride) // stride) * stride
group_key = (normalized_h, normalized_w)
resolution_groups[group_key].append(crop_info)
- # 对每个分辨率组进行批处理
- for group_key, group_crops in tqdm(resolution_groups.items(), desc=f"OCR-det {lang}"):
-
- # 计算目标尺寸(组内最大尺寸,向上取整到32的倍数)
- max_h = max(crop_info[0].shape[0] for crop_info in group_crops)
- max_w = max(crop_info[0].shape[1] for crop_info in group_crops)
- target_h = ((max_h + 32 - 1) // 32) * 32
- target_w = ((max_w + 32 - 1) // 32) * 32
-
- # 对所有图像进行padding到统一尺寸
- batch_images = []
- for crop_info in group_crops:
- img = crop_info[0]
- h, w = img.shape[:2]
- # 创建目标尺寸的白色背景
- padded_img = np.ones((target_h, target_w, 3), dtype=np.uint8) * 255
- # 将原图像粘贴到左上角
- padded_img[:h, :w] = img
- batch_images.append(padded_img)
-
- # 批处理检测
- det_batch_size = min(len(batch_images), self.batch_ratio * OCR_DET_BASE_BATCH_SIZE) # 增加批处理大小
- # logger.debug(f"OCR-det batch: {det_batch_size} images, target size: {target_h}x{target_w}")
- batch_results = ocr_model.text_detector.batch_predict(batch_images, det_batch_size)
-
- # 处理批处理结果
- for i, (crop_info, (dt_boxes, elapse)) in enumerate(zip(group_crops, batch_results)):
- new_image, useful_list, ocr_res_list_dict, res, adjusted_mfdetrec_res, _lang = crop_info
-
- if dt_boxes is not None and len(dt_boxes) > 0:
- # 直接应用原始OCR流程中的关键处理步骤
- from mineru.utils.ocr_utils import (
- merge_det_boxes, update_det_boxes, sorted_boxes
- )
- # 1. 排序检测框
- if len(dt_boxes) > 0:
- dt_boxes_sorted = sorted_boxes(dt_boxes)
- else:
- dt_boxes_sorted = []
-
- # 2. 合并相邻检测框
- if dt_boxes_sorted:
- dt_boxes_merged = merge_det_boxes(dt_boxes_sorted)
- else:
- dt_boxes_merged = []
-
- # 3. 根据公式位置更新检测框(关键步骤!)
- if dt_boxes_merged and adjusted_mfdetrec_res:
- dt_boxes_final = update_det_boxes(dt_boxes_merged, adjusted_mfdetrec_res)
- else:
- dt_boxes_final = dt_boxes_merged
-
- # 构造OCR结果格式
- ocr_res = [box.tolist() if hasattr(box, 'tolist') else box for box in dt_boxes_final]
-
- if ocr_res:
- ocr_result_list = get_ocr_result_list(
- ocr_res, useful_list, ocr_res_list_dict['ocr_enable'], new_image, _lang
- )
-
- ocr_res_list_dict['layout_res'].extend(ocr_result_list)
+ def _run_one_group_ocr(group_key, group_crops):
+
+ max_h = max(ci[0].shape[0] for ci in group_crops)
+ max_w = max(ci[0].shape[1] for ci in group_crops)
+ target_h = ((max_h + stride - 1) // stride) * stride
+ target_w = ((max_w + stride - 1) // stride) * stride
+
+ batch_images = []
+ for ci in group_crops:
+ img = ci[0]
+ h, w = img.shape[:2]
+ padded_img = np.ones((target_h, target_w, 3), dtype=np.uint8) * 255
+ padded_img[:h, :w] = img
+ batch_images.append(padded_img)
+
+ det_batch_size = min(len(batch_images), self.batch_ratio * OCR_DET_BASE_BATCH_SIZE)
+
+ batch_results = ocr_model.text_detector.batch_predict(batch_images, det_batch_size)
+
+ for i, (ci, (dt_boxes, elapse)) in enumerate(zip(group_crops, batch_results)):
+ new_image, useful_list, ocr_res_list_dict, res, adjusted_mfdetrec_res, _lang = ci
+ if dt_boxes is not None and len(dt_boxes) > 0:
+ from mineru.utils.ocr_utils import merge_det_boxes, update_det_boxes, sorted_boxes
+
+ if len(dt_boxes) > 0:
+ dt_boxes_sorted = sorted_boxes(dt_boxes)
+ else:
+ dt_boxes_sorted = []
+
+ if dt_boxes_sorted:
+ dt_boxes_merged = merge_det_boxes(dt_boxes_sorted)
+ else:
+ dt_boxes_merged = []
+
+ if dt_boxes_merged and adjusted_mfdetrec_res:
+ dt_boxes_final = update_det_boxes(dt_boxes_merged, adjusted_mfdetrec_res)
+ else:
+ dt_boxes_final = dt_boxes_merged
+
+ ocr_res = [box.tolist() if hasattr(box, 'tolist') else box for box in dt_boxes_final]
+ if ocr_res:
+ ocr_result_list = get_ocr_result_list(
+ ocr_res, useful_list, ocr_res_list_dict['ocr_enable'], new_image, _lang
+ )
+ ocr_res_list_dict['layout_res'].extend(ocr_result_list)
+
+ MAX_WORKERS = 4
+ start = time.time()
+ with ThreadPoolExecutor(max_workers=MAX_WORKERS) as ex:
+ futures = [ex.submit(_run_one_group_ocr, gk, gcs) for gk, gcs in resolution_groups.items()]
+ for f in as_completed(futures):
+ f.result()
+ end = time.time()
+ logger.info(f"ocr det run time : {end -start}")
else:
# 原始单张处理模式
for ocr_res_list_dict in tqdm(ocr_res_list_all_page, desc="OCR-det Predict"):
@@ -247,7 +245,7 @@ class BatchAnalyze:
# 表格识别 table recognition
if self.table_enable:
- for table_res_dict in tqdm(table_res_list_all_page, desc="Table Predict"):
+ def _run_one_group_table(table_res_dict):
_lang = table_res_dict['lang']
table_model = atom_model_manager.get_atom_model(
atom_model_name='table',
@@ -271,6 +269,16 @@ class BatchAnalyze:
'table recognition processing fails, not get html return'
)
+
+ MAX_WORKERS = 4
+ start = time.time()
+ with ThreadPoolExecutor(max_workers=MAX_WORKERS) as ex:
+ futures = [ex.submit(_run_one_group_table, table_res_dict) for table_res_dict in table_res_list_all_page]
+ for f in as_completed(futures):
+ f.result()
+ end = time.time()
+ logger.info(f"table run time : {end - start}")
+
# Create dictionaries to store items by language
need_ocr_lists_by_lang = {} # Dict of lists for each language
img_crop_lists_by_lang = {} # Dict of lists for each language
@@ -10,6 +10,20 @@ from fastapi.responses import JSONResponse
from typing import List, Optional
from loguru import logger
from base64 import b64encode
+import time
+
+from mineru.backend.pipeline.model_list import AtomicModel
+from mineru.utils.torchair_utils import (
+ get_pdf_page_count,
+ rewrite_model_init,
+ set_batch_candidate,
+ )
+from mineru.utils.model_utils import get_vram
+from mineru.backend.pipeline.batch_analyze import (
+ YOLO_LAYOUT_BASE_BATCH_SIZE,
+ MFD_BASE_BATCH_SIZE,
+ MFR_BASE_BATCH_SIZE,
+ )
from mineru.cli.common import aio_do_parse, read_fn, pdf_suffixes, image_suffixes
from mineru.utils.cli_parser import arg_parse
@@ -18,6 +32,8 @@ from mineru.version import __version__
app = FastAPI()
app.add_middleware(GZipMiddleware, minimum_size=1000)
+batch_ratio = 16
+
def encode_image(image_path: str) -> str:
"""Encode image using base64"""
with open(image_path, "rb") as f:
@@ -64,6 +80,8 @@ async def parse_pdf(
pdf_file_names = []
pdf_bytes_list = []
+ pdfs_page_count = 0
+
for file in files:
content = await file.read()
file_path = Path(file.filename)
@@ -74,12 +92,14 @@ async def parse_pdf(
temp_path = Path(unique_dir) / file_path.name
with open(temp_path, "wb") as f:
f.write(content)
+ pdfs_page_count += get_pdf_page_count(temp_path)
try:
pdf_bytes = read_fn(temp_path)
pdf_bytes_list.append(pdf_bytes)
pdf_file_names.append(file_path.stem)
- os.remove(temp_path) # 删除临时文件
+ abs_path=os.path.abspath(temp_path)
+ os.remove(abs_path) # 删除临时文件
except Exception as e:
return JSONResponse(
status_code=400,
@@ -91,13 +111,20 @@ async def parse_pdf(
content={"error": f"Unsupported file type: {file_path.suffix}"}
)
+ batch_candidate = {
+ AtomicModel.Layout: [YOLO_LAYOUT_BASE_BATCH_SIZE, pdfs_page_count % YOLO_LAYOUT_BASE_BATCH_SIZE],
+ AtomicModel.MFD: [MFD_BASE_BATCH_SIZE, pdfs_page_count % MFD_BASE_BATCH_SIZE],
+ AtomicModel.MFR: batch_ratio * MFR_BASE_BATCH_SIZE,
+ }
+ set_batch_candidate(batch_candidate)
+
# 设置语言列表,确保与文件数量一致
actual_lang_list = lang_list
if len(actual_lang_list) != len(pdf_file_names):
# 如果语言列表长度不匹配,使用第一个语言或默认"ch"
actual_lang_list = [actual_lang_list[0] if actual_lang_list else "ch"] * len(pdf_file_names)
-
+ start_time = time.time()
# 调用异步处理函数
await aio_do_parse(
output_dir=unique_dir,
@@ -120,6 +147,9 @@ async def parse_pdf(
end_page_id=end_page_id,
**config
)
+ stop_time = time.time()
+ print(f"total process time: {(stop_time-start_time):.2f}s")
+ print(f"per page process time: {(stop_time-start_time)/pdfs_page_count:.2f}s")
# 构建结果路径
result_dict = {}
@@ -171,11 +201,12 @@ async def parse_pdf(
@click.command(context_settings=dict(ignore_unknown_options=True, allow_extra_args=True))
@click.pass_context
@click.option('--host', default='127.0.0.1', help='Server host (default: 127.0.0.1)')
-@click.option('--port', default=8000, type=int, help='Server port (default: 8000)')
+@click.option('--port', default=6543, type=int, help='Server port (default: 6543)')
@click.option('--reload', is_flag=True, help='Enable auto-reload (development mode)')
def main(ctx, host, port, reload, **kwargs):
kwargs.update(arg_parse(ctx))
+ os.environ['MINERU_MODEL_SOURCE'] = 'local'
# 将配置参数存储到应用状态中
app.state.config = kwargs
@@ -66,6 +66,7 @@ class DocLayoutYOLOModel:
conf=self.conf,
iou=self.iou,
verbose=False,
+ half=True
)
for pred in predictions:
results.append(self._parse_prediction(pred))
@@ -31,7 +31,8 @@ class YOLOv8MFDModel:
conf=self.conf,
iou=self.iou,
verbose=False,
- device=self.device
+ device=self.device,
+ half=True
)
return [pred.cpu() for pred in preds] if is_batch else preds[0].cpu()
@@ -1,7 +1,7 @@
import torch
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm
-
+import numpy as np
class MathDataset(Dataset):
def __init__(self, image_paths, transform=None):
@@ -61,7 +61,7 @@ class UnimernetModel(object):
res["latex"] = latex
return formula_list
- def batch_predict(self, images_mfd_res: list, images: list, batch_size: int = 64) -> list:
+ def _batch_predict(self, images_mfd_res: list, images: list, batch_size: int = 64) -> list:
images_formula_list = []
mf_image_list = []
backfill_list = []
@@ -137,3 +137,94 @@ class UnimernetModel(object):
res["latex"] = latex
return images_formula_list
+
+
+ def batch_predict(self, images_mfd_res: list, images: list, batch_size: int = 64) -> list:
+
+ images_formula_list = []
+ mf_image_list = []
+ backfill_list = []
+ image_info = [] # Store (area, original_index, image) tuples
+
+ # Collect images with their original indices
+ for image_index in range(len(images_mfd_res)):
+ mfd_res = images_mfd_res[image_index]
+ pil_img = images[image_index]
+ # split代替多次索引
+ data = mfd_res.boxes.data.numpy()
+ xyxy, conf, cla = np.split(data, [4, 5], axis=-1)
+
+ cla = cla.reshape(-1).astype(int).tolist()
+ conf = np.round(conf.reshape(-1).astype(float), 2).tolist()
+
+ xyxy = xyxy.astype(np.int32)
+ xmin, ymin, xmax, ymax = xyxy[:, 0], xyxy[:, 1], xyxy[:, 2], xyxy[:, 3]
+ # area 直接矩阵运算
+ areas = (xmax - xmin) * (ymax - ymin)
+
+ num_boxes = len(conf)
+
+ formula_list = []
+ for i in range(num_boxes):
+ xmin_i, ymin_i, xmax_i, ymax_i = xyxy[i].tolist()
+ formula_list.append({
+ "category_id": 13 + cla[i],
+ "poly": [xmin_i, ymin_i, xmax_i, ymin_i,
+ xmax_i, ymax_i, xmin_i, ymax_i],
+ "score": conf[i],
+ "latex": "",
+ })
+
+ # bbox_img 截取
+ # bbox_img = pil_img[:, ymin_i:ymax_i, xmin_i:xmax_i]
+ bbox_img = pil_img.crop((xmin_i, ymin_i, xmax_i, ymax_i))
+ curr_idx = len(mf_image_list)
+ image_info.append((areas[i], curr_idx, bbox_img))
+ mf_image_list.append(bbox_img)
+
+ images_formula_list.append(formula_list)
+ backfill_list += formula_list
+
+ # Stable sort by area
+ image_info.sort(key=lambda x: x[0]) # sort by area
+ sorted_indices = [x[1] for x in image_info]
+ sorted_images = [x[2] for x in image_info]
+
+ # Create mapping for results
+ index_mapping = {new_idx: old_idx for new_idx, old_idx in enumerate(sorted_indices)}
+
+ # Create dataset with sorted images
+ dataset = MathDataset(sorted_images, transform=self.model.transform)
+
+ # 如果batch_size > len(sorted_images),则设置为不超过len(sorted_images)的2的幂
+ batch_size = min(batch_size, max(1, 2 ** (len(sorted_images).bit_length() - 1))) if sorted_images else 1
+
+ dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=0)
+
+ # Process batches and store results
+ mfr_res = []
+ # for mf_img in dataloader:
+
+ with tqdm(total=len(sorted_images), desc="MFR Predict") as pbar:
+ for index, mf_img in enumerate(dataloader):
+ mf_img = mf_img.to(dtype=self.model.dtype)
+ mf_img = mf_img.to(self.device)
+ with torch.no_grad():
+ output = self.model.generate({"image": mf_img}, batch_size=batch_size)
+ mfr_res.extend(output["fixed_str"])
+
+ # 更新进度条,每次增加batch_size,但要注意最后一个batch可能不足batch_size
+ current_batch_size = min(batch_size, len(sorted_images) - index * batch_size)
+ pbar.update(current_batch_size)
+
+ # Restore original order
+ unsorted_results = [""] * len(mfr_res)
+ for new_idx, latex in enumerate(mfr_res):
+ original_idx = index_mapping[new_idx]
+ unsorted_results[original_idx] = latex
+
+ # Fill results back
+ for res, latex in zip(backfill_list, unsorted_results):
+ res["latex"] = latex
+
+ return images_formula_list
@@ -5,7 +5,9 @@ import cv2
import albumentations as alb
from albumentations.pytorch import ToTensorV2
from torchvision.transforms.functional import resize
-
+import torch
+import torch_npu
+import torch.nn.functional as F
# TODO: dereference cv2 if possible
class UnimerSwinImageProcessor(BaseImageProcessor):
@@ -25,10 +27,53 @@ class UnimerSwinImageProcessor(BaseImageProcessor):
]
)
- def __call__(self, item):
+ self.NORMALIZE_DIVISOR = torch.tensor(255.0, dtype=torch.float16, device="npu")
+ self.weights = torch.tensor([[[0.2989]], [[0.5870]], [[0.1140]]], dtype=torch.float16, device="npu")
+ self.mean = torch.tensor(0.7931, dtype=torch.float16, device="npu")
+ self.std = torch.tensor(0.1738, dtype=torch.float16, device="npu")
+
+ self._mul_buf = torch.empty((3, *self.input_size), dtype=torch.float16, device="npu") # 预分配 [3,H,W]
+ self._gray_buf = torch.empty((1, *self.input_size), dtype=torch.float16, device="npu") # 预分配 [1,H,W]
+
+
+ def ___call__(self, item):
image = self.prepare_input(item)
return self.transform(image=image)['image'][:1]
+ def pil_to_npu(self, pil_img, device="npu"):
+ img = torch.from_numpy(np.asarray(pil_img, dtype=np.float16))
+ img = img.to(device).permute(2, 0, 1) / self.NORMALIZE_DIVISOR
+ return img
+
+ def __call__(self, item):
+
+ img = self.crop_margin(item)
+ img = self.pil_to_npu(img)
+
+ _, h, w = img.shape
+ target_h, target_w = self.input_size
+ scale = min(target_h / h, target_w / w)
+ new_h, new_w = int(h*scale), int(w*scale)
+
+ img = img.view(1, *img.shape) # [1,C,H,W]
+ img = F.interpolate(img, size=(new_h, new_w), mode='bilinear', align_corners=False)
+ img = img.view(*img.shape[1:])
+
+ dw, dh = target_w - new_w, target_h - new_h
+ dw /= 2
+ dh /= 2
+ left, right = int(dw), int(dw + 0.5)
+ top, bottom = int(dh), int(dh + 0.5)
+ img = F.pad(img, (left, right, top, bottom), value=0.0)
+
+ # RGB -> Gray
+ gray_tensor = (img * self.weights).sum(dim=0, keepdim=True) # [1, H, W]
+
+ # Normalize
+ gray_tensor.sub_(self.mean).div_(self.std)
+ return gray_tensor
+
+
@staticmethod
def crop_margin(img: Image.Image) -> Image.Image:
data = np.array(img.convert("L"))
@@ -44,6 +89,32 @@ class UnimerSwinImageProcessor(BaseImageProcessor):
a, b, w, h = cv2.boundingRect(coords) # Find minimum spanning bounding box
return img.crop((a, b, w + a, h + b))
+ def crop_margin_tensor(self, img):
+ """
+ img: [C,H,W] tensor, uint8 或 float
+ """
+
+ gray = (img * self.weights).sum(dim=0)
+
+ gray = gray.to(torch.uint8)
+ max_val = gray.max()
+ min_val = gray.min()
+
+ if max_val == min_val:
+ return img
+
+ norm_gray = (gray - min_val) / (max_val - min_val)
+
+ mask = (norm_gray < self.threshold)
+
+ coords = mask.nonzero(as_tuple=False)
+ if coords.shape[0] == 0:
+ return img
+ ymin, xmin = coords.min(0)[0]
+ ymax, xmax = coords.max(0)[0]
+
+ return img[:, ymin:ymax+1, xmin:xmax+1]
+
@staticmethod
def crop_margin_numpy(img: np.ndarray) -> np.ndarray:
"""Crop margins of image using NumPy operations"""
@@ -451,6 +451,16 @@ class UnimerSwinSelfAttention(nn.Module):
self.value = nn.Linear(self.all_head_size, self.all_head_size, bias=config.qkv_bias)
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
+
+ wq = self.query.weight
+ wk = self.key.weight
+ wv = self.value.weight
+ self.qkv = nn.Linear(in_features=wk.shape[1], out_features=wq.shape[0] + wk.shape[0] + wv.shape[0])
+ self.qkv.weight = nn.Parameter(torch.concat([wq, wk, wv], dim=0), requires_grad=False)
+ wq_bias = self.query.bias if self.query.bias is not None else torch.zeros(wq.shape[0])
+ wk_bias = self.key.bias if self.key.bias is not None else torch.zeros(wk.shape[0])
+ wv_bias = self.key.bias if self.value.bias is not None else torch.zeros(wv.shape[0])
+ self.qkv.bias = nn.Parameter(torch.concat([wq_bias, wk_bias, wv_bias], dim=0), requires_grad=False)
def transpose_for_scores(self, x):
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
@@ -465,11 +475,15 @@ class UnimerSwinSelfAttention(nn.Module):
output_attentions: Optional[bool] = False,
) -> Tuple[torch.Tensor]:
batch_size, dim, num_channels = hidden_states.shape
- mixed_query_layer = self.query(hidden_states)
- key_layer = self.transpose_for_scores(self.key(hidden_states))
- value_layer = self.transpose_for_scores(self.value(hidden_states))
- query_layer = self.transpose_for_scores(mixed_query_layer)
+ # """融合qk为大矩阵,由于加入相对位置编码,PFA接口用不了,暂时只修改矩阵乘法"""
+ batch_size, dim, num_channels = hidden_states.shape
+ qkv = self.qkv(hidden_states)
+ q, k, v = qkv.chunk(3, dim=-1)
+
+ query_layer = q.view(*q.shape[:2], self.num_attention_heads, -1).permute(0, 2, 1, 3)
+ key_layer = k.view(*k.shape[:2], self.num_attention_heads, -1).permute(0, 2, 1, 3)
+ value_layer = v.view(*v.shape[:2], self.num_attention_heads, -1).permute(0, 2, 1, 3)
# Take the dot product between "query" and "key" to get the raw attention scores.
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
@@ -117,6 +117,10 @@ class TextDetector(BaseOCRV20):
self.net.eval()
self.net.to(self.device)
+
+ import threading
+ self._dev_lock = getattr(self, "_dev_lock", threading.Lock())
+
def _batch_process_same_size(self, img_list):
"""
对相同尺寸的图像进行批处理
@@ -162,11 +166,11 @@ class TextDetector(BaseOCRV20):
return batch_results, time.time() - starttime
# 批处理推理
- with torch.no_grad():
- inp = torch.from_numpy(batch_tensor)
- inp = inp.to(self.device)
- outputs = self.net(inp)
-
+ with self._dev_lock:
+ with torch.no_grad():
+ inp = torch.from_numpy(batch_tensor)
+ inp = inp.to(self.device)
+ outputs = self.net(inp)
# 处理输出
preds = {}
if self.det_algorithm == "EAST":
@@ -304,10 +308,11 @@ class TextDetector(BaseOCRV20):
img = img.copy()
starttime = time.time()
- with torch.no_grad():
- inp = torch.from_numpy(img)
- inp = inp.to(self.device)
- outputs = self.net(inp)
+ with self._dev_lock:
+ with torch.no_grad():
+ inp = torch.from_numpy(img)
+ inp = inp.to(self.device)
+ outputs = self.net(inp)
preds = {}
if self.det_algorithm == "EAST":
@@ -94,6 +94,9 @@ class TextRecognizer(BaseOCRV20):
self.net.eval()
self.net.to(self.device)
+ import threading
+ self._dev_lock = getattr(self, "_dev_lock", threading.Lock())
+
def resize_norm_img(self, img, max_wh_ratio):
imgC, imgH, imgW = self.rec_image_shape
if self.rec_algorithm == 'NRTR' or self.rec_algorithm == 'ViTSTR':
@@ -301,74 +304,78 @@ class TextRecognizer(BaseOCRV20):
rec_res = [['', 0.0]] * img_num
batch_num = self.rec_batch_num
elapse = 0
- # for beg_img_no in range(0, img_num, batch_num):
- with tqdm(total=img_num, desc='OCR-rec Predict', disable=not tqdm_enable) as pbar:
- index = 0
- for beg_img_no in range(0, img_num, batch_num):
- end_img_no = min(img_num, beg_img_no + batch_num)
- norm_img_batch = []
- max_wh_ratio = 0
- for ino in range(beg_img_no, end_img_no):
- # h, w = img_list[ino].shape[0:2]
- h, w = img_list[indices[ino]].shape[0:2]
- wh_ratio = w * 1.0 / h
- max_wh_ratio = max(max_wh_ratio, wh_ratio)
- for ino in range(beg_img_no, end_img_no):
- if self.rec_algorithm == "SAR":
- norm_img, _, _, valid_ratio = self.resize_norm_img_sar(
- img_list[indices[ino]], self.rec_image_shape)
- norm_img = norm_img[np.newaxis, :]
- valid_ratio = np.expand_dims(valid_ratio, axis=0)
- valid_ratios = []
- valid_ratios.append(valid_ratio)
- norm_img_batch.append(norm_img)
-
- elif self.rec_algorithm == "SVTR":
- norm_img = self.resize_norm_img_svtr(img_list[indices[ino]],
- self.rec_image_shape)
- norm_img = norm_img[np.newaxis, :]
- norm_img_batch.append(norm_img)
- elif self.rec_algorithm == "SRN":
- norm_img = self.process_image_srn(img_list[indices[ino]],
- self.rec_image_shape, 8,
- self.max_text_length)
- encoder_word_pos_list = []
- gsrm_word_pos_list = []
- gsrm_slf_attn_bias1_list = []
- gsrm_slf_attn_bias2_list = []
- encoder_word_pos_list.append(norm_img[1])
- gsrm_word_pos_list.append(norm_img[2])
- gsrm_slf_attn_bias1_list.append(norm_img[3])
- gsrm_slf_attn_bias2_list.append(norm_img[4])
- norm_img_batch.append(norm_img[0])
- elif self.rec_algorithm == "CAN":
- norm_img = self.norm_img_can(img_list[indices[ino]],
- max_wh_ratio)
- norm_img = norm_img[np.newaxis, :]
- norm_img_batch.append(norm_img)
- norm_image_mask = np.ones(norm_img.shape, dtype='float32')
- word_label = np.ones([1, 36], dtype='int64')
- norm_img_mask_batch = []
- word_label_list = []
- norm_img_mask_batch.append(norm_image_mask)
- word_label_list.append(word_label)
- else:
- norm_img = self.resize_norm_img(img_list[indices[ino]],
- max_wh_ratio)
- norm_img = norm_img[np.newaxis, :]
- norm_img_batch.append(norm_img)
- norm_img_batch = np.concatenate(norm_img_batch)
- norm_img_batch = norm_img_batch.copy()
-
- if self.rec_algorithm == "SRN":
- starttime = time.time()
- encoder_word_pos_list = np.concatenate(encoder_word_pos_list)
- gsrm_word_pos_list = np.concatenate(gsrm_word_pos_list)
- gsrm_slf_attn_bias1_list = np.concatenate(
- gsrm_slf_attn_bias1_list)
- gsrm_slf_attn_bias2_list = np.concatenate(
- gsrm_slf_attn_bias2_list)
+ # for beg_img_no in range(0, img_num, batch_num):
+ from concurrent.futures import ThreadPoolExecutor, as_completed
+ def _rec_batch_worker(beg_img_no: int, end_img_no: int):
+
+
+ max_wh_ratio = 0.0
+ norm_img_batch = []
+ for ino in range(beg_img_no, end_img_no):
+ # h, w = img_list[ino].shape[0:2]
+ h, w = img_list[indices[ino]].shape[0:2]
+ wh_ratio = w * 1.0 / h
+ max_wh_ratio = max(max_wh_ratio, wh_ratio)
+ for ino in range(beg_img_no, end_img_no):
+ if self.rec_algorithm == "SAR":
+ norm_img, _, _, valid_ratio = self.resize_norm_img_sar(
+ img_list[indices[ino]], self.rec_image_shape)
+ norm_img = norm_img[np.newaxis, :]
+ valid_ratio = np.expand_dims(valid_ratio, axis=0)
+ valid_ratios = []
+ valid_ratios.append(valid_ratio)
+ norm_img_batch.append(norm_img)
+
+ elif self.rec_algorithm == "SVTR":
+ norm_img = self.resize_norm_img_svtr(img_list[indices[ino]],
+ self.rec_image_shape)
+ norm_img = norm_img[np.newaxis, :]
+ norm_img_batch.append(norm_img)
+ elif self.rec_algorithm == "SRN":
+ norm_img = self.process_image_srn(img_list[indices[ino]],
+ self.rec_image_shape, 8,
+ self.max_text_length)
+ encoder_word_pos_list = []
+ gsrm_word_pos_list = []
+ gsrm_slf_attn_bias1_list = []
+ gsrm_slf_attn_bias2_list = []
+ encoder_word_pos_list.append(norm_img[1])
+ gsrm_word_pos_list.append(norm_img[2])
+ gsrm_slf_attn_bias1_list.append(norm_img[3])
+ gsrm_slf_attn_bias2_list.append(norm_img[4])
+ norm_img_batch.append(norm_img[0])
+ elif self.rec_algorithm == "CAN":
+ norm_img = self.norm_img_can(img_list[indices[ino]],
+ max_wh_ratio)
+ norm_img = norm_img[np.newaxis, :]
+ norm_img_batch.append(norm_img)
+ norm_image_mask = np.ones(norm_img.shape, dtype='float32')
+ word_label = np.ones([1, 36], dtype='int64')
+ norm_img_mask_batch = []
+ word_label_list = []
+ norm_img_mask_batch.append(norm_image_mask)
+ word_label_list.append(word_label)
+ else:
+ norm_img = self.resize_norm_img(img_list[indices[ino]],
+ max_wh_ratio)
+ norm_img = norm_img[np.newaxis, :]
+ norm_img_batch.append(norm_img)
+ norm_img_batch = np.concatenate(norm_img_batch)
+ norm_img_batch = norm_img_batch.copy()
+
+ starttime = time.time()
+
+ if self.rec_algorithm == "SRN":
+ starttime = time.time()
+ encoder_word_pos_list = np.concatenate(encoder_word_pos_list)
+ gsrm_word_pos_list = np.concatenate(gsrm_word_pos_list)
+ gsrm_slf_attn_bias1_list = np.concatenate(
+ gsrm_slf_attn_bias1_list)
+ gsrm_slf_attn_bias2_list = np.concatenate(
+ gsrm_slf_attn_bias2_list)
+
+ with self._dev_lock:
with torch.no_grad():
inp = torch.from_numpy(norm_img_batch)
encoder_word_pos_inp = torch.from_numpy(encoder_word_pos_list)
@@ -384,58 +391,67 @@ class TextRecognizer(BaseOCRV20):
backbone_out = self.net.backbone(inp) # backbone_feat
prob_out = self.net.head(backbone_out, [encoder_word_pos_inp, gsrm_word_pos_inp, gsrm_slf_attn_bias1_inp, gsrm_slf_attn_bias2_inp])
- # preds = {"predict": prob_out[2]}
- preds = {"predict": prob_out["predict"]}
-
- elif self.rec_algorithm == "SAR":
- starttime = time.time()
- # valid_ratios = np.concatenate(valid_ratios)
- # inputs = [
- # norm_img_batch,
- # valid_ratios,
- # ]
-
+ # preds = {"predict": prob_out[2]}
+ preds = {"predict": prob_out["predict"]}
+
+ elif self.rec_algorithm == "SAR":
+ starttime = time.time()
+ # valid_ratios = np.concatenate(valid_ratios)
+ # inputs = [
+ # norm_img_batch,
+ # valid_ratios,
+ # ]
+
+ with self._dev_lock:
with torch.no_grad():
inp = torch.from_numpy(norm_img_batch)
inp = inp.to(self.device)
preds = self.net(inp)
- elif self.rec_algorithm == "CAN":
- starttime = time.time()
- norm_img_mask_batch = np.concatenate(norm_img_mask_batch)
- word_label_list = np.concatenate(word_label_list)
- inputs = [norm_img_batch, norm_img_mask_batch, word_label_list]
+ elif self.rec_algorithm == "CAN":
+ starttime = time.time()
+ norm_img_mask_batch = np.concatenate(norm_img_mask_batch)
+ word_label_list = np.concatenate(word_label_list)
+ inputs = [norm_img_batch, norm_img_mask_batch, word_label_list]
- inp = [torch.from_numpy(e_i) for e_i in inputs]
- inp = [e_i.to(self.device) for e_i in inp]
+ inp = [torch.from_numpy(e_i) for e_i in inputs]
+ inp = [e_i.to(self.device) for e_i in inp]
+ with self._dev_lock:
with torch.no_grad():
outputs = self.net(inp)
outputs = [v.cpu().numpy() for k, v in enumerate(outputs)]
- preds = outputs
-
- else:
- starttime = time.time()
+ preds = outputs
+ else:
+ with self._dev_lock:
with torch.no_grad():
- inp = torch.from_numpy(norm_img_batch)
- inp = inp.to(self.device)
+ inp = torch.from_numpy(norm_img_batch).to(self.device)
prob_out = self.net(inp)
+ preds = [v.cpu().numpy() for v in prob_out] if isinstance(prob_out, list) else prob_out.cpu().numpy()
- if isinstance(prob_out, list):
- preds = [v.cpu().numpy() for v in prob_out]
- else:
- preds = prob_out.cpu().numpy()
+ rec_result = self.postprocess_op(preds)
- rec_result = self.postprocess_op(preds)
- for rno in range(len(rec_result)):
- rec_res[indices[beg_img_no + rno]] = rec_result[rno]
- elapse += time.time() - starttime
+ for rno in range(len(rec_result)):
+ global_idx = indices[beg_img_no + rno]
+ rec_res[global_idx] = rec_result[rno]
+
+ batch_elapse = time.time() - starttime
+ return len(rec_result), batch_elapse
+
+ MAX_WORKERS = 4
+ with ThreadPoolExecutor(max_workers=MAX_WORKERS) as ex, \
+ tqdm(total=img_num, desc='OCR-rec Predict', disable=not tqdm_enable) as pbar:
+
+ futures = []
+ for beg_img_no in range(0, img_num, batch_num):
+ end_img_no = min(img_num, beg_img_no + batch_num)
+ futures.append(ex.submit(_rec_batch_worker, beg_img_no, end_img_no))
- # 更新进度条,每次增加batch_size,但要注意最后一个batch可能不足batch_size
- current_batch_size = min(batch_num, img_num - index * batch_num)
- index += 1
- pbar.update(current_batch_size)
+ for fut in as_completed(futures):
+ n_done, batch_elapse = fut.result()
+ elapse += batch_elapse
+ pbar.update(n_done)
# Fix NaN values in recognition results
for i in range(len(rec_res)):
@@ -21,6 +21,8 @@ class RapidTableModel(object):
self.table_model = RapidTable(input_args)
self.ocr_engine = ocr_engine
+ import threading
+ self._dev_lock = getattr(self, "_dev_lock", threading.Lock())
def predict(self, image):
bgr_image = cv2.cvtColor(np.asarray(image), cv2.COLOR_RGB2BGR)
@@ -30,44 +32,45 @@ class RapidTableModel(object):
img_aspect_ratio = img_height / img_width if img_width > 0 else 1.0
img_is_portrait = img_aspect_ratio > 1.2
- if img_is_portrait:
+ with self._dev_lock:
+ if img_is_portrait:
- det_res = self.ocr_engine.ocr(bgr_image, rec=False)[0]
- # Check if table is rotated by analyzing text box aspect ratios
- is_rotated = False
- if det_res:
- vertical_count = 0
+ det_res = self.ocr_engine.ocr(bgr_image, rec=False)[0]
+ # Check if table is rotated by analyzing text box aspect ratios
+ is_rotated = False
+ if det_res:
+ vertical_count = 0
- for box_ocr_res in det_res:
- p1, p2, p3, p4 = box_ocr_res
+ for box_ocr_res in det_res:
+ p1, p2, p3, p4 = box_ocr_res
- # Calculate width and height
- width = p3[0] - p1[0]
- height = p3[1] - p1[1]
+ # Calculate width and height
+ width = p3[0] - p1[0]
+ height = p3[1] - p1[1]
- aspect_ratio = width / height if height > 0 else 1.0
+ aspect_ratio = width / height if height > 0 else 1.0
- # Count vertical vs horizontal text boxes
- if aspect_ratio < 0.8: # Taller than wide - vertical text
- vertical_count += 1
- # elif aspect_ratio > 1.2: # Wider than tall - horizontal text
- # horizontal_count += 1
+ # Count vertical vs horizontal text boxes
+ if aspect_ratio < 0.8: # Taller than wide - vertical text
+ vertical_count += 1
+ # elif aspect_ratio > 1.2: # Wider than tall - horizontal text
+ # horizontal_count += 1
- # If we have more vertical text boxes than horizontal ones,
- # and vertical ones are significant, table might be rotated
- if vertical_count >= len(det_res) * 0.3:
- is_rotated = True
+ # If we have more vertical text boxes than horizontal ones,
+ # and vertical ones are significant, table might be rotated
+ if vertical_count >= len(det_res) * 0.3:
+ is_rotated = True
- # logger.debug(f"Text orientation analysis: vertical={vertical_count}, det_res={len(det_res)}, rotated={is_rotated}")
+ # logger.debug(f"Text orientation analysis: vertical={vertical_count}, det_res={len(det_res)}, rotated={is_rotated}")
- # Rotate image if necessary
- if is_rotated:
- # logger.debug("Table appears to be in portrait orientation, rotating 90 degrees clockwise")
- image = cv2.rotate(np.asarray(image), cv2.ROTATE_90_CLOCKWISE)
- bgr_image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
+ # Rotate image if necessary
+ if is_rotated:
+ # logger.debug("Table appears to be in portrait orientation, rotating 90 degrees clockwise")
+ image = cv2.rotate(np.asarray(image), cv2.ROTATE_90_CLOCKWISE)
+ bgr_image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
- # Continue with OCR on potentially rotated image
- ocr_result = self.ocr_engine.ocr(bgr_image)[0]
+ # Continue with OCR on potentially rotated image
+ ocr_result = self.ocr_engine.ocr(bgr_image)[0]
if ocr_result:
ocr_result = [[item[0], escape_html(item[1][0]), item[1][1]] for item in ocr_result if
len(item) == 2 and isinstance(item[1], tuple)]
new file mode 100644
@@ -0,0 +1,164 @@
+from loguru import logger
+import pypdfium2 as pdfium
+
+import torch
+import torch_npu
+import torch.nn as nn
+import torchvision
+import torchvision_npu
+import torchair as tng
+from torchair.configs.compiler_config import CompilerConfig
+
+from mineru.backend.pipeline.model_list import AtomicModel
+from mineru.model.mfr.unimernet.unimernet_hf.unimer_swin.modeling_unimer_swin import UnimerSwinSelfAttention
+from mineru.backend.pipeline.model_init import (
+ AtomModelSingleton,
+ table_model_init,
+ mfd_model_init,
+ mfr_model_init,
+ doclayout_yolo_model_init,
+ ocr_model_init,
+ )
+
+from transformers.generation.utils import GenerationMixin
+
+batch_candidate = None
+
+def set_batch_candidate(bs):
+ global batch_candidate
+ batch_candidate = bs
+
+def atom_model_init_compile(model_name: str, **kwargs):
+ global batch_candidate
+ atom_model = None
+ if model_name == AtomicModel.Layout:
+ atom_model = doclayout_yolo_model_init(
+ kwargs.get('doclayout_yolo_weights'),
+ kwargs.get('device')
+ )
+ atom_model.model.model = compile_model(atom_model.model.model, False, True)
+ npu_input = torch.zeros((batch_candidate[AtomicModel.Layout][0], 3, atom_model.imgsz, atom_model.imgsz))
+ tng.inference.set_dim_gears(npu_input, {0: batch_candidate[AtomicModel.Layout]})
+
+ elif model_name == AtomicModel.MFD:
+ atom_model = mfd_model_init(
+ kwargs.get('mfd_weights'),
+ kwargs.get('device')
+ )
+ atom_model.model.model = compile_model(atom_model.model.model, False, True)
+ npu_input = torch.zeros((batch_candidate[AtomicModel.MFD][0], 3, atom_model.imgsz, atom_model.imgsz))
+ tng.inference.set_dim_gears(npu_input, {0: batch_candidate[AtomicModel.MFD]})
+
+ elif model_name == AtomicModel.MFR:
+ atom_model = mfr_model_init(
+ kwargs.get('mfr_weight_dir'),
+ kwargs.get('device')
+ )
+
+ atom_model.model.encoder = compile_model(atom_model.model.encoder, False, True)
+ atom_model.model.decoder = compile_model(atom_model.model.decoder, True, True)
+
+ elif model_name == AtomicModel.OCR:
+ atom_model = ocr_model_init(
+ kwargs.get('det_db_box_thresh'),
+ kwargs.get('lang'),
+ kwargs.get('det_limit_side_len'),
+ )
+
+ elif model_name == AtomicModel.Table:
+ atom_model = table_model_init(
+ kwargs.get('lang'),
+ )
+
+ else:
+ logger.error('model name not allow')
+ raise ValueError("model name not allow")
+
+ if atom_model is None:
+ logger.error('model init failed')
+ raise RuntimeError("model init failed")
+
+ return atom_model
+
+def compile_model(model, dynamic, fullgraph):
+ config = CompilerConfig()
+ config.experimental_config.frozen_parameter = True
+ config.experimental_config.tiling_schedule_optimize = True
+ npu_backend = tng.get_npu_backend(compiler_config=config)
+ compiled_model = torch.compile(model, dynamic=dynamic, fullgraph=fullgraph, backend=npu_backend)
+ return compiled_model
+
+
+def rewrite_model_init():
+ def _patched_getmodel(self, atom_model_name: str, **kwargs):
+ lang = kwargs.get('lang', None)
+ table_model_name = kwargs.get('table_model_name', None)
+
+ if atom_model_name in [AtomicModel.OCR]:
+ key = (atom_model_name, lang)
+ elif atom_model_name in [AtomicModel.Table]:
+ key = (atom_model_name, table_model_name, lang)
+ else:
+ key = atom_model_name
+
+ if key not in self._models:
+ self._models[key] = atom_model_init_compile(model_name=atom_model_name, **kwargs)
+ return self._models[key]
+ AtomModelSingleton.get_atom_model = _patched_getmodel
+
+
+def rewrite_mfr_encoder_forward():
+ def _patched_prepare_encoder_decoder_kwargs_for_generation(self,
+ inputs_tensor: torch.Tensor,
+ model_kwargs,
+ model_input_name,
+ generation_config,
+ ):
+ # 1. get encoder
+ encoder = self.get_encoder()
+
+ # 2. Prepare encoder args and encoder kwargs from model kwargs and generation config.
+ irrelevant_prefix = ["decoder_", "cross_attn", "use_cache"]
+ encoder_kwargs = {
+ argument: value
+ for argument, value in model_kwargs.items()
+ if not any(argument.startswith(p) for p in irrelevant_prefix)
+ }
+ encoder_signature = set(inspect.signature(encoder.forward).parameters)
+ encoder_accepts_wildcard = "kwargs" in encoder_signature or "model_kwargs" in encoder_signature
+ if not encoder_accepts_wildcard:
+ encoder_kwargs = {
+ argument: value
+ for argument, value in encoder_kwargs.items()
+ if argument in encoder_signature
+ }
+ encoder_kwargs["output_attentions"] = generation_config.output_attentions
+ encoder_kwargs["output_hidden_states"] = generation_config.output_hidden_states
+
+ # 3. make sure that encoder returns `ModelOutput`
+ model_input_name = model_input_name if model_input_name is not None else self.main_input_name
+ encoder_kwargs["return_dict"] = True
+
+ ####### 固定input_tensor形状
+ pad_count = 0
+ if batch_candidate[AtomicModel.MFR] != inputs_tensor.shape[0]:
+ pad_count = batch_candidate[AtomicModel.MFR] - inputs_tensor.shape[0]
+ padding_tensor = torch.zeros(pad_count, *inputs_tensor.shape[1:], dtype=inputs_tensor.dtype, device=inputs_tensor.device)
+ inputs_tensor = torch.cat((inputs_tensor, padding_tensor), dim=0)
+
+ encoder_kwargs[model_input_name] = inputs_tensor
+ output = encoder(**encoder_kwargs)# type: ignore
+ if pad_count != 0:
+ output.last_hidden_state = output.last_hidden_state[:-pad_count]
+ output.pooler_output = output.pooler_output[:-pad_count]
+ model_kwargs["encoder_outputs"] = output
+ return model_kwargs
+
+ GenerationMixin._prepare_encoder_decoder_kwargs_for_generation = _patched_prepare_encoder_decoder_kwargs_for_generation
+
+def get_pdf_page_count(pdf_path):
+ pdf = pdfium.PdfDocument(pdf_path)
+ try:
+ return len(pdf)
+ finally:
+ pdf.close()
\ No newline at end of file