crop_or_pad.py 4.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136
  1. import numbers
  2. import numpy as np
  3. def crop_or_pad(image, x_min, y_min, x_max, y_max, border_value=0):
  4. """
  5. See Also:
  6. translate_image
  7. References:
  8. tf.image.resize_image_with_crop_or_pad
  9. """
  10. assert image.ndim in [2, 3]
  11. assert isinstance(x_min, numbers.Integral) and isinstance(y_min, numbers.Integral)
  12. assert isinstance(x_max, numbers.Integral) and isinstance(y_max, numbers.Integral)
  13. assert (x_min <= x_max) and (y_min <= y_max)
  14. src_height, src_width = image.shape[:2]
  15. dst_height, dst_width = y_max - y_min + 1, x_max - x_min + 1
  16. channels = 1 if image.ndim == 2 else image.shape[2]
  17. if image.ndim == 2:
  18. dst_image_shape = (dst_height, dst_width)
  19. else:
  20. dst_image_shape = (dst_height, dst_width, channels)
  21. if isinstance(border_value, numbers.Real):
  22. dst_image = np.full(dst_image_shape, border_value, dtype=image.dtype)
  23. elif isinstance(border_value, tuple):
  24. assert len(border_value) == channels, \
  25. 'Expected the num of elements in tuple equals the channels' \
  26. 'of input image. Found {} vs {}'.format(
  27. len(border_value), channels)
  28. if channels == 1:
  29. dst_image = np.full(dst_image_shape, border_value[0], dtype=image.dtype)
  30. else:
  31. border_value = np.asarray(border_value, dtype=image.dtype)
  32. dst_image = np.empty(dst_image_shape, dtype=image.dtype)
  33. dst_image[:] = border_value
  34. else:
  35. raise ValueError(
  36. 'Invalid type {} for `border_value`.'.format(type(border_value)))
  37. src_x_begin = max(x_min, 0)
  38. src_x_end = min(x_max + 1, src_width)
  39. dst_x_begin = src_x_begin - x_min
  40. dst_x_end = src_x_end - x_min
  41. src_y_begin = max(y_min, 0)
  42. src_y_end = min(y_max + 1, src_height)
  43. dst_y_begin = src_y_begin - y_min
  44. dst_y_end = src_y_end - y_min
  45. if (src_x_begin >= src_x_end) or (src_y_begin >= src_y_end):
  46. return dst_image
  47. dst_image[dst_y_begin: dst_y_end, dst_x_begin: dst_x_end, ...] = \
  48. image[src_y_begin: src_y_end, src_x_begin: src_x_end, ...]
  49. return dst_image
  50. def crop_or_pad_coords(boxes, image_width, image_height):
  51. """
  52. References:
  53. `mmcv.impad`
  54. `pad` in https://github.com/kpzhang93/MTCNN_face_detection_alignment
  55. `MtcnnDetector.pad` in https://github.com/AITTSMD/MTCNN-Tensorflow
  56. """
  57. x_mins = boxes[:, 0]
  58. y_mins = boxes[:, 1]
  59. x_maxs = boxes[:, 2]
  60. y_maxs = boxes[:, 3]
  61. dst_widths = x_maxs - x_mins + 1
  62. dst_heights = y_maxs - y_mins + 1
  63. src_x_begin = np.maximum(x_mins, 0)
  64. src_x_end = np.minimum(x_maxs + 1, image_width)
  65. dst_x_begin = src_x_begin - x_mins
  66. dst_x_end = src_x_end - x_mins
  67. src_y_begin = np.maximum(y_mins, 0)
  68. src_y_end = np.minimum(y_maxs + 1, image_height)
  69. dst_y_begin = src_y_begin - y_mins
  70. dst_y_end = src_y_end - y_mins
  71. coords = np.stack([dst_y_begin, dst_y_end, dst_x_begin, dst_x_end,
  72. src_y_begin, src_y_end, src_x_begin, src_x_end,
  73. dst_heights, dst_widths], axis=0)
  74. return coords
  75. def center_crop(image, dst_width, dst_height, strict=True):
  76. """
  77. strict:
  78. when True, raise error if src size is less than dst size.
  79. when False, remain unchanged if src size is less than dst size, otherwise center crop.
  80. """
  81. assert image.ndim in [2, 3]
  82. assert isinstance(dst_width, numbers.Integral) and isinstance(dst_height, numbers.Integral)
  83. src_height, src_width = image.shape[:2]
  84. if strict:
  85. assert (src_height >= dst_height) and (src_width >= dst_width)
  86. crop_top = max((src_height - dst_height) // 2, 0)
  87. crop_left = max((src_width - dst_width) // 2, 0)
  88. cropped = image[crop_top: dst_height + crop_top,
  89. crop_left: dst_width + crop_left, ...]
  90. return cropped
  91. def center_pad(image, dst_width, dst_height, strict=True):
  92. """
  93. strict:
  94. when True, raise error if src size is greater than dst size.
  95. when False, remain unchanged if src size is greater than dst size, otherwise center pad.
  96. """
  97. assert image.ndim in [2, 3]
  98. assert isinstance(dst_width, numbers.Integral) and isinstance(dst_height, numbers.Integral)
  99. src_height, src_width = image.shape[:2]
  100. if strict:
  101. assert (src_height <= dst_height) and (src_width <= dst_width)
  102. padding_x = max(dst_width - src_width, 0)
  103. padding_y = max(dst_height - src_height, 0)
  104. padding_top = padding_y // 2
  105. padding_left = padding_x // 2
  106. if image.ndim == 2:
  107. padding = ((padding_top, padding_y - padding_top),
  108. (padding_left, padding_x - padding_left))
  109. else:
  110. padding = ((padding_top, padding_y - padding_top),
  111. (padding_left, padding_x - padding_left), (0, 0))
  112. return np.pad(image, padding, 'constant')