boxes_and_indices.py 1.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960
  1. import numpy as np
  2. def _concat(arr_list, axis=0):
  3. """Avoids a copy if there is only a single element in a list.
  4. """
  5. if len(arr_list) == 1:
  6. return arr_list[0]
  7. return np.concatenate(arr_list, axis)
  8. def convert_boxes_list_to_boxes_and_indices(boxes_list):
  9. """
  10. Args:
  11. boxes_list (np.ndarray): list or tuple of ndarray with shape (N_i, 4+K)
  12. Returns:
  13. boxes (ndarray): shape (M, 4+K) where M is sum of N_i.
  14. indices (ndarray): shape (M, 1) where M is sum of N_i.
  15. References:
  16. `mmdet.core.bbox.bbox2roi` in mmdetection
  17. `convert_boxes_to_roi_format` in TorchVision
  18. `modeling.poolers.convert_boxes_to_pooler_format` in detectron2
  19. """
  20. assert isinstance(boxes_list, (list, tuple))
  21. boxes = _concat(boxes_list, axis=0)
  22. indices_list = [np.full((len(b), 1), i, boxes.dtype)
  23. for i, b in enumerate(boxes_list)]
  24. indices = _concat(indices_list, axis=0)
  25. return boxes, indices
  26. def convert_boxes_and_indices_to_boxes_list(boxes, indices, num_indices):
  27. """
  28. Args:
  29. boxes (np.ndarray): shape (N, 4+K)
  30. indices (np.ndarray): shape (N,) or (N, 1), maybe batch index
  31. in mini-batch or class label index.
  32. num_indices (int): number of index.
  33. Returns:
  34. list (ndarray): boxes list of each index
  35. References:
  36. `mmdet.core.bbox2result` in mmdetection
  37. `mmdet.core.bbox.roi2bbox` in mmdetection
  38. `convert_boxes_to_roi_format` in TorchVision
  39. `modeling.poolers.convert_boxes_to_pooler_format` in detectron2
  40. """
  41. if boxes.shape[0] == 0:
  42. return [np.zeros((0, boxes.shape[1]), dtype=np.float32)
  43. for i in range(num_indices)]
  44. else:
  45. if indices.ndim == 2:
  46. indices = np.squeeze(indices, axis=-1)
  47. return [boxes[indices == i, :] for i in range(num_indices)]