boxes_filter.py 1.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051
  1. import numpy as np
  2. def filter_small_boxes(boxes, min_width, min_height):
  3. """Filters all boxes with side smaller than min size.
  4. Args:
  5. boxes: a numpy array with shape [N, 4] holding N boxes.
  6. min_width (float): minimum width
  7. min_height (float): minimum height
  8. Returns:
  9. keep: indices of the boxes that have width larger than
  10. min_width and height larger than min_height.
  11. References:
  12. `_filter_boxes` in py-faster-rcnn
  13. `prune_small_boxes` in TensorFlow object detection API.
  14. `structures.Boxes.nonempty` in detectron2
  15. `ops.boxes.remove_small_boxes` in torchvision
  16. """
  17. widths = boxes[:, 2] - boxes[:, 0]
  18. heights = boxes[:, 3] - boxes[:, 1]
  19. keep = (widths >= min_width)
  20. keep &= (heights >= min_height)
  21. return np.nonzero(keep)[0]
  22. def filter_boxes_outside(boxes, reference_box):
  23. """Filters bounding boxes that fall outside reference box.
  24. References:
  25. `prune_outside_window` in TensorFlow object detection API.
  26. """
  27. x_min, y_min, x_max, y_max = reference_box[:4]
  28. keep = ((boxes[:, 0] >= x_min) & (boxes[:, 1] >= y_min) &
  29. (boxes[:, 2] <= x_max) & (boxes[:, 3] <= y_max))
  30. return np.nonzero(keep)[0]
  31. def filter_boxes_completely_outside(boxes, reference_box):
  32. """Filters bounding boxes that fall completely outside of reference box.
  33. References:
  34. `prune_completely_outside_window` in TensorFlow object detection API.
  35. """
  36. x_min, y_min, x_max, y_max = reference_box[:4]
  37. keep = ((boxes[:, 0] < x_max) & (boxes[:, 1] < y_max) &
  38. (boxes[:, 2] > x_min) & (boxes[:, 3] > y_min))
  39. return np.nonzero(keep)[0]