diff --git a/demo/demo.py b/demo/demo.py
index 36433c45..6f28620f 100644
--- a/demo/demo.py
+++ b/demo/demo.py
@@ -86,7 +86,7 @@ def do_parse(
                 image_dir = str(os.path.basename(local_image_dir))
                 content_list = pipeline_union_make(pdf_info, MakeMode.CONTENT_LIST, image_dir)
                 md_writer.write_string(
-                    f"{pdf_file_name}_content_list.json",
+                    f"{pdf_file_name}_content.json",
                     json.dumps(content_list, ensure_ascii=False, indent=4),
                 )
 
@@ -142,7 +142,8 @@ def do_parse(
                 image_dir = str(os.path.basename(local_image_dir))
                 content_list = vlm_union_make(pdf_info, MakeMode.CONTENT_LIST, image_dir)
                 md_writer.write_string(
-                    f"{pdf_file_name}_content_list.json",
+                    # f"{pdf_file_name}_content_list.json",
+                    f"{pdf_file_name}_content.json", ## 文件名太长了,linux文件系统ext4超过255字节无法保存
                     json.dumps(content_list, ensure_ascii=False, indent=4),
                 )
 
diff --git a/mineru/backend/pipeline/batch_analyze.py b/mineru/backend/pipeline/batch_analyze.py
index c88a52a3..561b055c 100644
--- a/mineru/backend/pipeline/batch_analyze.py
+++ b/mineru/backend/pipeline/batch_analyze.py
@@ -3,6 +3,9 @@ from loguru import logger
 from tqdm import tqdm
 from collections import defaultdict
 import numpy as np
+import time
+import torch
+import torch_npu
 
 from .model_init import AtomModelSingleton
 from ...utils.config_reader import get_formula_enable, get_table_enable
@@ -95,6 +98,7 @@ class BatchAnalyze:
                                               })
 
         # OCR检测处理
+        from concurrent.futures import ThreadPoolExecutor, as_completed
         if self.enable_ocr_det_batch:
             # 批处理模式 - 按语言和分辨率分组
             # 收集所有需要OCR检测的裁剪图像
@@ -139,79 +143,73 @@ class BatchAnalyze:
                 )
 
                 # 按分辨率分组并同时完成padding
+                stride = 64
                 resolution_groups = defaultdict(list)
                 for crop_info in lang_crop_list:
                     cropped_img = crop_info[0]
                     h, w = cropped_img.shape[:2]
                     # 使用更大的分组容差,减少分组数量
                     # 将尺寸标准化到32的倍数
-                    normalized_h = ((h + 32) // 32) * 32  # 向上取整到32的倍数
-                    normalized_w = ((w + 32) // 32) * 32
+                    normalized_h = ((h + stride) // stride) * stride  # 向上取整到stride的倍数
+                    normalized_w = ((w + stride) // stride) * stride
                     group_key = (normalized_h, normalized_w)
                     resolution_groups[group_key].append(crop_info)
 
-                # 对每个分辨率组进行批处理
-                for group_key, group_crops in tqdm(resolution_groups.items(), desc=f"OCR-det {lang}"):
-
-                    # 计算目标尺寸(组内最大尺寸,向上取整到32的倍数)
-                    max_h = max(crop_info[0].shape[0] for crop_info in group_crops)
-                    max_w = max(crop_info[0].shape[1] for crop_info in group_crops)
-                    target_h = ((max_h + 32 - 1) // 32) * 32
-                    target_w = ((max_w + 32 - 1) // 32) * 32
-
-                    # 对所有图像进行padding到统一尺寸
-                    batch_images = []
-                    for crop_info in group_crops:
-                        img = crop_info[0]
-                        h, w = img.shape[:2]
-                        # 创建目标尺寸的白色背景
-                        padded_img = np.ones((target_h, target_w, 3), dtype=np.uint8) * 255
-                        # 将原图像粘贴到左上角
-                        padded_img[:h, :w] = img
-                        batch_images.append(padded_img)
-
-                    # 批处理检测
-                    det_batch_size = min(len(batch_images), self.batch_ratio * OCR_DET_BASE_BATCH_SIZE)  # 增加批处理大小
-                    # logger.debug(f"OCR-det batch: {det_batch_size} images, target size: {target_h}x{target_w}")
-                    batch_results = ocr_model.text_detector.batch_predict(batch_images, det_batch_size)
-
-                    # 处理批处理结果
-                    for i, (crop_info, (dt_boxes, elapse)) in enumerate(zip(group_crops, batch_results)):
-                        new_image, useful_list, ocr_res_list_dict, res, adjusted_mfdetrec_res, _lang = crop_info
-
-                        if dt_boxes is not None and len(dt_boxes) > 0:
-                            # 直接应用原始OCR流程中的关键处理步骤
-                            from mineru.utils.ocr_utils import (
-                                merge_det_boxes, update_det_boxes, sorted_boxes
-                            )
 
-                            # 1. 排序检测框
-                            if len(dt_boxes) > 0:
-                                dt_boxes_sorted = sorted_boxes(dt_boxes)
-                            else:
-                                dt_boxes_sorted = []
-
-                            # 2. 合并相邻检测框
-                            if dt_boxes_sorted:
-                                dt_boxes_merged = merge_det_boxes(dt_boxes_sorted)
-                            else:
-                                dt_boxes_merged = []
-
-                            # 3. 根据公式位置更新检测框(关键步骤!)
-                            if dt_boxes_merged and adjusted_mfdetrec_res:
-                                dt_boxes_final = update_det_boxes(dt_boxes_merged, adjusted_mfdetrec_res)
-                            else:
-                                dt_boxes_final = dt_boxes_merged
-
-                            # 构造OCR结果格式
-                            ocr_res = [box.tolist() if hasattr(box, 'tolist') else box for box in dt_boxes_final]
-
-                            if ocr_res:
-                                ocr_result_list = get_ocr_result_list(
-                                    ocr_res, useful_list, ocr_res_list_dict['ocr_enable'], new_image, _lang
-                                )
-
-                                ocr_res_list_dict['layout_res'].extend(ocr_result_list)
+            def _run_one_group_ocr(group_key, group_crops):
+
+                max_h = max(ci[0].shape[0] for ci in group_crops)
+                max_w = max(ci[0].shape[1] for ci in group_crops)
+                target_h = ((max_h + stride - 1) // stride) * stride
+                target_w = ((max_w + stride - 1) // stride) * stride
+
+                batch_images = []
+                for ci in group_crops:
+                    img = ci[0]
+                    h, w = img.shape[:2]
+                    padded_img = np.ones((target_h, target_w, 3), dtype=np.uint8) * 255
+                    padded_img[:h, :w] = img
+                    batch_images.append(padded_img)
+
+                det_batch_size = min(len(batch_images), self.batch_ratio * OCR_DET_BASE_BATCH_SIZE)
+
+                batch_results = ocr_model.text_detector.batch_predict(batch_images, det_batch_size)
+
+                for i, (ci, (dt_boxes, elapse)) in enumerate(zip(group_crops, batch_results)):
+                    new_image, useful_list, ocr_res_list_dict, res, adjusted_mfdetrec_res, _lang = ci
+                    if dt_boxes is not None and len(dt_boxes) > 0:
+                        from mineru.utils.ocr_utils import merge_det_boxes, update_det_boxes, sorted_boxes
+
+                        if len(dt_boxes) > 0:
+                            dt_boxes_sorted = sorted_boxes(dt_boxes)
+                        else:
+                            dt_boxes_sorted = []
+
+                        if dt_boxes_sorted:
+                            dt_boxes_merged = merge_det_boxes(dt_boxes_sorted)
+                        else:
+                            dt_boxes_merged = []
+
+                        if dt_boxes_merged and adjusted_mfdetrec_res:
+                            dt_boxes_final = update_det_boxes(dt_boxes_merged, adjusted_mfdetrec_res)
+                        else:
+                            dt_boxes_final = dt_boxes_merged
+
+                        ocr_res = [box.tolist() if hasattr(box, 'tolist') else box for box in dt_boxes_final]
+                        if ocr_res:
+                            ocr_result_list = get_ocr_result_list(
+                                ocr_res, useful_list, ocr_res_list_dict['ocr_enable'], new_image, _lang
+                            )
+                            ocr_res_list_dict['layout_res'].extend(ocr_result_list)
+
+            MAX_WORKERS = 4
+            start = time.time()
+            with ThreadPoolExecutor(max_workers=MAX_WORKERS) as ex:
+                futures = [ex.submit(_run_one_group_ocr, gk, gcs) for gk, gcs in resolution_groups.items()]
+                for f in as_completed(futures):
+                    f.result()
+            end = time.time()
+            logger.info(f"ocr det run time : {end -start}")
         else:
             # 原始单张处理模式
             for ocr_res_list_dict in tqdm(ocr_res_list_all_page, desc="OCR-det Predict"):
@@ -247,7 +245,7 @@ class BatchAnalyze:
 
         # 表格识别 table recognition
         if self.table_enable:
-            for table_res_dict in tqdm(table_res_list_all_page, desc="Table Predict"):
+            def _run_one_group_table(table_res_dict):
                 _lang = table_res_dict['lang']
                 table_model = atom_model_manager.get_atom_model(
                     atom_model_name='table',
@@ -271,6 +269,16 @@ class BatchAnalyze:
                         'table recognition processing fails, not get html return'
                     )
 
+
+            MAX_WORKERS = 4
+            start = time.time()
+            with ThreadPoolExecutor(max_workers=MAX_WORKERS) as ex:
+                futures = [ex.submit(_run_one_group_table, table_res_dict) for table_res_dict in table_res_list_all_page]
+                for f in as_completed(futures):
+                    f.result()
+            end = time.time()
+            logger.info(f"table run time : {end - start}")
+
         # Create dictionaries to store items by language
         need_ocr_lists_by_lang = {}  # Dict of lists for each language
         img_crop_lists_by_lang = {}  # Dict of lists for each language
diff --git a/mineru/cli/fast_api.py b/mineru/cli/fast_api.py
index 1bfc3e5d..4466260e 100644
--- a/mineru/cli/fast_api.py
+++ b/mineru/cli/fast_api.py
@@ -10,6 +10,20 @@ from fastapi.responses import JSONResponse
 from typing import List, Optional
 from loguru import logger
 from base64 import b64encode
+import time
+
+from mineru.backend.pipeline.model_list import AtomicModel
+from mineru.utils.torchair_utils import (
+    get_pdf_page_count,
+    rewrite_model_init,
+    set_batch_candidate,
+    )
+from mineru.utils.model_utils import get_vram
+from mineru.backend.pipeline.batch_analyze import (
+    YOLO_LAYOUT_BASE_BATCH_SIZE,
+    MFD_BASE_BATCH_SIZE,
+    MFR_BASE_BATCH_SIZE,
+    )
 
 from mineru.cli.common import aio_do_parse, read_fn, pdf_suffixes, image_suffixes
 from mineru.utils.cli_parser import arg_parse
@@ -18,6 +32,8 @@ from mineru.version import __version__
 app = FastAPI()
 app.add_middleware(GZipMiddleware, minimum_size=1000)
 
+batch_ratio = 16
+
 def encode_image(image_path: str) -> str:
     """Encode image using base64"""
     with open(image_path, "rb") as f:
@@ -64,6 +80,8 @@ async def parse_pdf(
         pdf_file_names = []
         pdf_bytes_list = []
 
+        pdfs_page_count = 0
+
         for file in files:
             content = await file.read()
             file_path = Path(file.filename)
@@ -74,12 +92,14 @@ async def parse_pdf(
                 temp_path = Path(unique_dir) / file_path.name
                 with open(temp_path, "wb") as f:
                     f.write(content)
+                pdfs_page_count += get_pdf_page_count(temp_path)
 
                 try:
                     pdf_bytes = read_fn(temp_path)
                     pdf_bytes_list.append(pdf_bytes)
                     pdf_file_names.append(file_path.stem)
-                    os.remove(temp_path)  # 删除临时文件
+                    abs_path=os.path.abspath(temp_path)
+                    os.remove(abs_path)  # 删除临时文件
                 except Exception as e:
                     return JSONResponse(
                         status_code=400,
@@ -91,13 +111,20 @@ async def parse_pdf(
                     content={"error": f"Unsupported file type: {file_path.suffix}"}
                 )
 
+        batch_candidate = {
+            AtomicModel.Layout: [YOLO_LAYOUT_BASE_BATCH_SIZE, pdfs_page_count % YOLO_LAYOUT_BASE_BATCH_SIZE],
+            AtomicModel.MFD: [MFD_BASE_BATCH_SIZE, pdfs_page_count % MFD_BASE_BATCH_SIZE],
+            AtomicModel.MFR: batch_ratio * MFR_BASE_BATCH_SIZE,
+        }
+        set_batch_candidate(batch_candidate)
+
 
         # 设置语言列表,确保与文件数量一致
         actual_lang_list = lang_list
         if len(actual_lang_list) != len(pdf_file_names):
             # 如果语言列表长度不匹配,使用第一个语言或默认"ch"
             actual_lang_list = [actual_lang_list[0] if actual_lang_list else "ch"] * len(pdf_file_names)
-
+        start_time = time.time()
         # 调用异步处理函数
         await aio_do_parse(
             output_dir=unique_dir,
@@ -120,6 +147,9 @@ async def parse_pdf(
             end_page_id=end_page_id,
             **config
         )
+        stop_time = time.time()
+        print(f"total process time: {(stop_time-start_time):.2f}s")
+        print(f"per page process time: {(stop_time-start_time)/pdfs_page_count:.2f}s")
 
         # 构建结果路径
         result_dict = {}
@@ -171,11 +201,12 @@ async def parse_pdf(
 @click.command(context_settings=dict(ignore_unknown_options=True, allow_extra_args=True))
 @click.pass_context
 @click.option('--host', default='127.0.0.1', help='Server host (default: 127.0.0.1)')
-@click.option('--port', default=8000, type=int, help='Server port (default: 8000)')
+@click.option('--port', default=6543, type=int, help='Server port (default: 6543)')
 @click.option('--reload', is_flag=True, help='Enable auto-reload (development mode)')
 def main(ctx, host, port, reload, **kwargs):
 
     kwargs.update(arg_parse(ctx))
+    os.environ['MINERU_MODEL_SOURCE'] = 'local'
 
     # 将配置参数存储到应用状态中
     app.state.config = kwargs
diff --git a/mineru/model/layout/doclayout_yolo.py b/mineru/model/layout/doclayout_yolo.py
index 5667a909..fc5056bb 100644
--- a/mineru/model/layout/doclayout_yolo.py
+++ b/mineru/model/layout/doclayout_yolo.py
@@ -66,6 +66,7 @@ class DocLayoutYOLOModel:
                     conf=self.conf,
                     iou=self.iou,
                     verbose=False,
+                    half=True
                 )
                 for pred in predictions:
                     results.append(self._parse_prediction(pred))
diff --git a/mineru/model/mfd/yolo_v8.py b/mineru/model/mfd/yolo_v8.py
index 33dac091..1fb4b50e 100644
--- a/mineru/model/mfd/yolo_v8.py
+++ b/mineru/model/mfd/yolo_v8.py
@@ -31,7 +31,8 @@ class YOLOv8MFDModel:
             conf=self.conf,
             iou=self.iou,
             verbose=False,
-            device=self.device
+            device=self.device,
+            half=True
         )
         return [pred.cpu() for pred in preds] if is_batch else preds[0].cpu()
 
diff --git a/mineru/model/mfr/unimernet/Unimernet.py b/mineru/model/mfr/unimernet/Unimernet.py
index ae3879da..23e56f2a 100644
--- a/mineru/model/mfr/unimernet/Unimernet.py
+++ b/mineru/model/mfr/unimernet/Unimernet.py
@@ -1,7 +1,7 @@
 import torch
 from torch.utils.data import DataLoader, Dataset
 from tqdm import tqdm
-
+import numpy as np
 
 class MathDataset(Dataset):
     def __init__(self, image_paths, transform=None):
@@ -61,7 +61,7 @@ class UnimernetModel(object):
             res["latex"] = latex
         return formula_list
 
-    def batch_predict(self, images_mfd_res: list, images: list, batch_size: int = 64) -> list:
+    def _batch_predict(self, images_mfd_res: list, images: list, batch_size: int = 64) -> list:
         images_formula_list = []
         mf_image_list = []
         backfill_list = []
@@ -137,3 +137,94 @@ class UnimernetModel(object):
             res["latex"] = latex
 
         return images_formula_list
+
+
+    def batch_predict(self, images_mfd_res: list, images: list, batch_size: int = 64) -> list:
+
+        images_formula_list = []
+        mf_image_list = []
+        backfill_list = []
+        image_info = []  # Store (area, original_index, image) tuples
+
+        # Collect images with their original indices
+        for image_index in range(len(images_mfd_res)):
+            mfd_res = images_mfd_res[image_index]
+            pil_img = images[image_index]
+            # split代替多次索引
+            data = mfd_res.boxes.data.numpy()
+            xyxy, conf, cla = np.split(data, [4, 5], axis=-1)
+
+            cla = cla.reshape(-1).astype(int).tolist()
+            conf = np.round(conf.reshape(-1).astype(float), 2).tolist()
+
+            xyxy = xyxy.astype(np.int32)
+            xmin, ymin, xmax, ymax = xyxy[:, 0], xyxy[:, 1], xyxy[:, 2], xyxy[:, 3]
+            # area 直接矩阵运算
+            areas = (xmax - xmin) * (ymax - ymin)
+
+            num_boxes = len(conf)
+
+            formula_list = []
+            for i in range(num_boxes):
+                xmin_i, ymin_i, xmax_i, ymax_i = xyxy[i].tolist()
+                formula_list.append({
+                    "category_id": 13 + cla[i],
+                    "poly": [xmin_i, ymin_i, xmax_i, ymin_i,
+                            xmax_i, ymax_i, xmin_i, ymax_i],
+                    "score": conf[i],
+                    "latex": "",
+                })
+
+                # bbox_img 截取
+                # bbox_img = pil_img[:, ymin_i:ymax_i, xmin_i:xmax_i]
+                bbox_img = pil_img.crop((xmin_i, ymin_i, xmax_i, ymax_i))
+                curr_idx = len(mf_image_list)
+                image_info.append((areas[i], curr_idx, bbox_img))
+                mf_image_list.append(bbox_img)
+
+            images_formula_list.append(formula_list)
+            backfill_list += formula_list
+
+        # Stable sort by area
+        image_info.sort(key=lambda x: x[0])  # sort by area
+        sorted_indices = [x[1] for x in image_info]
+        sorted_images = [x[2] for x in image_info]
+
+        # Create mapping for results
+        index_mapping = {new_idx: old_idx for new_idx, old_idx in enumerate(sorted_indices)}
+
+        # Create dataset with sorted images
+        dataset = MathDataset(sorted_images, transform=self.model.transform)
+
+        # 如果batch_size > len(sorted_images),则设置为不超过len(sorted_images)的2的幂
+        batch_size = min(batch_size, max(1, 2 ** (len(sorted_images).bit_length() - 1))) if sorted_images else 1
+
+        dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=0)
+
+        # Process batches and store results
+        mfr_res = []
+        # for mf_img in dataloader:
+
+        with tqdm(total=len(sorted_images), desc="MFR Predict") as pbar:
+            for index, mf_img in enumerate(dataloader):
+                mf_img = mf_img.to(dtype=self.model.dtype)
+                mf_img = mf_img.to(self.device)
+                with torch.no_grad():
+                    output = self.model.generate({"image": mf_img}, batch_size=batch_size)
+                mfr_res.extend(output["fixed_str"])
+
+                # 更新进度条,每次增加batch_size,但要注意最后一个batch可能不足batch_size
+                current_batch_size = min(batch_size, len(sorted_images) - index * batch_size)
+                pbar.update(current_batch_size)
+
+        # Restore original order
+        unsorted_results = [""] * len(mfr_res)
+        for new_idx, latex in enumerate(mfr_res):
+            original_idx = index_mapping[new_idx]
+            unsorted_results[original_idx] = latex
+
+        # Fill results back
+        for res, latex in zip(backfill_list, unsorted_results):
+            res["latex"] = latex
+
+        return images_formula_list
diff --git a/mineru/model/mfr/unimernet/unimernet_hf/unimer_swin/image_processing_unimer_swin.py b/mineru/model/mfr/unimernet/unimernet_hf/unimer_swin/image_processing_unimer_swin.py
index 98d1deee..3866a257 100644
--- a/mineru/model/mfr/unimernet/unimernet_hf/unimer_swin/image_processing_unimer_swin.py
+++ b/mineru/model/mfr/unimernet/unimernet_hf/unimer_swin/image_processing_unimer_swin.py
@@ -5,7 +5,9 @@ import cv2
 import albumentations as alb
 from albumentations.pytorch import ToTensorV2
 from torchvision.transforms.functional import resize
-
+import torch
+import torch_npu
+import torch.nn.functional as F
 
 # TODO: dereference cv2 if possible
 class UnimerSwinImageProcessor(BaseImageProcessor):
@@ -25,10 +27,53 @@ class UnimerSwinImageProcessor(BaseImageProcessor):
             ]
         )
 
-    def __call__(self, item):
+        self.NORMALIZE_DIVISOR = torch.tensor(255.0, dtype=torch.float16, device="npu")
+        self.weights = torch.tensor([[[0.2989]], [[0.5870]], [[0.1140]]], dtype=torch.float16, device="npu")
+        self.mean = torch.tensor(0.7931, dtype=torch.float16, device="npu")
+        self.std = torch.tensor(0.1738, dtype=torch.float16, device="npu")
+
+        self._mul_buf = torch.empty((3, *self.input_size), dtype=torch.float16, device="npu")     # 预分配 [3,H,W]
+        self._gray_buf = torch.empty((1, *self.input_size), dtype=torch.float16, device="npu")     # 预分配 [1,H,W]
+
+
+    def ___call__(self, item):
         image = self.prepare_input(item)
         return self.transform(image=image)['image'][:1]
 
+    def pil_to_npu(self, pil_img, device="npu"):
+        img = torch.from_numpy(np.asarray(pil_img, dtype=np.float16))
+        img = img.to(device).permute(2, 0, 1) / self.NORMALIZE_DIVISOR
+        return img
+
+    def __call__(self, item):
+
+        img = self.crop_margin(item)
+        img = self.pil_to_npu(img)
+
+        _, h, w = img.shape
+        target_h, target_w = self.input_size
+        scale = min(target_h / h, target_w / w)
+        new_h, new_w = int(h*scale), int(w*scale)
+
+        img = img.view(1, *img.shape)   # [1,C,H,W]
+        img = F.interpolate(img, size=(new_h, new_w), mode='bilinear', align_corners=False)
+        img = img.view(*img.shape[1:])
+
+        dw, dh = target_w - new_w, target_h - new_h
+        dw /= 2
+        dh /= 2
+        left, right = int(dw), int(dw + 0.5)
+        top, bottom = int(dh), int(dh + 0.5)
+        img = F.pad(img, (left, right, top, bottom), value=0.0)
+
+        # RGB -> Gray
+        gray_tensor = (img * self.weights).sum(dim=0, keepdim=True)  # [1, H, W]
+
+        # Normalize
+        gray_tensor.sub_(self.mean).div_(self.std)
+        return gray_tensor
+
+
     @staticmethod
     def crop_margin(img: Image.Image) -> Image.Image:
         data = np.array(img.convert("L"))
@@ -44,6 +89,32 @@ class UnimerSwinImageProcessor(BaseImageProcessor):
         a, b, w, h = cv2.boundingRect(coords)  # Find minimum spanning bounding box
         return img.crop((a, b, w + a, h + b))
 
+    def crop_margin_tensor(self, img):
+        """
+        img: [C,H,W] tensor, uint8 或 float
+        """
+
+        gray = (img * self.weights).sum(dim=0)
+
+        gray = gray.to(torch.uint8)
+        max_val = gray.max()
+        min_val = gray.min()
+
+        if max_val == min_val:
+            return img
+
+        norm_gray = (gray - min_val) / (max_val - min_val)
+
+        mask = (norm_gray < self.threshold)
+
+        coords = mask.nonzero(as_tuple=False)
+        if coords.shape[0] == 0:
+            return img
+        ymin, xmin = coords.min(0)[0]
+        ymax, xmax = coords.max(0)[0]
+
+        return img[:, ymin:ymax+1, xmin:xmax+1]
+
     @staticmethod
     def crop_margin_numpy(img: np.ndarray) -> np.ndarray:
         """Crop margins of image using NumPy operations"""
diff --git a/mineru/model/mfr/unimernet/unimernet_hf/unimer_swin/modeling_unimer_swin.py b/mineru/model/mfr/unimernet/unimernet_hf/unimer_swin/modeling_unimer_swin.py
index 1b808e8b..2abac99a 100644
--- a/mineru/model/mfr/unimernet/unimernet_hf/unimer_swin/modeling_unimer_swin.py
+++ b/mineru/model/mfr/unimernet/unimernet_hf/unimer_swin/modeling_unimer_swin.py
@@ -451,6 +451,16 @@ class UnimerSwinSelfAttention(nn.Module):
         self.value = nn.Linear(self.all_head_size, self.all_head_size, bias=config.qkv_bias)
 
         self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
+        
+        wq = self.query.weight
+        wk = self.key.weight
+        wv = self.value.weight
+        self.qkv = nn.Linear(in_features=wk.shape[1], out_features=wq.shape[0] + wk.shape[0] + wv.shape[0])
+        self.qkv.weight = nn.Parameter(torch.concat([wq, wk, wv], dim=0), requires_grad=False)
+        wq_bias = self.query.bias if self.query.bias is not None else torch.zeros(wq.shape[0])
+        wk_bias = self.key.bias if self.key.bias is not None else torch.zeros(wk.shape[0])
+        wv_bias = self.key.bias if self.value.bias is not None else torch.zeros(wv.shape[0])
+        self.qkv.bias = nn.Parameter(torch.concat([wq_bias, wk_bias, wv_bias], dim=0), requires_grad=False)
 
     def transpose_for_scores(self, x):
         new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
@@ -465,11 +475,15 @@ class UnimerSwinSelfAttention(nn.Module):
         output_attentions: Optional[bool] = False,
     ) -> Tuple[torch.Tensor]:
         batch_size, dim, num_channels = hidden_states.shape
-        mixed_query_layer = self.query(hidden_states)
 
-        key_layer = self.transpose_for_scores(self.key(hidden_states))
-        value_layer = self.transpose_for_scores(self.value(hidden_states))
-        query_layer = self.transpose_for_scores(mixed_query_layer)
+        # """融合qk为大矩阵,由于加入相对位置编码,PFA接口用不了,暂时只修改矩阵乘法"""
+        batch_size, dim, num_channels = hidden_states.shape
+        qkv = self.qkv(hidden_states)
+        q, k, v = qkv.chunk(3, dim=-1)
+
+        query_layer = q.view(*q.shape[:2], self.num_attention_heads, -1).permute(0, 2, 1, 3)
+        key_layer = k.view(*k.shape[:2], self.num_attention_heads, -1).permute(0, 2, 1, 3)
+        value_layer = v.view(*v.shape[:2], self.num_attention_heads, -1).permute(0, 2, 1, 3)
 
         # Take the dot product between "query" and "key" to get the raw attention scores.
         attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
diff --git a/mineru/model/ocr/paddleocr2pytorch/tools/infer/predict_det.py b/mineru/model/ocr/paddleocr2pytorch/tools/infer/predict_det.py
index 3de483ac..712e4127 100755
--- a/mineru/model/ocr/paddleocr2pytorch/tools/infer/predict_det.py
+++ b/mineru/model/ocr/paddleocr2pytorch/tools/infer/predict_det.py
@@ -117,6 +117,10 @@ class TextDetector(BaseOCRV20):
         self.net.eval()
         self.net.to(self.device)
 
+
+        import threading
+        self._dev_lock = getattr(self, "_dev_lock", threading.Lock())
+
     def _batch_process_same_size(self, img_list):
         """
             对相同尺寸的图像进行批处理
@@ -162,11 +166,11 @@ class TextDetector(BaseOCRV20):
             return batch_results, time.time() - starttime
 
         # 批处理推理
-        with torch.no_grad():
-            inp = torch.from_numpy(batch_tensor)
-            inp = inp.to(self.device)
-            outputs = self.net(inp)
-
+        with self._dev_lock:
+            with torch.no_grad():
+                inp = torch.from_numpy(batch_tensor)
+                inp = inp.to(self.device)
+                outputs = self.net(inp)
         # 处理输出
         preds = {}
         if self.det_algorithm == "EAST":
@@ -304,10 +308,11 @@ class TextDetector(BaseOCRV20):
         img = img.copy()
         starttime = time.time()
 
-        with torch.no_grad():
-            inp = torch.from_numpy(img)
-            inp = inp.to(self.device)
-            outputs = self.net(inp)
+        with self._dev_lock:
+            with torch.no_grad():
+                inp = torch.from_numpy(img)
+                inp = inp.to(self.device)
+                outputs = self.net(inp)
 
         preds = {}
         if self.det_algorithm == "EAST":
diff --git a/mineru/model/ocr/paddleocr2pytorch/tools/infer/predict_rec.py b/mineru/model/ocr/paddleocr2pytorch/tools/infer/predict_rec.py
index c06ca5fe..d865b201 100755
--- a/mineru/model/ocr/paddleocr2pytorch/tools/infer/predict_rec.py
+++ b/mineru/model/ocr/paddleocr2pytorch/tools/infer/predict_rec.py
@@ -94,6 +94,9 @@ class TextRecognizer(BaseOCRV20):
         self.net.eval()
         self.net.to(self.device)
 
+        import threading
+        self._dev_lock = getattr(self, "_dev_lock", threading.Lock())
+
     def resize_norm_img(self, img, max_wh_ratio):
         imgC, imgH, imgW = self.rec_image_shape
         if self.rec_algorithm == 'NRTR' or self.rec_algorithm == 'ViTSTR':
@@ -301,74 +304,78 @@ class TextRecognizer(BaseOCRV20):
         rec_res = [['', 0.0]] * img_num
         batch_num = self.rec_batch_num
         elapse = 0
-        # for beg_img_no in range(0, img_num, batch_num):
-        with tqdm(total=img_num, desc='OCR-rec Predict', disable=not tqdm_enable) as pbar:
-            index = 0
-            for beg_img_no in range(0, img_num, batch_num):
-                end_img_no = min(img_num, beg_img_no + batch_num)
-                norm_img_batch = []
-                max_wh_ratio = 0
-                for ino in range(beg_img_no, end_img_no):
-                    # h, w = img_list[ino].shape[0:2]
-                    h, w = img_list[indices[ino]].shape[0:2]
-                    wh_ratio = w * 1.0 / h
-                    max_wh_ratio = max(max_wh_ratio, wh_ratio)
-                for ino in range(beg_img_no, end_img_no):
-                    if self.rec_algorithm == "SAR":
-                        norm_img, _, _, valid_ratio = self.resize_norm_img_sar(
-                            img_list[indices[ino]], self.rec_image_shape)
-                        norm_img = norm_img[np.newaxis, :]
-                        valid_ratio = np.expand_dims(valid_ratio, axis=0)
-                        valid_ratios = []
-                        valid_ratios.append(valid_ratio)
-                        norm_img_batch.append(norm_img)
-
-                    elif self.rec_algorithm == "SVTR":
-                        norm_img = self.resize_norm_img_svtr(img_list[indices[ino]],
-                                                             self.rec_image_shape)
-                        norm_img = norm_img[np.newaxis, :]
-                        norm_img_batch.append(norm_img)
-                    elif self.rec_algorithm == "SRN":
-                        norm_img = self.process_image_srn(img_list[indices[ino]],
-                                                          self.rec_image_shape, 8,
-                                                          self.max_text_length)
-                        encoder_word_pos_list = []
-                        gsrm_word_pos_list = []
-                        gsrm_slf_attn_bias1_list = []
-                        gsrm_slf_attn_bias2_list = []
-                        encoder_word_pos_list.append(norm_img[1])
-                        gsrm_word_pos_list.append(norm_img[2])
-                        gsrm_slf_attn_bias1_list.append(norm_img[3])
-                        gsrm_slf_attn_bias2_list.append(norm_img[4])
-                        norm_img_batch.append(norm_img[0])
-                    elif self.rec_algorithm == "CAN":
-                        norm_img = self.norm_img_can(img_list[indices[ino]],
-                                                     max_wh_ratio)
-                        norm_img = norm_img[np.newaxis, :]
-                        norm_img_batch.append(norm_img)
-                        norm_image_mask = np.ones(norm_img.shape, dtype='float32')
-                        word_label = np.ones([1, 36], dtype='int64')
-                        norm_img_mask_batch = []
-                        word_label_list = []
-                        norm_img_mask_batch.append(norm_image_mask)
-                        word_label_list.append(word_label)
-                    else:
-                        norm_img = self.resize_norm_img(img_list[indices[ino]],
-                                                        max_wh_ratio)
-                        norm_img = norm_img[np.newaxis, :]
-                        norm_img_batch.append(norm_img)
-                norm_img_batch = np.concatenate(norm_img_batch)
-                norm_img_batch = norm_img_batch.copy()
-
-                if self.rec_algorithm == "SRN":
-                    starttime = time.time()
-                    encoder_word_pos_list = np.concatenate(encoder_word_pos_list)
-                    gsrm_word_pos_list = np.concatenate(gsrm_word_pos_list)
-                    gsrm_slf_attn_bias1_list = np.concatenate(
-                        gsrm_slf_attn_bias1_list)
-                    gsrm_slf_attn_bias2_list = np.concatenate(
-                        gsrm_slf_attn_bias2_list)
 
+        # for beg_img_no in range(0, img_num, batch_num):
+        from concurrent.futures import ThreadPoolExecutor, as_completed
+        def _rec_batch_worker(beg_img_no: int, end_img_no: int):
+
+
+            max_wh_ratio = 0.0
+            norm_img_batch = []
+            for ino in range(beg_img_no, end_img_no):
+                # h, w = img_list[ino].shape[0:2]
+                h, w = img_list[indices[ino]].shape[0:2]
+                wh_ratio = w * 1.0 / h
+                max_wh_ratio = max(max_wh_ratio, wh_ratio)
+            for ino in range(beg_img_no, end_img_no):
+                if self.rec_algorithm == "SAR":
+                    norm_img, _, _, valid_ratio = self.resize_norm_img_sar(
+                        img_list[indices[ino]], self.rec_image_shape)
+                    norm_img = norm_img[np.newaxis, :]
+                    valid_ratio = np.expand_dims(valid_ratio, axis=0)
+                    valid_ratios = []
+                    valid_ratios.append(valid_ratio)
+                    norm_img_batch.append(norm_img)
+
+                elif self.rec_algorithm == "SVTR":
+                    norm_img = self.resize_norm_img_svtr(img_list[indices[ino]],
+                                                            self.rec_image_shape)
+                    norm_img = norm_img[np.newaxis, :]
+                    norm_img_batch.append(norm_img)
+                elif self.rec_algorithm == "SRN":
+                    norm_img = self.process_image_srn(img_list[indices[ino]],
+                                                        self.rec_image_shape, 8,
+                                                        self.max_text_length)
+                    encoder_word_pos_list = []
+                    gsrm_word_pos_list = []
+                    gsrm_slf_attn_bias1_list = []
+                    gsrm_slf_attn_bias2_list = []
+                    encoder_word_pos_list.append(norm_img[1])
+                    gsrm_word_pos_list.append(norm_img[2])
+                    gsrm_slf_attn_bias1_list.append(norm_img[3])
+                    gsrm_slf_attn_bias2_list.append(norm_img[4])
+                    norm_img_batch.append(norm_img[0])
+                elif self.rec_algorithm == "CAN":
+                    norm_img = self.norm_img_can(img_list[indices[ino]],
+                                                    max_wh_ratio)
+                    norm_img = norm_img[np.newaxis, :]
+                    norm_img_batch.append(norm_img)
+                    norm_image_mask = np.ones(norm_img.shape, dtype='float32')
+                    word_label = np.ones([1, 36], dtype='int64')
+                    norm_img_mask_batch = []
+                    word_label_list = []
+                    norm_img_mask_batch.append(norm_image_mask)
+                    word_label_list.append(word_label)
+                else:
+                    norm_img = self.resize_norm_img(img_list[indices[ino]],
+                                                    max_wh_ratio)
+                    norm_img = norm_img[np.newaxis, :]
+                    norm_img_batch.append(norm_img)
+            norm_img_batch = np.concatenate(norm_img_batch)
+            norm_img_batch = norm_img_batch.copy()
+
+            starttime = time.time()
+
+            if self.rec_algorithm == "SRN":
+                starttime = time.time()
+                encoder_word_pos_list = np.concatenate(encoder_word_pos_list)
+                gsrm_word_pos_list = np.concatenate(gsrm_word_pos_list)
+                gsrm_slf_attn_bias1_list = np.concatenate(
+                    gsrm_slf_attn_bias1_list)
+                gsrm_slf_attn_bias2_list = np.concatenate(
+                    gsrm_slf_attn_bias2_list)
+
+                with self._dev_lock:
                     with torch.no_grad():
                         inp = torch.from_numpy(norm_img_batch)
                         encoder_word_pos_inp = torch.from_numpy(encoder_word_pos_list)
@@ -384,58 +391,67 @@ class TextRecognizer(BaseOCRV20):
 
                         backbone_out = self.net.backbone(inp) # backbone_feat
                         prob_out = self.net.head(backbone_out, [encoder_word_pos_inp, gsrm_word_pos_inp, gsrm_slf_attn_bias1_inp, gsrm_slf_attn_bias2_inp])
-                    # preds = {"predict": prob_out[2]}
-                    preds = {"predict": prob_out["predict"]}
-
-                elif self.rec_algorithm == "SAR":
-                    starttime = time.time()
-                    # valid_ratios = np.concatenate(valid_ratios)
-                    # inputs = [
-                    #     norm_img_batch,
-                    #     valid_ratios,
-                    # ]
-
+                # preds = {"predict": prob_out[2]}
+                preds = {"predict": prob_out["predict"]}
+
+            elif self.rec_algorithm == "SAR":
+                starttime = time.time()
+                # valid_ratios = np.concatenate(valid_ratios)
+                # inputs = [
+                #     norm_img_batch,
+                #     valid_ratios,
+                # ]
+
+                with self._dev_lock:
                     with torch.no_grad():
                         inp = torch.from_numpy(norm_img_batch)
                         inp = inp.to(self.device)
                         preds = self.net(inp)
 
-                elif self.rec_algorithm == "CAN":
-                    starttime = time.time()
-                    norm_img_mask_batch = np.concatenate(norm_img_mask_batch)
-                    word_label_list = np.concatenate(word_label_list)
-                    inputs = [norm_img_batch, norm_img_mask_batch, word_label_list]
+            elif self.rec_algorithm == "CAN":
+                starttime = time.time()
+                norm_img_mask_batch = np.concatenate(norm_img_mask_batch)
+                word_label_list = np.concatenate(word_label_list)
+                inputs = [norm_img_batch, norm_img_mask_batch, word_label_list]
 
-                    inp = [torch.from_numpy(e_i) for e_i in inputs]
-                    inp = [e_i.to(self.device) for e_i in inp]
+                inp = [torch.from_numpy(e_i) for e_i in inputs]
+                inp = [e_i.to(self.device) for e_i in inp]
+                with self._dev_lock:
                     with torch.no_grad():
                         outputs = self.net(inp)
                         outputs = [v.cpu().numpy() for k, v in enumerate(outputs)]
 
-                    preds = outputs
-
-                else:
-                    starttime = time.time()
+                preds = outputs
 
+            else:
+                with self._dev_lock:
                     with torch.no_grad():
-                        inp = torch.from_numpy(norm_img_batch)
-                        inp = inp.to(self.device)
+                        inp = torch.from_numpy(norm_img_batch).to(self.device)
                         prob_out = self.net(inp)
+                preds = [v.cpu().numpy() for v in prob_out] if isinstance(prob_out, list) else prob_out.cpu().numpy()
 
-                    if isinstance(prob_out, list):
-                        preds = [v.cpu().numpy() for v in prob_out]
-                    else:
-                        preds = prob_out.cpu().numpy()
+            rec_result = self.postprocess_op(preds)
 
-                rec_result = self.postprocess_op(preds)
-                for rno in range(len(rec_result)):
-                    rec_res[indices[beg_img_no + rno]] = rec_result[rno]
-                elapse += time.time() - starttime
+            for rno in range(len(rec_result)):
+                global_idx = indices[beg_img_no + rno]
+                rec_res[global_idx] = rec_result[rno]
+
+            batch_elapse = time.time() - starttime
+            return len(rec_result), batch_elapse
+
+        MAX_WORKERS = 4
+        with ThreadPoolExecutor(max_workers=MAX_WORKERS) as ex, \
+            tqdm(total=img_num, desc='OCR-rec Predict', disable=not tqdm_enable) as pbar:
+
+            futures = []
+            for beg_img_no in range(0, img_num, batch_num):
+                end_img_no = min(img_num, beg_img_no + batch_num)
+                futures.append(ex.submit(_rec_batch_worker, beg_img_no, end_img_no))
 
-                # 更新进度条,每次增加batch_size,但要注意最后一个batch可能不足batch_size
-                current_batch_size = min(batch_num, img_num - index * batch_num)
-                index += 1
-                pbar.update(current_batch_size)
+            for fut in as_completed(futures):
+                n_done, batch_elapse = fut.result()
+                elapse += batch_elapse
+                pbar.update(n_done)
 
         # Fix NaN values in recognition results
         for i in range(len(rec_res)):
diff --git a/mineru/model/table/rapid_table.py b/mineru/model/table/rapid_table.py
index 174a8052..dd796bcc 100644
--- a/mineru/model/table/rapid_table.py
+++ b/mineru/model/table/rapid_table.py
@@ -21,6 +21,8 @@ class RapidTableModel(object):
         self.table_model = RapidTable(input_args)
         self.ocr_engine = ocr_engine
 
+        import threading
+        self._dev_lock = getattr(self, "_dev_lock", threading.Lock())
 
     def predict(self, image):
         bgr_image = cv2.cvtColor(np.asarray(image), cv2.COLOR_RGB2BGR)
@@ -30,44 +32,45 @@ class RapidTableModel(object):
         img_aspect_ratio = img_height / img_width if img_width > 0 else 1.0
         img_is_portrait = img_aspect_ratio > 1.2
 
-        if img_is_portrait:
+        with self._dev_lock:
+            if img_is_portrait:
 
-            det_res = self.ocr_engine.ocr(bgr_image, rec=False)[0]
-            # Check if table is rotated by analyzing text box aspect ratios
-            is_rotated = False
-            if det_res:
-                vertical_count = 0
+                det_res = self.ocr_engine.ocr(bgr_image, rec=False)[0]
+                # Check if table is rotated by analyzing text box aspect ratios
+                is_rotated = False
+                if det_res:
+                    vertical_count = 0
 
-                for box_ocr_res in det_res:
-                    p1, p2, p3, p4 = box_ocr_res
+                    for box_ocr_res in det_res:
+                        p1, p2, p3, p4 = box_ocr_res
 
-                    # Calculate width and height
-                    width = p3[0] - p1[0]
-                    height = p3[1] - p1[1]
+                        # Calculate width and height
+                        width = p3[0] - p1[0]
+                        height = p3[1] - p1[1]
 
-                    aspect_ratio = width / height if height > 0 else 1.0
+                        aspect_ratio = width / height if height > 0 else 1.0
 
-                    # Count vertical vs horizontal text boxes
-                    if aspect_ratio < 0.8:  # Taller than wide - vertical text
-                        vertical_count += 1
-                    # elif aspect_ratio > 1.2:  # Wider than tall - horizontal text
-                    #     horizontal_count += 1
+                        # Count vertical vs horizontal text boxes
+                        if aspect_ratio < 0.8:  # Taller than wide - vertical text
+                            vertical_count += 1
+                        # elif aspect_ratio > 1.2:  # Wider than tall - horizontal text
+                        #     horizontal_count += 1
 
-                # If we have more vertical text boxes than horizontal ones,
-                # and vertical ones are significant, table might be rotated
-                if vertical_count >= len(det_res) * 0.3:
-                    is_rotated = True
+                    # If we have more vertical text boxes than horizontal ones,
+                    # and vertical ones are significant, table might be rotated
+                    if vertical_count >= len(det_res) * 0.3:
+                        is_rotated = True
 
-                # logger.debug(f"Text orientation analysis: vertical={vertical_count}, det_res={len(det_res)}, rotated={is_rotated}")
+                    # logger.debug(f"Text orientation analysis: vertical={vertical_count}, det_res={len(det_res)}, rotated={is_rotated}")
 
-            # Rotate image if necessary
-            if is_rotated:
-                # logger.debug("Table appears to be in portrait orientation, rotating 90 degrees clockwise")
-                image = cv2.rotate(np.asarray(image), cv2.ROTATE_90_CLOCKWISE)
-                bgr_image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
+                # Rotate image if necessary
+                if is_rotated:
+                    # logger.debug("Table appears to be in portrait orientation, rotating 90 degrees clockwise")
+                    image = cv2.rotate(np.asarray(image), cv2.ROTATE_90_CLOCKWISE)
+                    bgr_image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
 
-        # Continue with OCR on potentially rotated image
-        ocr_result = self.ocr_engine.ocr(bgr_image)[0]
+            # Continue with OCR on potentially rotated image
+            ocr_result = self.ocr_engine.ocr(bgr_image)[0]
         if ocr_result:
             ocr_result = [[item[0], escape_html(item[1][0]), item[1][1]] for item in ocr_result if
                       len(item) == 2 and isinstance(item[1], tuple)]
diff --git a/mineru/utils/torchair_utils.py b/mineru/utils/torchair_utils.py
new file mode 100644
index 00000000..70c7b0f5
--- /dev/null
+++ b/mineru/utils/torchair_utils.py
@@ -0,0 +1,164 @@
+from loguru import logger
+import pypdfium2 as pdfium
+
+import torch
+import torch_npu
+import torch.nn as nn
+import torchvision
+import torchvision_npu
+import torchair as tng
+from torchair.configs.compiler_config import CompilerConfig
+
+from mineru.backend.pipeline.model_list import AtomicModel
+from mineru.model.mfr.unimernet.unimernet_hf.unimer_swin.modeling_unimer_swin import UnimerSwinSelfAttention
+from mineru.backend.pipeline.model_init import (
+    AtomModelSingleton,
+    table_model_init,
+    mfd_model_init,
+    mfr_model_init,
+    doclayout_yolo_model_init,
+    ocr_model_init,
+    )
+
+from transformers.generation.utils import GenerationMixin
+
+batch_candidate = None
+
+def set_batch_candidate(bs):
+    global batch_candidate
+    batch_candidate = bs
+
+def atom_model_init_compile(model_name: str, **kwargs):
+    global batch_candidate
+    atom_model = None
+    if model_name == AtomicModel.Layout:
+        atom_model = doclayout_yolo_model_init(
+            kwargs.get('doclayout_yolo_weights'),
+            kwargs.get('device')
+        )
+        atom_model.model.model = compile_model(atom_model.model.model, False, True)
+        npu_input = torch.zeros((batch_candidate[AtomicModel.Layout][0], 3, atom_model.imgsz, atom_model.imgsz))
+        tng.inference.set_dim_gears(npu_input, {0: batch_candidate[AtomicModel.Layout]})
+
+    elif model_name == AtomicModel.MFD:
+        atom_model = mfd_model_init(
+            kwargs.get('mfd_weights'),
+            kwargs.get('device')
+        )
+        atom_model.model.model = compile_model(atom_model.model.model, False, True)
+        npu_input = torch.zeros((batch_candidate[AtomicModel.MFD][0], 3, atom_model.imgsz, atom_model.imgsz))
+        tng.inference.set_dim_gears(npu_input, {0: batch_candidate[AtomicModel.MFD]})
+
+    elif model_name == AtomicModel.MFR:
+        atom_model = mfr_model_init(
+            kwargs.get('mfr_weight_dir'),
+            kwargs.get('device')
+        )
+
+        atom_model.model.encoder = compile_model(atom_model.model.encoder, False, True)
+        atom_model.model.decoder = compile_model(atom_model.model.decoder, True, True)
+        
+    elif model_name == AtomicModel.OCR:
+        atom_model = ocr_model_init(
+            kwargs.get('det_db_box_thresh'),
+            kwargs.get('lang'),
+            kwargs.get('det_limit_side_len'),
+        )
+
+    elif model_name == AtomicModel.Table:
+        atom_model = table_model_init(
+            kwargs.get('lang'),
+        )
+        
+    else:
+        logger.error('model name not allow')
+        raise ValueError("model name not allow")
+
+    if atom_model is None:
+        logger.error('model init failed')
+        raise RuntimeError("model init failed")
+
+    return atom_model
+
+def compile_model(model, dynamic, fullgraph):
+    config = CompilerConfig()
+    config.experimental_config.frozen_parameter = True
+    config.experimental_config.tiling_schedule_optimize = True
+    npu_backend = tng.get_npu_backend(compiler_config=config)
+    compiled_model = torch.compile(model, dynamic=dynamic, fullgraph=fullgraph, backend=npu_backend)
+    return compiled_model
+
+
+def rewrite_model_init():
+    def _patched_getmodel(self, atom_model_name: str, **kwargs):
+        lang = kwargs.get('lang', None)
+        table_model_name = kwargs.get('table_model_name', None)
+
+        if atom_model_name in [AtomicModel.OCR]:
+            key = (atom_model_name, lang)
+        elif atom_model_name in [AtomicModel.Table]:
+            key = (atom_model_name, table_model_name, lang)
+        else:
+            key = atom_model_name
+
+        if key not in self._models:
+            self._models[key] = atom_model_init_compile(model_name=atom_model_name, **kwargs)
+        return self._models[key]
+    AtomModelSingleton.get_atom_model = _patched_getmodel
+
+
+def rewrite_mfr_encoder_forward():
+    def _patched_prepare_encoder_decoder_kwargs_for_generation(self,
+        inputs_tensor: torch.Tensor,
+        model_kwargs,
+        model_input_name,
+        generation_config,
+    ):
+        # 1. get encoder
+        encoder = self.get_encoder()
+
+        # 2. Prepare encoder args and encoder kwargs from model kwargs and generation config.
+        irrelevant_prefix = ["decoder_", "cross_attn", "use_cache"]
+        encoder_kwargs = {
+            argument: value
+            for argument, value in model_kwargs.items()
+            if not any(argument.startswith(p) for p in irrelevant_prefix)
+        }
+        encoder_signature = set(inspect.signature(encoder.forward).parameters)
+        encoder_accepts_wildcard = "kwargs" in encoder_signature or "model_kwargs" in encoder_signature
+        if not encoder_accepts_wildcard:
+            encoder_kwargs = {
+                argument: value 
+                for argument, value in encoder_kwargs.items() 
+                if argument in encoder_signature
+            }
+        encoder_kwargs["output_attentions"] = generation_config.output_attentions
+        encoder_kwargs["output_hidden_states"] = generation_config.output_hidden_states
+
+        # 3. make sure that encoder returns `ModelOutput`
+        model_input_name = model_input_name if model_input_name is not None else self.main_input_name
+        encoder_kwargs["return_dict"] = True
+
+        ####### 固定input_tensor形状
+        pad_count = 0
+        if batch_candidate[AtomicModel.MFR] != inputs_tensor.shape[0]:
+            pad_count = batch_candidate[AtomicModel.MFR] - inputs_tensor.shape[0]
+            padding_tensor = torch.zeros(pad_count, *inputs_tensor.shape[1:], dtype=inputs_tensor.dtype, device=inputs_tensor.device)
+            inputs_tensor = torch.cat((inputs_tensor, padding_tensor), dim=0)
+
+        encoder_kwargs[model_input_name] = inputs_tensor
+        output = encoder(**encoder_kwargs)# type: ignore
+        if pad_count != 0:
+            output.last_hidden_state = output.last_hidden_state[:-pad_count]
+            output.pooler_output = output.pooler_output[:-pad_count]
+        model_kwargs["encoder_outputs"] = output
+        return model_kwargs
+
+    GenerationMixin._prepare_encoder_decoder_kwargs_for_generation = _patched_prepare_encoder_decoder_kwargs_for_generation
+
+def get_pdf_page_count(pdf_path):
+    pdf = pdfium.PdfDocument(pdf_path)
+    try:
+        return len(pdf)
+    finally:
+        pdf.close()
\ No newline at end of file