utils_draw.py 5.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140
  1. import numpy as np
  2. import PIL
  3. from PIL import Image
  4. from PIL import ImageDraw
  5. from PIL import ImageFont
  6. from PIL import ImageColor
  7. def _is_legal_color(color):
  8. if color is None:
  9. return True
  10. if isinstance(color, str):
  11. return True
  12. return isinstance(color, (tuple, list)) and len(color) == 3
  13. def _normalize_color(color, pil_mode, swap_rgb=False):
  14. if color is None:
  15. return color
  16. if isinstance(color, str):
  17. color = ImageColor.getrgb(color)
  18. gray = color[0]
  19. if swap_rgb:
  20. color = (color[2], color[1], color[0])
  21. if pil_mode == 'L':
  22. color = gray
  23. return color
  24. def draw_text(image, text, position, color=(255,0,0), font=None, font_size=15):
  25. """Draws text on given image.
  26. Args:
  27. image (ndarray).
  28. text (str): text to be drawn.
  29. position (Tuple[int, int]): position where to be drawn.
  30. color (List[Union[str, Tuple[int, int, int]]]): text color.
  31. font (str): A filename or file-like object containing a TrueType font. If the file is not found in this
  32. filename, the loader may also search in other directories, such as the `fonts/` directory on Windows
  33. or `/Library/Fonts/`, `/System/Library/Fonts/` and `~/Library/Fonts/` on macOS.
  34. font_size (int): The requested font size in points.
  35. References:
  36. torchvision.utils.draw_bounding_boxes
  37. """
  38. if isinstance(image, np.ndarray):
  39. pil_image = Image.fromarray(image)
  40. elif isinstance(image, PIL.Image.Image):
  41. pil_image = image
  42. else:
  43. raise ValueError('Unsupported image type!')
  44. assert pil_image.mode in ['L', 'RGB', 'RGBA']
  45. assert _is_legal_color(color)
  46. color = _normalize_color(color, pil_image.mode, isinstance(image, np.ndarray))
  47. if font is None:
  48. font_object = ImageFont.load_default()
  49. else:
  50. font_object = ImageFont.truetype(font, size=font_size)
  51. draw = ImageDraw.Draw(pil_image)
  52. draw.text((position[0], position[1]), text,
  53. fill=color, font=font_object)
  54. if isinstance(image, np.ndarray):
  55. return np.asarray(pil_image)
  56. return pil_image
  57. def draw_bounding_boxes(image, boxes, labels=None, colors=None,
  58. fill=False, width=1, font=None, font_size=15):
  59. """Draws bounding boxes on given image.
  60. Args:
  61. image (ndarray).
  62. boxes (ndarray): ndarray of size (N, 4) containing bounding boxes in (xmin, ymin, xmax, ymax) format.
  63. labels (List[str]): List containing the labels of bounding boxes.
  64. colors (List[Union[str, Tuple[int, int, int]]]): List containing the colors of bounding boxes or labels.
  65. fill (bool): If `True` fills the bounding box with specified color.
  66. width (int): Width of bounding box.
  67. font (str): A filename or file-like object containing a TrueType font. If the file is not found in this
  68. filename, the loader may also search in other directories, such as the `fonts/` directory on Windows
  69. or `/Library/Fonts/`, `/System/Library/Fonts/` and `~/Library/Fonts/` on macOS.
  70. font_size (int): The requested font size in points.
  71. References:
  72. torchvision.utils.draw_bounding_boxes
  73. """
  74. if isinstance(image, np.ndarray):
  75. pil_image = Image.fromarray(image)
  76. elif isinstance(image, PIL.Image.Image):
  77. pil_image = image
  78. else:
  79. raise ValueError('Unsupported image type!')
  80. pil_image = pil_image.convert('RGB')
  81. if font is None:
  82. font_object = ImageFont.load_default()
  83. else:
  84. font_object = ImageFont.truetype(font, size=font_size)
  85. if fill:
  86. draw = ImageDraw.Draw(pil_image, "RGBA")
  87. else:
  88. draw = ImageDraw.Draw(pil_image)
  89. for i, bbox in enumerate(boxes):
  90. if colors is None:
  91. color = None
  92. else:
  93. color = colors[i]
  94. assert _is_legal_color(color)
  95. color = _normalize_color(color, pil_image.mode, isinstance(image, np.ndarray))
  96. if fill:
  97. if color is None:
  98. fill_color = (255, 255, 255, 100)
  99. elif isinstance(color, str):
  100. # This will automatically raise Error if rgb cannot be parsed.
  101. fill_color = ImageColor.getrgb(color) + (100,)
  102. elif isinstance(color, tuple):
  103. fill_color = color + (100,)
  104. # the first argument of ImageDraw.rectangle:
  105. # in old version only supports [(x0, y0), (x1, y1)]
  106. # in new version supports either [(x0, y0), (x1, y1)] or [x0, y0, x1, y1]
  107. draw.rectangle([(bbox[0], bbox[1]), (bbox[2], bbox[3])], width=width, outline=color, fill=fill_color)
  108. else:
  109. draw.rectangle([(bbox[0], bbox[1]), (bbox[2], bbox[3])], width=width, outline=color)
  110. if labels is not None:
  111. margin = width + 1
  112. draw.text((bbox[0] + margin, bbox[1] + margin), labels[i], fill=color, font=font_object)
  113. if isinstance(image, np.ndarray):
  114. return np.asarray(pil_image)
  115. return pil_image