misc.py 6.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199
  1. import os
  2. import imghdr
  3. from io import BytesIO
  4. import cv2
  5. import khandy
  6. import numpy as np
  7. from PIL import Image
  8. def imread_pil(file_or_buffer, to_mode=None):
  9. """Improvement on Image.open to avoid ResourceWarning.
  10. """
  11. try:
  12. if isinstance(file_or_buffer, bytes):
  13. buffer = BytesIO()
  14. buffer.write(file_or_buffer)
  15. buffer.seek(0)
  16. file_or_buffer = buffer
  17. if hasattr(file_or_buffer, 'read'):
  18. image = Image.open(file_or_buffer)
  19. if to_mode is not None:
  20. image = image.convert(to_mode)
  21. else:
  22. # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)
  23. with open(file_or_buffer, 'rb') as f:
  24. image = Image.open(f)
  25. # If convert outside with statement, will raise "seek of closed file" as
  26. # https://github.com/microsoft/Swin-Transformer/issues/66
  27. if to_mode is not None:
  28. image = image.convert(to_mode)
  29. return image
  30. except Exception as e:
  31. print(e)
  32. return None
  33. def imread_cv(file_or_buffer, flags=-1):
  34. """Improvement on cv2.imread, make it support filename including chinese character.
  35. """
  36. try:
  37. if isinstance(file_or_buffer, bytes):
  38. return cv2.imdecode(np.frombuffer(file_or_buffer, dtype=np.uint8), flags)
  39. else:
  40. # support type: file or str or Path
  41. return cv2.imdecode(np.fromfile(file_or_buffer, dtype=np.uint8), flags)
  42. except Exception as e:
  43. print(e)
  44. return None
  45. def imwrite_cv(filename, image, params=None):
  46. """Improvement on cv2.imwrite, make it support filename including chinese character.
  47. """
  48. try:
  49. cv2.imencode(os.path.splitext(filename)[-1], image, params)[1].tofile(filename)
  50. return True
  51. except:
  52. return False
  53. def imwrite_bytes(filename, image_bytes, update_extension=True):
  54. extension = imghdr.what('', image_bytes)
  55. if extension is None:
  56. raise ValueError('image_bytes is not image')
  57. extension = '.' + extension
  58. file_extension = khandy.get_path_extension(filename)
  59. if extension.lower() != file_extension.lower():
  60. if update_extension:
  61. filename = khandy.replace_path_extension(filename, extension)
  62. else:
  63. image = cv2.imdecode(np.frombuffer(image_bytes, np.uint8), -1)
  64. image_bytes = cv2.imencode(file_extension, image)[1]
  65. with open(filename, "wb") as f:
  66. f.write(image_bytes)
  67. return filename
  68. def normalize_image_dtype(image, keep_num_channels=False):
  69. """Normalize image dtype to uint8 (usually for visualization).
  70. Args:
  71. image : ndarray
  72. Input image.
  73. keep_num_channels : bool, optional
  74. If this is set to True, the result is an array which has
  75. the same shape as input image, otherwise the result is
  76. an array whose channels number is 3.
  77. Returns:
  78. out: ndarray
  79. Image whose dtype is np.uint8.
  80. """
  81. assert (image.ndim == 3 and image.shape[-1] in [1, 3]) or (image.ndim == 2)
  82. image = image.astype(np.float32)
  83. image = khandy.minmax_normalize(image, axis=None, copy=False)
  84. image = np.array(image * 255, dtype=np.uint8)
  85. if not keep_num_channels:
  86. if image.ndim == 2:
  87. image = np.expand_dims(image, -1)
  88. if image.shape[-1] == 1:
  89. image = np.tile(image, (1,1,3))
  90. return image
  91. def normalize_image_shape(image, swap_rb=False):
  92. """Normalize image shape to (h, w, 3).
  93. Args:
  94. image : ndarray
  95. Input image.
  96. swap_rb : bool, optional
  97. whether swap red and blue channel or not
  98. Returns:
  99. out: ndarray
  100. Image whose shape is (h, w, 3).
  101. """
  102. if image.ndim == 2:
  103. image = cv2.cvtColor(image, cv2.COLOR_GRAY2BGR)
  104. elif image.ndim == 3:
  105. num_channels = image.shape[-1]
  106. if num_channels == 1:
  107. image = cv2.cvtColor(image, cv2.COLOR_GRAY2BGR)
  108. elif num_channels == 3:
  109. if swap_rb:
  110. image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
  111. elif num_channels == 4:
  112. if swap_rb:
  113. image = cv2.cvtColor(image, cv2.COLOR_BGRA2RGB)
  114. else:
  115. image = cv2.cvtColor(image, cv2.COLOR_BGRA2BGR)
  116. else:
  117. raise ValueError('Unsupported!')
  118. else:
  119. raise ValueError('Unsupported!')
  120. return image
  121. def stack_image_list(image_list, dtype=np.float32):
  122. """Join a sequence of image along a new axis before first axis.
  123. References:
  124. `im_list_to_blob` in `py-faster-rcnn-master/lib/utils/blob.py`
  125. """
  126. assert isinstance(image_list, (tuple, list))
  127. max_dimension = np.array([image.ndim for image in image_list]).max()
  128. assert max_dimension in [2, 3]
  129. max_shape = np.array([image.shape[:2] for image in image_list]).max(axis=0)
  130. num_channels = []
  131. for image in image_list:
  132. if image.ndim == 2:
  133. num_channels.append(1)
  134. else:
  135. num_channels.append(image.shape[-1])
  136. assert len(set(num_channels) - set([1])) in [0, 1]
  137. max_num_channels = np.max(num_channels)
  138. blob = np.empty((len(image_list), max_shape[0], max_shape[1], max_num_channels), dtype=dtype)
  139. for k, image in enumerate(image_list):
  140. blob[k, :image.shape[0], :image.shape[1], :] = np.atleast_3d(image).astype(dtype, copy=False)
  141. if max_dimension == 2:
  142. blob = np.squeeze(blob, axis=-1)
  143. return blob
  144. def is_numpy_image(image):
  145. return isinstance(image, np.ndarray) and image.ndim in {2, 3}
  146. def is_gray_image(image, tol=3):
  147. assert is_numpy_image(image)
  148. if image.ndim == 2:
  149. return True
  150. elif image.ndim == 3:
  151. num_channels = image.shape[-1]
  152. if num_channels == 1:
  153. return True
  154. elif num_channels == 4:
  155. rgb = cv2.cvtColor(image, cv2.COLOR_BGRA2BGR)
  156. gray = cv2.cvtColor(rgb, cv2.COLOR_BGR2GRAY)
  157. gray3 = cv2.cvtColor(gray, cv2.COLOR_GRAY2BGR)
  158. mae = np.mean(cv2.absdiff(rgb, gray3))
  159. return mae <= tol
  160. elif num_channels == 3:
  161. gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
  162. gray3 = cv2.cvtColor(gray, cv2.COLOR_GRAY2BGR)
  163. mae = np.mean(cv2.absdiff(image, gray3))
  164. return mae <= tol
  165. else:
  166. return False
  167. else:
  168. return False