boxes_filter.py 3.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111
  1. import numpy as np
  2. def filter_small_boxes(boxes, min_width, min_height):
  3. """Filters all boxes with side smaller than min size.
  4. Args:
  5. boxes: a numpy array with shape [N, 4] holding N boxes.
  6. min_width (float): minimum width
  7. min_height (float): minimum height
  8. Returns:
  9. keep: indices of the boxes that have width larger than
  10. min_width and height larger than min_height.
  11. References:
  12. `_filter_boxes` in py-faster-rcnn
  13. `prune_small_boxes` in TensorFlow object detection API.
  14. `structures.Boxes.nonempty` in detectron2
  15. `ops.boxes.remove_small_boxes` in torchvision
  16. """
  17. widths = boxes[:, 2] - boxes[:, 0]
  18. heights = boxes[:, 3] - boxes[:, 1]
  19. keep = (widths >= min_width)
  20. keep &= (heights >= min_height)
  21. return np.nonzero(keep)[0]
  22. def filter_boxes_outside(boxes, reference_box):
  23. """Filters bounding boxes that fall outside reference box.
  24. References:
  25. `prune_outside_window` in TensorFlow object detection API.
  26. """
  27. x_min, y_min, x_max, y_max = reference_box[:4]
  28. keep = ((boxes[:, 0] >= x_min) & (boxes[:, 1] >= y_min) &
  29. (boxes[:, 2] <= x_max) & (boxes[:, 3] <= y_max))
  30. return np.nonzero(keep)[0]
  31. def filter_boxes_completely_outside(boxes, reference_box):
  32. """Filters bounding boxes that fall completely outside of reference box.
  33. References:
  34. `prune_completely_outside_window` in TensorFlow object detection API.
  35. """
  36. x_min, y_min, x_max, y_max = reference_box[:4]
  37. keep = ((boxes[:, 0] < x_max) & (boxes[:, 1] < y_max) &
  38. (boxes[:, 2] > x_min) & (boxes[:, 3] > y_min))
  39. return np.nonzero(keep)[0]
  40. def non_max_suppression(boxes, scores, thresh, classes=None, ratio_type="iou"):
  41. """Greedily select boxes with high confidence
  42. Args:
  43. boxes: [[x_min, y_min, x_max, y_max], ...]
  44. scores: object confidence
  45. thresh: retain overlap_ratio <= thresh
  46. classes: class labels
  47. Returns:
  48. indexes to keep
  49. References:
  50. `py_cpu_nms` in py-faster-rcnn
  51. torchvision.ops.nms
  52. torchvision.ops.batched_nms
  53. """
  54. if boxes.size == 0:
  55. return np.empty((0,), dtype=np.int64)
  56. if classes is not None:
  57. # strategy: in order to perform NMS independently per class,
  58. # we add an offset to all the boxes. The offset is dependent
  59. # only on the class idx, and is large enough so that boxes
  60. # from different classes do not overlap
  61. max_coordinate = np.max(boxes)
  62. offsets = classes * (max_coordinate + 1)
  63. boxes = boxes + offsets[:, None]
  64. x_mins = boxes[:, 0]
  65. y_mins = boxes[:, 1]
  66. x_maxs = boxes[:, 2]
  67. y_maxs = boxes[:, 3]
  68. areas = (x_maxs - x_mins) * (y_maxs - y_mins)
  69. order = scores.flatten().argsort()[::-1]
  70. keep = []
  71. while order.size > 0:
  72. i = order[0]
  73. keep.append(i)
  74. max_x_mins = np.maximum(x_mins[i], x_mins[order[1:]])
  75. max_y_mins = np.maximum(y_mins[i], y_mins[order[1:]])
  76. min_x_maxs = np.minimum(x_maxs[i], x_maxs[order[1:]])
  77. min_y_maxs = np.minimum(y_maxs[i], y_maxs[order[1:]])
  78. widths = np.maximum(0, min_x_maxs - max_x_mins)
  79. heights = np.maximum(0, min_y_maxs - max_y_mins)
  80. intersect_area = widths * heights
  81. if ratio_type in ["union", 'iou']:
  82. ratio = intersect_area / (areas[i] + areas[order[1:]] - intersect_area)
  83. elif ratio_type == "min":
  84. ratio = intersect_area / np.minimum(areas[i], areas[order[1:]])
  85. else:
  86. raise ValueError('Unsupported ratio_type. Got {}'.format(ratio_type))
  87. inds = np.nonzero(ratio <= thresh)[0]
  88. order = order[inds + 1]
  89. return np.asarray(keep)