misc.py 7.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220
  1. import os
  2. import imghdr
  3. import warnings
  4. from io import BytesIO
  5. import cv2
  6. import khandy
  7. import numpy as np
  8. from PIL import Image
  9. def imread(file_or_buffer, flags=-1):
  10. """Improvement on cv2.imread, make it support filename including chinese character.
  11. """
  12. try:
  13. if isinstance(file_or_buffer, bytes):
  14. return cv2.imdecode(np.frombuffer(file_or_buffer, dtype=np.uint8), flags)
  15. else:
  16. # support type: file or str or Path
  17. return cv2.imdecode(np.fromfile(file_or_buffer, dtype=np.uint8), flags)
  18. except Exception as e:
  19. print(e)
  20. return None
  21. def imread_cv(file_or_buffer, flags=-1):
  22. warnings.warn('khandy.imread_cv will be deprecated, use khandy.imread instead!')
  23. return imread(file_or_buffer, flags)
  24. def imwrite(filename, image, params=None):
  25. """Improvement on cv2.imwrite, make it support filename including chinese character.
  26. """
  27. cv2.imencode(os.path.splitext(filename)[-1], image, params)[1].tofile(filename)
  28. def imwrite_cv(filename, image, params=None):
  29. warnings.warn('khandy.imwrite_cv will be deprecated, use khandy.imwrite instead!')
  30. return imwrite(filename, image, params)
  31. def imread_pil(file_or_buffer, to_mode=None):
  32. """Improvement on Image.open to avoid ResourceWarning.
  33. """
  34. try:
  35. if isinstance(file_or_buffer, bytes):
  36. buffer = BytesIO()
  37. buffer.write(file_or_buffer)
  38. buffer.seek(0)
  39. file_or_buffer = buffer
  40. if hasattr(file_or_buffer, 'read'):
  41. image = Image.open(file_or_buffer)
  42. if to_mode is not None:
  43. image = image.convert(to_mode)
  44. else:
  45. # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)
  46. with open(file_or_buffer, 'rb') as f:
  47. image = Image.open(f)
  48. # If convert outside with statement, will raise "seek of closed file" as
  49. # https://github.com/microsoft/Swin-Transformer/issues/66
  50. if to_mode is not None:
  51. image = image.convert(to_mode)
  52. return image
  53. except Exception as e:
  54. print(e)
  55. return None
  56. def imwrite_bytes(filename, image_bytes: bytes, update_extension: bool = True):
  57. """Write image bytes to file.
  58. Args:
  59. filename: str
  60. filename which image_bytes is written into.
  61. image_bytes: bytes
  62. image content to be written.
  63. update_extension: bool
  64. whether update extension according to image_bytes or not.
  65. the cost of update extension is smaller than update image format.
  66. """
  67. extension = imghdr.what('', image_bytes)
  68. file_extension = khandy.get_path_extension(filename)
  69. # imghdr.what fails to determine image format sometimes!
  70. # so when its return value is None, never update extension.
  71. if extension is None:
  72. image = cv2.imdecode(np.frombuffer(image_bytes, np.uint8), -1)
  73. image_bytes = cv2.imencode(file_extension, image)[1]
  74. elif (extension.lower() != file_extension.lower()[1:]):
  75. if update_extension:
  76. filename = khandy.replace_path_extension(filename, extension)
  77. else:
  78. image = cv2.imdecode(np.frombuffer(image_bytes, np.uint8), -1)
  79. image_bytes = cv2.imencode(file_extension, image)[1]
  80. with open(filename, "wb") as f:
  81. f.write(image_bytes)
  82. return filename
  83. def normalize_image_dtype(image, keep_num_channels=False):
  84. """Normalize image dtype to uint8 (usually for visualization).
  85. Args:
  86. image : ndarray
  87. Input image.
  88. keep_num_channels : bool, optional
  89. If this is set to True, the result is an array which has
  90. the same shape as input image, otherwise the result is
  91. an array whose channels number is 3.
  92. Returns:
  93. out: ndarray
  94. Image whose dtype is np.uint8.
  95. """
  96. assert (image.ndim == 3 and image.shape[-1] in [1, 3]) or (image.ndim == 2)
  97. image = image.astype(np.float32)
  98. image = khandy.minmax_normalize(image, axis=None, copy=False)
  99. image = np.array(image * 255, dtype=np.uint8)
  100. if not keep_num_channels:
  101. if image.ndim == 2:
  102. image = np.expand_dims(image, -1)
  103. if image.shape[-1] == 1:
  104. image = np.tile(image, (1,1,3))
  105. return image
  106. def normalize_image_shape(image, swap_rb=False):
  107. """Normalize image shape to (h, w, 3).
  108. Args:
  109. image : ndarray
  110. Input image.
  111. swap_rb : bool, optional
  112. whether swap red and blue channel or not
  113. Returns:
  114. out: ndarray
  115. Image whose shape is (h, w, 3).
  116. """
  117. if image.ndim == 2:
  118. image = cv2.cvtColor(image, cv2.COLOR_GRAY2BGR)
  119. elif image.ndim == 3:
  120. num_channels = image.shape[-1]
  121. if num_channels == 1:
  122. image = cv2.cvtColor(image, cv2.COLOR_GRAY2BGR)
  123. elif num_channels == 3:
  124. if swap_rb:
  125. image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
  126. elif num_channels == 4:
  127. if swap_rb:
  128. image = cv2.cvtColor(image, cv2.COLOR_BGRA2RGB)
  129. else:
  130. image = cv2.cvtColor(image, cv2.COLOR_BGRA2BGR)
  131. else:
  132. raise ValueError('Unsupported!')
  133. else:
  134. raise ValueError('Unsupported!')
  135. return image
  136. def stack_image_list(image_list, dtype=np.float32):
  137. """Join a sequence of image along a new axis before first axis.
  138. References:
  139. `im_list_to_blob` in `py-faster-rcnn-master/lib/utils/blob.py`
  140. """
  141. assert isinstance(image_list, (tuple, list))
  142. max_dimension = np.array([image.ndim for image in image_list]).max()
  143. assert max_dimension in [2, 3]
  144. max_shape = np.array([image.shape[:2] for image in image_list]).max(axis=0)
  145. num_channels = []
  146. for image in image_list:
  147. if image.ndim == 2:
  148. num_channels.append(1)
  149. else:
  150. num_channels.append(image.shape[-1])
  151. assert len(set(num_channels) - set([1])) in [0, 1]
  152. max_num_channels = np.max(num_channels)
  153. blob = np.empty((len(image_list), max_shape[0], max_shape[1], max_num_channels), dtype=dtype)
  154. for k, image in enumerate(image_list):
  155. blob[k, :image.shape[0], :image.shape[1], :] = np.atleast_3d(image).astype(dtype, copy=False)
  156. if max_dimension == 2:
  157. blob = np.squeeze(blob, axis=-1)
  158. return blob
  159. def is_numpy_image(image):
  160. return isinstance(image, np.ndarray) and image.ndim in {2, 3}
  161. def is_gray_image(image, tol=3):
  162. assert is_numpy_image(image)
  163. if image.ndim == 2:
  164. return True
  165. elif image.ndim == 3:
  166. num_channels = image.shape[-1]
  167. if num_channels == 1:
  168. return True
  169. elif num_channels == 4:
  170. rgb = cv2.cvtColor(image, cv2.COLOR_BGRA2BGR)
  171. gray = cv2.cvtColor(rgb, cv2.COLOR_BGR2GRAY)
  172. gray3 = cv2.cvtColor(gray, cv2.COLOR_GRAY2BGR)
  173. mae = np.mean(cv2.absdiff(rgb, gray3))
  174. return mae <= tol
  175. elif num_channels == 3:
  176. gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
  177. gray3 = cv2.cvtColor(gray, cv2.COLOR_GRAY2BGR)
  178. mae = np.mean(cv2.absdiff(image, gray3))
  179. return mae <= tol
  180. else:
  181. return False
  182. else:
  183. return False