diff --git a/demo.py b/demo.py
index eaf6680..9b4b269 100644
--- a/demo.py
+++ b/demo.py
@@ -5,19 +5,27 @@ import cv2
 import numpy as np
 import torch
 
-from utils import IMG_SIZE, bilinear_unwarping, load_model
+from utils import IMG_SIZE, bilinear_unwarping
+
+import acl
+from ais_bench.infer.interface import InferSession
+import torch_npu
+from torch_npu.contrib import transfer_to_npu
 
 
 def unwarp_img(ckpt_path, img_path, img_size):
     """
     Unwarp a document image using the model from ckpt_path.
     """
-    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+    device = torch.device("npu")
+    device_id = 0
+    torch.npu.set_device("npu:0")
+    # 保存PTA的context
+    context, ret = acl.rt.get_context()
 
     # Load model
-    model = load_model(ckpt_path)
-    model.to(device)
-    model.eval()
+    # OM推理会重新生成新的context
+    model = InferSession(device_id, ckpt_path)
 
     # Load image
     img = cv2.imread(img_path)
@@ -25,8 +33,10 @@ def unwarp_img(ckpt_path, img_path, img_size):
     inp = torch.from_numpy(cv2.resize(img, img_size).transpose(2, 0, 1)).unsqueeze(0)
 
     # Make prediction
-    inp = inp.to(device)
-    point_positions2D, _ = model(inp)
+    point_positions2D, _ = model.infer([inp])
+    # 恢复PTA的context
+    ret = acl.rt.set_context(context)
+    point_positions2D = torch.from_numpy(point_positions2D).to(device)
 
     # Unwarp
     size = img.shape[:2][::-1]
@@ -46,7 +56,7 @@ if __name__ == "__main__":
     parser = argparse.ArgumentParser()
 
     parser.add_argument(
-        "--ckpt-path", type=str, default="./model/best_model.pkl", help="Path to the model weights as pkl."
+        "--ckpt-path", type=str, default="./model/best_model.om", help="Path to the model weights as om."
     )
     parser.add_argument("--img-path", type=str, help="Path to the document image to unwarp.")
 
diff --git a/docUnet_pred.py b/docUnet_pred.py
index 7121e4f..55fa3ac 100644
--- a/docUnet_pred.py
+++ b/docUnet_pred.py
@@ -10,7 +10,16 @@ import numpy as np
 import torch
 from tqdm import tqdm
 
-from utils import IMG_SIZE, bilinear_unwarping, load_model
+from utils import IMG_SIZE, bilinear_unwarping
+
+import acl
+from ais_bench.infer.interface import InferSession
+import torch_npu
+from torch_npu.contrib import transfer_to_npu
+
+torch.npu.set_device("npu:0")
+# 保存PTA的context
+context, ret = acl.rt.get_context()
 
 
 def get_processor_name():
@@ -73,15 +82,15 @@ def infer_docUnet(model, dataloader, device, save_path):
     Unwarp all images in the DocUNet benchmark and save them.
     Also measure the times it takes to perform this operation.
     """
-    model.eval()
     inference_times = []
     inferenceGPU_times = []
     for img_RGB, im_names in tqdm(dataloader):
         # Inference
         start_toGPU = time.time()
-        img_RGB = img_RGB.to(device)
         start_inf = time.time()
-        point_positions2D, _ = model(img_RGB)
+        point_positions2D, _ = model.infer([img_RGB])
+        ret = acl.rt.set_context(context)
+        point_positions2D = torch.from_numpy(point_positions2D).to(device)
         end_inf = time.time()
 
         # Warped image need to be re-open to get full resolution (downsampled in data loader)
@@ -124,19 +133,18 @@ def infer_docUnet(model, dataloader, device, save_path):
     # Computes average inference time and the number of parameters of the model
     avg_inference_time = np.mean(inference_times)
     avg_inferenceGPU_time = np.mean(inferenceGPU_times)
-    n_params = count_parameters(model)
-    return avg_inference_time, avg_inferenceGPU_time, n_params
+    return avg_inference_time, avg_inferenceGPU_time
 
 
 def create_results(ckpt_path, docUnet_path, crop, img_size):
     """
     Create results for the DocUNet benchmark.
     """
-    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+    device = torch.device("npu")
+    device_id = 0
 
     # Load model, create dataset and save directory
-    model = load_model(ckpt_path)
-    model.to(device)
+    model = InferSession(device_id, ckpt_path)
 
     dataset = docUnetLoader(docUnet_path, crop, img_size=img_size)
     dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, drop_last=False)
@@ -146,7 +154,7 @@ def create_results(ckpt_path, docUnet_path, crop, img_size):
     print(f"    Results will be saved at {save_path}", flush=True)
 
     # Infer results from the model and saves metadata
-    inference_time, inferenceGPU_time, n_params = infer_docUnet(model, dataloader, device, save_path)
+    inference_time, inferenceGPU_time = infer_docUnet(model, dataloader, device, save_path)
     with open(os.path.join(save_path, "model_info.txt"), "w") as f:
         f.write("\n---Model and Hardware Information---\n")
         f.write(f"Inference Time : {inference_time:.5f}s\n")
@@ -155,8 +163,7 @@ def create_results(ckpt_path, docUnet_path, crop, img_size):
         f.write(f"  FPS : {1/inferenceGPU_time:.1f}\n")
         f.write("Using :\n")
         f.write(f"  CPU : {get_processor_name()}\n")
-        f.write(f"  GPU : {torch.cuda.get_device_name(0)}\n")
-        f.write(f"Number of Parameters : {n_params:,}\n")
+        f.write(f"  GPU : {torch.npu.get_device_name(0)}\n")
     return save_path
 
 
@@ -164,7 +171,7 @@ if __name__ == "__main__":
     parser = argparse.ArgumentParser()
 
     parser.add_argument(
-        "--ckpt-path", type=str, default="./model/best_model.pkl", help="Path to the model weights as pkl."
+        "--ckpt-path", type=str, default="./model/best_model.om", help="Path to the model weights as om."
     )
     parser.add_argument("--docunet-path", type=str, default="./data/DocUNet", help="Path to the docunet benchmark.")
     parser.add_argument(
diff --git a/utils.py b/utils.py
index 523539e..6b98820 100644
--- a/utils.py
+++ b/utils.py
@@ -2,6 +2,8 @@ import os
 
 import torch
 import torch.nn.functional as F
+import torch_npu
+from torch_npu.contrib import transfer_to_npu
 
 from model import UVDocnet
 
diff --git a/uvdocBenchmark_metric.py b/uvdocBenchmark_metric.py
index 898077e..1812d2c 100644
--- a/uvdocBenchmark_metric.py
+++ b/uvdocBenchmark_metric.py
@@ -6,7 +6,7 @@ import hdf5storage as h5
 import numpy as np
 import torch
 import torch.nn.functional as F
-from skimage.morphology import binary_erosion
+from skimage.morphology import erosion
 from tqdm import tqdm
 
 from utils import bilinear_unwarping_from_numpy
@@ -47,7 +47,7 @@ def warp_texture(texture, uvmap):
     grey = np.all(warped_texture == 0.5, axis=-1)
     warped_texture[grey] = np.nan
     mask = 1 - np.all(np.isnan(warped_texture), axis=-1).astype(int)
-    mask_small = binary_erosion(mask).astype(int)
+    mask_small = erosion(mask).astype(int)
     mask_small = np.expand_dims(mask_small, axis=-1)
     warped_texture[np.repeat(~mask_small.astype(bool), 3, axis=-1)] = 1
     warped_texture = (warped_texture * 255).astype(np.uint8)
diff --git a/uvdocBenchmark_pred.py b/uvdocBenchmark_pred.py
index 7f97094..b461238 100644
--- a/uvdocBenchmark_pred.py
+++ b/uvdocBenchmark_pred.py
@@ -1,5 +1,6 @@
 import argparse
 import os
+import time
 
 import cv2
 import hdf5storage as h5
@@ -7,7 +8,16 @@ import numpy as np
 import torch
 from tqdm import tqdm
 
-from utils import IMG_SIZE, bilinear_unwarping, load_model
+from utils import IMG_SIZE, bilinear_unwarping
+
+import acl
+from ais_bench.infer.interface import InferSession
+import torch_npu
+from torch_npu.contrib import transfer_to_npu
+
+torch.npu.set_device("npu:0")
+# 保存PTA的context
+context, ret = acl.rt.get_context()
 
 
 class UVDocBenchmarkLoader(torch.utils.data.Dataset):
@@ -39,16 +49,20 @@ def infer_uvdoc(model, dataloader, device, save_path):
     """
     Unwarp all images in the UVDoc benchmark and save them, along with the mappings.
     """
-    model.eval()
-
     os.makedirs(os.path.join(save_path, "uwp_img"), exist_ok=True)
     os.makedirs(os.path.join(save_path, "bm"), exist_ok=True)
     os.makedirs(os.path.join(save_path, "uwp_texture"), exist_ok=True)
 
+    inference_times = []
+    inferenceGPU_times = []
     for img_RGB, im_names in tqdm(dataloader):
         # Inference
-        img_RGB = img_RGB.to(device)
-        point_positions2D, _ = model(img_RGB)
+        start_toGPU = time.time()
+        start_inf = time.time()
+        point_positions2D, _ = model.infer([img_RGB])
+        ret = acl.rt.set_context(context)
+        point_positions2D = torch.from_numpy(point_positions2D).to(device)
+        end_inf = time.time()
 
         # Warped image need to be re-open to get full resolution (downsampled in data loader)
         warped = cv2.imread(os.path.join(dataloader.dataset.dataroot, "img", im_names[0]))
@@ -56,34 +70,38 @@ def infer_uvdoc(model, dataloader, device, save_path):
         warped = torch.from_numpy(warped.transpose(2, 0, 1) / 255.0).float()
         size = warped.shape[1:][::-1]
 
+        # Unwarp and save the texture
+        warp_texture = cv2.imread(os.path.join(dataloader.dataset.dataroot, "warped_textures", im_names[0]))
+        warp_texture = cv2.cvtColor(warp_texture, cv2.COLOR_BGR2RGB)
+        warp_texture = torch.from_numpy(warp_texture.transpose(2, 0, 1) / 255.0).float()
+        size_texture = warp_texture.shape[1:][::-1]
+
         # Unwarping
+        start_unwarp = time.time()
         unwarped = bilinear_unwarping(
             warped_img=torch.unsqueeze(warped, dim=0).to(device),
             point_positions=torch.unsqueeze(point_positions2D[0], dim=0),
             img_size=tuple(size),
         )
+        unwarped_texture = bilinear_unwarping(
+            warped_img=torch.unsqueeze(warp_texture, dim=0).to(device),
+            point_positions=torch.unsqueeze(point_positions2D[0], dim=0),
+            img_size=tuple(size_texture),
+        )
+        end_unwarp = time.time()
+
         unwarped = (unwarped[0].detach().cpu().numpy().transpose(1, 2, 0) * 255).astype(np.uint8)
         unwarped_BGR = cv2.cvtColor(unwarped, cv2.COLOR_RGB2BGR)
 
+        unwarped_texture = (unwarped_texture[0].detach().cpu().numpy().transpose(1, 2, 0) * 255).astype(np.uint8)
+        unwarped_texture_BGR = cv2.cvtColor(unwarped_texture, cv2.COLOR_RGB2BGR)
+        end_toGPU = time.time()
+
         cv2.imwrite(
             os.path.join(save_path, "uwp_img", im_names[0].split(" ")[0].split(".")[0] + ".png"),
             unwarped_BGR,
         )
 
-        # Unwarp and save the texture
-        warp_texture = cv2.imread(os.path.join(dataloader.dataset.dataroot, "warped_textures", im_names[0]))
-        warp_texture = cv2.cvtColor(warp_texture, cv2.COLOR_BGR2RGB)
-        warp_texture = torch.from_numpy(warp_texture.transpose(2, 0, 1) / 255.0).float()
-        size = warp_texture.shape[1:][::-1]
-
-        unwarped_texture = bilinear_unwarping(
-            warped_img=torch.unsqueeze(warp_texture, dim=0).to(device),
-            point_positions=torch.unsqueeze(point_positions2D[0], dim=0),
-            img_size=tuple(size),
-        )
-        unwarped_texture = (unwarped_texture[0].detach().cpu().numpy().transpose(1, 2, 0) * 255).astype(np.uint8)
-        unwarped_texture_BGR = cv2.cvtColor(unwarped_texture, cv2.COLOR_RGB2BGR)
-
         cv2.imwrite(
             os.path.join(save_path, "uwp_texture", im_names[0].split(" ")[0].split(".")[0] + ".png"),
             unwarped_texture_BGR,
@@ -95,16 +113,24 @@ def infer_uvdoc(model, dataloader, device, save_path):
             {"bm": point_positions2D[0].detach().cpu().numpy().transpose(1, 2, 0)},
         )
 
+        inference_times.append(end_inf - start_inf + end_unwarp - start_unwarp)
+        inferenceGPU_times.append(end_inf - start_toGPU + end_toGPU - start_unwarp)
+    
+    # Computes average inference time and the number of parameters of the model
+    avg_inference_time = np.mean(inference_times)
+    avg_inferenceGPU_time = np.mean(inferenceGPU_times)
+    return avg_inference_time, avg_inferenceGPU_time
+
 
 def create_uvdoc_results(ckpt_path, uvdoc_path, img_size):
     """
     Create results for the UVDoc benchmark.
     """
-    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+    device = torch.device("npu")
+    device_id = 0
 
     # Load model, create dataset and save directory
-    model = load_model(ckpt_path)
-    model.to(device)
+    model = InferSession(device_id, ckpt_path)
 
     dataset = UVDocBenchmarkLoader(data_path=uvdoc_path, img_size=img_size)
     dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, drop_last=False)
@@ -114,14 +140,20 @@ def create_uvdoc_results(ckpt_path, uvdoc_path, img_size):
     print(f"    Results will be saved at {save_path}", flush=True)
 
     # Infer results
-    infer_uvdoc(model, dataloader, "cuda:0", save_path)
+    inference_time, inferenceGPU_time = infer_uvdoc(model, dataloader, device, save_path)
+    print()
+    print(f"Inference Time : {inference_time:.5f}s\n")
+    print(f"  FPS : {1/inference_time:.1f}\n")
+    print(f"Inference Time (Include Loading To/From GPU) : {inferenceGPU_time:.5f}s\n")
+    print(f"  FPS : {1/inferenceGPU_time:.1f}\n")
+    print()
     return save_path
 
 
 if __name__ == "__main__":
     parser = argparse.ArgumentParser()
     parser.add_argument(
-        "--ckpt-path", type=str, default="./model/best_model.pkl", help="Path to the model weights as pkl."
+        "--ckpt-path", type=str, default="./model/best_model.om", help="Path to the model weights as om."
     )
     parser.add_argument(
         "--uvdoc-path", type=str, default="./data/UVDoc_benchmark/", help="Path to the UVDocBenchmark dataset."