# Copyright 2026 Huawei Technologies Co., Ltd

#

# Licensed under the Apache License, Version 2.0 (the "License");

# you may not use this file except in compliance with the License.

# You may obtain a copy of the License at

#

#     http://www.apache.org/licenses/LICENSE-2.0

#

# Unless required by applicable law or agreed to in writing, software

# distributed under the License is distributed on an "AS IS" BASIS,

# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.

# See the License for the specific language governing permissions and

# limitations under the License.



import argparse

import os

import time

from pathlib import Path



from tqdm import tqdm



from paddleocr import PaddleOCRVL



IMG_EXTS = {".png", ".jpg", ".jpeg", ".bmp", ".tif", ".tiff", ".webp"}

PDF_EXTS = {".pdf"}





def parse_args():

    parser = argparse.ArgumentParser()

    parser.add_argument(

        "--data_dir",

        type=str,

        default="OmniDocBenchV1.5/pdfs",

        help="Input path: a directory, a pdf file, or an image file",

    )

    parser.add_argument("--output_path", type=str, default="OmniDocBenchV1.5_out_pdf")

    parser.add_argument("--layout_detection_model_name", type=str, default="PP-DocLayoutV3")

    parser.add_argument("--layout_detection_model_dir", type=str, default="./PP-DocLayoutV3-weight")

    parser.add_argument("--vllm_ip", type=str, default="http://localhost:8000/v1")



    return parser.parse_args()





def is_supported_input_file(p: Path) -> bool:

    if not p.is_file():

        return False

    return p.suffix.lower() in (IMG_EXTS | PDF_EXTS)





def list_input_files(input_path: str) -> list[str]:

    """

    Accepts either:

      - a single pdf/image file path, or

      - a directory containing pdf/image files (non-recursive)



    Returns a sorted list of file paths.

    """

    input_path_obj = Path(input_path)

    if not input_path_obj.exists():

        raise FileNotFoundError(f"Input path not found: {input_path}")



    if input_path_obj.is_file():

        if not is_supported_input_file(input_path_obj):

            raise ValueError(f"Unsupported file type: {input_path_obj}")

        return [str(input_path_obj)]



    candidate_files = [p for p in input_path_obj.iterdir() if is_supported_input_file(p)]

    candidate_files.sort(key=lambda p: p.name)

    return [str(p) for p in candidate_files]





if __name__ == "__main__":

    args = parse_args()



    pipeline = PaddleOCRVL(

        layout_detection_model_name=args.layout_detection_model_name,

        layout_detection_model_dir=args.layout_detection_model_dir,

        vl_rec_backend="vllm-server", 

        vl_rec_server_url=args.vllm_ip,

        device="npu"

    )	  



    inputs = list_input_files(args.data_dir)

    print(f"Found {len(inputs)} file(s) from: {args.data_dir}")



    start_time = time.time()

    for filename in tqdm(os.listdir(args.data_dir)):

        file_path = os.path.join(args.data_dir, filename)

        if os.path.isfile(file_path):

            for res in pipeline.predict_iter(file_path):

                res.save_to_markdown(args.output_path)

    end_time = time.time()





    total_time = end_time - start_time



    total_files = len(inputs)

    print("\n========== Summary ==========")

    print(f"Total files processed: {total_files}")

    print(f"Total time (s): {total_time:.2f}")

    if total_files > 0:

        print(f"Avg time per file (s): {total_time / total_files:.3f}")