diff -ruN doclayout_yolo-0.0.4/doclayout_yolo/data/loaders.py doclayout_yolo/data/loaders.py
@@ -14,6 +14,7 @@
import requests
import torch
from PIL import Image
+from torchvision.transforms import functional as TF
from doclayout_yolo.data.utils import IMG_FORMATS, VID_FORMATS
from doclayout_yolo.utils import LOGGER, is_colab, is_kaggle, ops
@@ -411,7 +412,7 @@
self.bs = len(self.im0)
@staticmethod
- def _single_check(im):
+ def __single_check(im): ## origin _single_check
"""Validate and format an image to numpy array."""
assert isinstance(im, (Image.Image, np.ndarray)), f"Expected PIL/np.ndarray image type, but got {type(im)}"
if isinstance(im, Image.Image):
@@ -419,6 +420,18 @@
im = im.convert("RGB")
im = np.asarray(im)[:, :, ::-1]
im = np.ascontiguousarray(im) # contiguous
+
+ return im
+
+ @staticmethod
+ def _single_check(im):
+ """Validate and format an image to numpy array."""
+ assert isinstance(im, (Image.Image, np.ndarray)), f"Expected PIL/np.ndarray image type, but got {type(im)}"
+ if isinstance(im, Image.Image):
+ if im.mode != "RGB":
+ im = im.convert("RGB")
+ im = np.asarray(im)
+
return im
def __len__(self):
diff -ruN doclayout_yolo-0.0.4/doclayout_yolo/engine/model.py doclayout_yolo/engine/model.py
@@ -143,6 +143,8 @@
else:
self._load(model, task=task)
+ self.model.half()
+
def __call__(
self,
source: Union[str, Path, int, list, tuple, np.ndarray, torch.Tensor] = None,
diff -ruN doclayout_yolo-0.0.4/doclayout_yolo/engine/predictor.py doclayout_yolo/engine/predictor.py
@@ -47,6 +47,8 @@
from doclayout_yolo.utils.files import increment_path
from doclayout_yolo.utils.torch_utils import select_device, smart_inference_mode
+import torch.nn.functional as F
+
STREAM_WARNING = """
WARNING ⚠️ inference results will accumulate in RAM unless `stream=True` is passed, causing potential out-of-memory
errors for large sources or long-running streams and videos. See https://docs.doclayout_yolo.com/modes/predict/ for help.
@@ -112,7 +114,7 @@
self._lock = threading.Lock() # for automatic thread-safe inference
callbacks.add_integration_callbacks(self)
- def preprocess(self, im):
+ def _preprocess(self, im): ### origin preprocess
"""
Prepares input image before inference.
@@ -132,6 +134,46 @@
im /= 255 # 0 - 255 to 0.0 - 1.0
return im
+
+ def preprocess(self, images): ### adapt preprocess
+ """
+ Prepares input image before inference.
+
+ Args:
+ images (torch.Tensor | List(np.ndarray)): BCHW for tensor, [(HWC) x B] for list.
+ """
+ new_shape = (new_shape, new_shape) if isinstance(self.imgsz, int) else self.imgsz
+ tensors = []
+ for im in images:
+ im = torch.from_numpy(im).to(self.device).permute((2, 0, 1)) / 255.0
+
+ c, h, w = im.shape
+
+ r = min(new_shape[0] / h, new_shape[1] / w)
+
+ new_unpad = (int(round(w * r)), int(round(h * r)))
+
+ if (w, h) != new_unpad:
+ im = F.interpolate(im.unsqueeze(0), size=(new_unpad[1], new_unpad[0]),
+ mode="bilinear", align_corners=False).squeeze(0)
+
+ dw, dh = new_shape[1] - new_unpad[0], new_shape[0] - new_unpad[1]
+ dw /= 2
+ dh /= 2
+ left, right = int(dw), int(dw + 0.5)
+ top, bottom = int(dh), int(dh + 0.5)
+ im = F.pad(im, (left, right, top, bottom), value=114/255.0)
+
+ _, H, W = im.shape
+ assert (H, W) == (new_shape[0], new_shape[1]), f"Expected image size do not match: padding image size:{(H, W)} != expected image size: {(new_shape[0], new_shape[1])}"
+
+ im = im.half() if self.model.fp16 else im.float() # uint8 to fp16/32
+
+ tensors.append(im)
+
+ return torch.stack(tensors, dim=0)
+
+
def inference(self, im, *args, **kwargs):
"""Runs inference on a given image using the specified model and arguments."""
visualize = (
@@ -152,7 +194,8 @@
(list): A list of transformed images.
"""
same_shapes = len({x.shape for x in im}) == 1
- letterbox = LetterBox(self.imgsz, auto=same_shapes and self.model.pt, stride=self.model.stride)
+ letterbox = LetterBox(self.imgsz, auto=False, stride=self.model.stride)
+ # letterbox = LetterBox(self.imgsz, auto=same_shapes and self.model.pt, stride=self.model.stride)
return [letterbox(image=x) for x in im]
def postprocess(self, preds, img, orig_imgs):
@@ -225,7 +268,8 @@
# Warmup model
if not self.done_warmup:
- self.model.warmup(imgsz=(1 if self.model.pt or self.model.triton else self.dataset.bs, 3, *self.imgsz))
+ # self.model.warmup(imgsz=(1 if self.model.pt or self.model.triton else self.dataset.bs, 3, *self.imgsz))
+ self.model.warmup(imgsz=(self.dataset.bs, 3, *self.imgsz))
self.done_warmup = True
self.seen, self.windows, self.batch = 0, [], None
diff -ruN doclayout_yolo-0.0.4/doclayout_yolo/nn/modules/block.py doclayout_yolo/nn/modules/block.py
@@ -230,7 +230,9 @@
def forward(self, x):
"""Forward pass through C2f layer."""
y = list(self.cv1(x).chunk(2, 1))
- y.extend(m(y[-1]) for m in self.m)
+ # y.extend(m(y[-1]) for m in self.m)
+ for m in self.m:
+ y.append(m(y[-1]))
return self.cv2(torch.cat(y, 1))
def forward_split(self, x):
diff -ruN doclayout_yolo-0.0.4/doclayout_yolo/utils/tal.py doclayout_yolo/utils/tal.py
@@ -328,7 +328,8 @@
sy = torch.arange(end=h, device=device, dtype=dtype) + grid_cell_offset # shift y
sy, sx = torch.meshgrid(sy, sx, indexing="ij") if TORCH_1_10 else torch.meshgrid(sy, sx)
anchor_points.append(torch.stack((sx, sy), -1).view(-1, 2))
- stride_tensor.append(torch.full((h * w, 1), stride, dtype=dtype, device=device))
+ # stride_tensor.append(torch.full((h * w, 1), stride, dtype=dtype, device=device))
+ stride_tensor.append(torch.ones((h * w, 1), dtype=dtype, device=device)*stride)
return torch.cat(anchor_points), torch.cat(stride_tensor)