Переглянути джерело

refactor crop_or_pad and translate_image

quarrying 4 роки тому
батько
коміт
e2ce312f76
2 змінених файлів з 48 додано та 42 видалено
  1. 28 19
      khandy/image/crop_or_pad.py
  2. 20 23
      khandy/image/translate.py

+ 28 - 19
khandy/image/crop_or_pad.py

@@ -1,8 +1,7 @@
-import cv2
 import numpy as np
 
 
-def crop_or_pad(image, x_min, y_min, x_max, y_max, pad_val=None):
+def crop_or_pad(image, x_min, y_min, x_max, y_max, border_value=0):
     """
     References:
         tf.image.resize_image_with_crop_or_pad
@@ -16,28 +15,38 @@ def crop_or_pad(image, x_min, y_min, x_max, y_max, pad_val=None):
     dst_height, dst_width = y_max - y_min + 1, x_max - x_min + 1
     channels = 1 if image.ndim == 2 else image.shape[2]
     
-    if pad_val is not None:
-        if isinstance(pad_val, (int, float)):
-            pad_val = [pad_val for _ in range(channels)]
-        assert len(pad_val) == channels
-        
-    src_x_begin = max(x_min, 0)
-    src_y_begin = max(y_min, 0, )
-    src_x_end = min(x_max + 1, src_width)
-    src_y_end = min(y_max + 1, src_height)
-    dst_x_begin = src_x_begin - x_min
-    dst_y_begin = src_y_begin - y_min
-    dst_x_end = src_x_end - x_min
-    dst_y_end = src_y_end - y_min
-    
     if image.ndim == 2: 
         dst_image_shape = (dst_height, dst_width)
     else:
         dst_image_shape = (dst_height, dst_width, channels)
-    if pad_val is None:
-        dst_image = np.zeros(dst_image_shape, image.dtype)
+
+    if isinstance(border_value, (int, float)):
+        dst_image = np.full(dst_image_shape, border_value, dtype=image.dtype)
+    elif isinstance(border_value, tuple):
+        assert len(border_value) == channels, \
+            'Expected the num of elements in tuple equals the channels' \
+            'of input image. Found {} vs {}'.format(
+                len(border_value), channels)
+        if channels == 1:
+            dst_image = np.full(dst_image_shape, border_value[0], dtype=image.dtype)
+        else:
+            border_value = np.asarray(border_value, dtype=image.dtype)
+            dst_image = np.empty(dst_image_shape, dtype=image.dtype)
+            dst_image[:] = border_value
     else:
-        dst_image = np.full(dst_image_shape, pad_val, dtype=image.dtype)
+        raise ValueError(
+            'Invalid type {} for `border_value`.'.format(type(border_value)))
+        
+    src_x_begin = max(x_min, 0)
+    src_x_end   = min(x_max + 1, src_width)
+    dst_x_begin = src_x_begin - x_min
+    dst_x_end   = src_x_end - x_min
+
+    src_y_begin = max(y_min, 0, )
+    src_y_end   = min(y_max + 1, src_height)
+    dst_y_begin = src_y_begin - y_min
+    dst_y_end   = src_y_end - y_min
+    
     dst_image[dst_y_begin: dst_y_end, dst_x_begin: dst_x_end, ...] = \
         image[src_y_begin: src_y_end, src_x_begin: src_x_end, ...]
     return dst_image

+ 20 - 23
khandy/image/translate.py

@@ -16,47 +16,44 @@ def translate_image(image, x_shift, y_shift, border_value=0):
     Returns:
         ndarray: The translated image.
     """
+    assert image.ndim in [2, 3]
     assert isinstance(x_shift, int)
     assert isinstance(y_shift, int)
     image_height, image_width = image.shape[:2]
-
-    if image.ndim == 2:
-        channels = 1
-    elif image.ndim == 3:
-        channels = image.shape[-1]
-        
-    if isinstance(border_value, int):
-        new_image = np.full_like(image, border_value)
+    channels = 1 if image.ndim == 2 else image.shape[2]
+    
+    if isinstance(border_value, (int, float)):
+        dst_image = np.full_like(image, border_value)
     elif isinstance(border_value, tuple):
         assert len(border_value) == channels, \
             'Expected the num of elements in tuple equals the channels' \
             'of input image. Found {} vs {}'.format(
                 len(border_value), channels)
         if channels == 1:
-            new_image = np.full_like(image, border_value[0])
+            dst_image = np.full_like(image, border_value[0])
         else:
             border_value = np.asarray(border_value, dtype=image.dtype)
-            new_image = np.empty_like(image)
-            new_image[:] = border_value
+            dst_image = np.empty_like(image)
+            dst_image[:] = border_value
     else:
         raise ValueError(
             'Invalid type {} for `border_value`.'.format(type(border_value)))
         
     if (abs(x_shift) >= image_width) or (abs(y_shift) >= image_height):
-        return new_image
+        return dst_image
         
-    src_x_start = max(0, -x_shift)
-    src_x_end   = min(image_width, image_width - x_shift)
-    dst_x_start = max(0, x_shift)
-    dst_x_end   = min(image_width, image_width + x_shift)
+    src_x_begin = max(-x_shift, 0)
+    src_x_end   = min(image_width - x_shift, image_width)
+    dst_x_begin = max(x_shift, 0)
+    dst_x_end   = min(image_width + x_shift, image_width)
     
-    src_y_start = max(0, -y_shift)
-    src_y_end   = min(image_height, image_height - y_shift)
-    dst_y_start = max(0, y_shift)
-    dst_y_end   = min(image_height, image_height + y_shift)
+    src_y_begin = max(-y_shift, 0)
+    src_y_end   = min(image_height - y_shift, image_height)
+    dst_y_begin = max(y_shift, 0)
+    dst_y_end   = min(image_height + y_shift, image_height)
     
-    new_image[dst_y_start:dst_y_end, dst_x_start:dst_x_end] = \
-        image[src_y_start:src_y_end, src_x_start:src_x_end]
-    return new_image
+    dst_image[dst_y_begin:dst_y_end, dst_x_begin:dst_x_end] = \
+        image[src_y_begin:src_y_end, src_x_begin:src_x_end]
+    return dst_image