batched_boxes.py 1.6 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152
  1. import numpy as np
  2. def _concat(arr_list, axis=0):
  3. """Avoids a copy if there is only a single element in a list.
  4. """
  5. if len(arr_list) == 1:
  6. return arr_list[0]
  7. return np.concatenate(arr_list, axis)
  8. def convert_boxes_list_to_batched_boxes(boxes_list):
  9. """
  10. Args:
  11. boxes_list: list or tuple of ndarray with shape (N_i, 4+K)
  12. Returns:
  13. ndarray with shape (M, 5+K) where M is sum of N_i.
  14. References:
  15. `mmdet.core.bbox.bbox2roi` in mmdetection
  16. `convert_boxes_to_roi_format` in TorchVision
  17. `modeling.poolers.convert_boxes_to_pooler_format` in detectron2
  18. """
  19. assert isinstance(boxes_list, (list, tuple))
  20. concat_boxes = _concat(boxes_list, axis=0)
  21. indices_list = [np.full((len(b), 1), i, concat_boxes.dtype)
  22. for i, b in enumerate(boxes_list)]
  23. indices = _concat(indices_list, axis=0)
  24. batched_boxes = np.hstack([indices, concat_boxes])
  25. return batched_boxes
  26. def convert_batched_boxes_to_boxes_list(batched_boxes):
  27. """
  28. References:
  29. `mmdet.core.bbox.roi2bbox` in mmdetection
  30. `convert_boxes_to_roi_format` in TorchVision
  31. `modeling.poolers.convert_boxes_to_pooler_format` in detectron2
  32. """
  33. assert isinstance(batched_boxes, np.ndarray)
  34. assert batched_boxes.ndim == 2 and batched_boxes.shape[-1] >= 5
  35. boxes_list = []
  36. indices = np.unique(batched_boxes[:, 0])
  37. for index in indices:
  38. inds = (batched_boxes[:, 0] == index)
  39. boxes = batched_boxes[inds, 1:]
  40. boxes_list.append(boxes)
  41. return boxes_list