boxes_filter.py 3.1 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697
  1. import numpy as np
  2. def filter_small_boxes(boxes, min_width, min_height):
  3. """Filters all boxes with side smaller than min size.
  4. Args:
  5. boxes: a numpy array with shape [N, 4] holding N boxes.
  6. min_width (float): minimum width
  7. min_height (float): minimum height
  8. Returns:
  9. keep: indices of the boxes that have width larger than
  10. min_width and height larger than min_height.
  11. References:
  12. `_filter_boxes` in py-faster-rcnn
  13. `prune_small_boxes` in TensorFlow object detection API.
  14. `structures.Boxes.nonempty` in detectron2
  15. `ops.boxes.remove_small_boxes` in torchvision
  16. """
  17. widths = boxes[:, 2] - boxes[:, 0]
  18. heights = boxes[:, 3] - boxes[:, 1]
  19. keep = (widths >= min_width)
  20. keep &= (heights >= min_height)
  21. return np.nonzero(keep)[0]
  22. def filter_boxes_outside(boxes, reference_box):
  23. """Filters bounding boxes that fall outside reference box.
  24. References:
  25. `prune_outside_window` in TensorFlow object detection API.
  26. """
  27. x_min, y_min, x_max, y_max = reference_box[:4]
  28. keep = ((boxes[:, 0] >= x_min) & (boxes[:, 1] >= y_min) &
  29. (boxes[:, 2] <= x_max) & (boxes[:, 3] <= y_max))
  30. return np.nonzero(keep)[0]
  31. def filter_boxes_completely_outside(boxes, reference_box):
  32. """Filters bounding boxes that fall completely outside of reference box.
  33. References:
  34. `prune_completely_outside_window` in TensorFlow object detection API.
  35. """
  36. x_min, y_min, x_max, y_max = reference_box[:4]
  37. keep = ((boxes[:, 0] < x_max) & (boxes[:, 1] < y_max) &
  38. (boxes[:, 2] > x_min) & (boxes[:, 3] > y_min))
  39. return np.nonzero(keep)[0]
  40. def non_max_suppression(boxes, scores, thresh, ratio_type="iou"):
  41. """Greedily select boxes with high confidence
  42. Args:
  43. boxes: [[x_min, y_min, x_max, y_max], ...]
  44. scores: object confidence
  45. thresh: retain overlap_ratio <= thresh
  46. Returns:
  47. indexes to keep
  48. References:
  49. `py_cpu_nms` in py-faster-rcnn
  50. """
  51. x_mins = boxes[:, 0]
  52. y_mins = boxes[:, 1]
  53. x_maxs = boxes[:, 2]
  54. y_maxs = boxes[:, 3]
  55. areas = (x_maxs - x_mins) * (y_maxs - y_mins)
  56. order = scores.argsort()[::-1]
  57. keep = []
  58. while order.size > 0:
  59. i = order[0]
  60. keep.append(i)
  61. max_x_mins = np.maximum(x_mins[i], x_mins[order[1:]])
  62. max_y_mins = np.maximum(y_mins[i], y_mins[order[1:]])
  63. min_x_maxs = np.minimum(x_maxs[i], x_maxs[order[1:]])
  64. min_y_maxs = np.minimum(y_maxs[i], y_maxs[order[1:]])
  65. widths = np.maximum(0, min_x_maxs - max_x_mins)
  66. heights = np.maximum(0, min_y_maxs - max_y_mins)
  67. intersect_area = widths * heights
  68. if ratio_type in ["union", 'iou']:
  69. ratio = intersect_area / (areas[i] + areas[order[1:]] - intersect_area)
  70. elif ratio_type == "min":
  71. ratio = intersect_area / np.minimum(areas[i], areas[order[1:]])
  72. else:
  73. raise ValueError('Unsupported ratio_type. Got {}'.format(ratio_type))
  74. inds = np.nonzero(ratio <= thresh)[0]
  75. order = order[inds + 1]
  76. return keep