|
@@ -62,6 +62,40 @@ def normalize_image_dtype(image, keep_num_channels=False):
|
|
|
return image
|
|
|
|
|
|
|
|
|
+def normalize_image_shape(image, swap_rb=False):
|
|
|
+ """Normalize image shape to (h, w, 3).
|
|
|
+
|
|
|
+ Args:
|
|
|
+ image : ndarray
|
|
|
+ Input image.
|
|
|
+ swap_rb : bool, optional
|
|
|
+ whether swap red and blue channel or not
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ out: ndarray
|
|
|
+ Image whose shape is (h, w, 3).
|
|
|
+ """
|
|
|
+ if image.ndim == 2:
|
|
|
+ image = cv2.cvtColor(image, cv2.COLOR_GRAY2BGR)
|
|
|
+ elif image.ndim == 3:
|
|
|
+ num_channels = image.shape[-1]
|
|
|
+ if num_channels == 1:
|
|
|
+ image = cv2.cvtColor(image, cv2.COLOR_GRAY2BGR)
|
|
|
+ elif num_channels == 3:
|
|
|
+ if swap_rb:
|
|
|
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
|
|
+ elif num_channels == 4:
|
|
|
+ if swap_rb:
|
|
|
+ image = cv2.cvtColor(image, cv2.COLOR_BGRA2RGB)
|
|
|
+ else:
|
|
|
+ image = cv2.cvtColor(image, cv2.COLOR_BGRA2BGR)
|
|
|
+ else:
|
|
|
+ raise ValueError('Unsupported!')
|
|
|
+ else:
|
|
|
+ raise ValueError('Unsupported!')
|
|
|
+ return image
|
|
|
+
|
|
|
+
|
|
|
def stack_image_list(image_list, dtype=np.float32):
|
|
|
"""Join a sequence of image along a new axis before first axis.
|
|
|
|