|
@@ -1,8 +1,7 @@
|
|
|
-import cv2
|
|
|
import numpy as np
|
|
|
|
|
|
|
|
|
-def crop_or_pad(image, x_min, y_min, x_max, y_max, pad_val=None):
|
|
|
+def crop_or_pad(image, x_min, y_min, x_max, y_max, border_value=0):
|
|
|
"""
|
|
|
References:
|
|
|
tf.image.resize_image_with_crop_or_pad
|
|
@@ -16,28 +15,38 @@ def crop_or_pad(image, x_min, y_min, x_max, y_max, pad_val=None):
|
|
|
dst_height, dst_width = y_max - y_min + 1, x_max - x_min + 1
|
|
|
channels = 1 if image.ndim == 2 else image.shape[2]
|
|
|
|
|
|
- if pad_val is not None:
|
|
|
- if isinstance(pad_val, (int, float)):
|
|
|
- pad_val = [pad_val for _ in range(channels)]
|
|
|
- assert len(pad_val) == channels
|
|
|
-
|
|
|
- src_x_begin = max(x_min, 0)
|
|
|
- src_y_begin = max(y_min, 0, )
|
|
|
- src_x_end = min(x_max + 1, src_width)
|
|
|
- src_y_end = min(y_max + 1, src_height)
|
|
|
- dst_x_begin = src_x_begin - x_min
|
|
|
- dst_y_begin = src_y_begin - y_min
|
|
|
- dst_x_end = src_x_end - x_min
|
|
|
- dst_y_end = src_y_end - y_min
|
|
|
-
|
|
|
if image.ndim == 2:
|
|
|
dst_image_shape = (dst_height, dst_width)
|
|
|
else:
|
|
|
dst_image_shape = (dst_height, dst_width, channels)
|
|
|
- if pad_val is None:
|
|
|
- dst_image = np.zeros(dst_image_shape, image.dtype)
|
|
|
+
|
|
|
+ if isinstance(border_value, (int, float)):
|
|
|
+ dst_image = np.full(dst_image_shape, border_value, dtype=image.dtype)
|
|
|
+ elif isinstance(border_value, tuple):
|
|
|
+ assert len(border_value) == channels, \
|
|
|
+ 'Expected the num of elements in tuple equals the channels' \
|
|
|
+ 'of input image. Found {} vs {}'.format(
|
|
|
+ len(border_value), channels)
|
|
|
+ if channels == 1:
|
|
|
+ dst_image = np.full(dst_image_shape, border_value[0], dtype=image.dtype)
|
|
|
+ else:
|
|
|
+ border_value = np.asarray(border_value, dtype=image.dtype)
|
|
|
+ dst_image = np.empty(dst_image_shape, dtype=image.dtype)
|
|
|
+ dst_image[:] = border_value
|
|
|
else:
|
|
|
- dst_image = np.full(dst_image_shape, pad_val, dtype=image.dtype)
|
|
|
+ raise ValueError(
|
|
|
+ 'Invalid type {} for `border_value`.'.format(type(border_value)))
|
|
|
+
|
|
|
+ src_x_begin = max(x_min, 0)
|
|
|
+ src_x_end = min(x_max + 1, src_width)
|
|
|
+ dst_x_begin = src_x_begin - x_min
|
|
|
+ dst_x_end = src_x_end - x_min
|
|
|
+
|
|
|
+ src_y_begin = max(y_min, 0, )
|
|
|
+ src_y_end = min(y_max + 1, src_height)
|
|
|
+ dst_y_begin = src_y_begin - y_min
|
|
|
+ dst_y_end = src_y_end - y_min
|
|
|
+
|
|
|
dst_image[dst_y_begin: dst_y_end, dst_x_begin: dst_x_end, ...] = \
|
|
|
image[src_y_begin: src_y_end, src_x_begin: src_x_end, ...]
|
|
|
return dst_image
|