Procházet zdrojové kódy

refactor convert_boxes_list_to_batched_boxes

quarrying před 4 roky
rodič
revize
a6ba99a8cd
1 změnil soubory, kde provedl 13 přidání a 8 odebrání
  1. 13 8
      khandy/boxes/batched_boxes.py

+ 13 - 8
khandy/boxes/batched_boxes.py

@@ -1,6 +1,14 @@
 import numpy as np
 
 
+def _concat(arr_list, axis=0):
+    """Avoids a copy if there is only a single element in a list.
+    """
+    if len(arr_list) == 1:
+        return arr_list[0]
+    return np.concatenate(arr_list, axis)
+    
+    
 def convert_boxes_list_to_batched_boxes(boxes_list):
     """
     Args:
@@ -10,19 +18,15 @@ def convert_boxes_list_to_batched_boxes(boxes_list):
         ndarray with shape (M, 5+K) where M is sum of N_i.
         
     References:
-        `convert_boxes_to_roi_format` in TorchVision
         `mmdet.core.bbox.bbox2roi` in mmdetection
+        `convert_boxes_to_roi_format` in TorchVision
         `modeling.poolers.convert_boxes_to_pooler_format` in detectron2
     """
     assert isinstance(boxes_list, (list, tuple))
-    # avoids a copy if there is only a single element in a list
-    if len(boxes_list) == 1:
-        concat_boxes = boxes_list[0]
-    else:
-        concat_boxes = np.concatenate(boxes_list, axis=0)
+    concat_boxes = _concat(boxes_list, axis=0)
     indices_list = [np.full((len(b), 1), i, concat_boxes.dtype) 
                     for i, b in enumerate(boxes_list)]
-    indices = np.concatenate(indices_list, axis=0)
+    indices = _concat(indices_list, axis=0)
     batched_boxes = np.hstack([indices, concat_boxes])
     return batched_boxes
     
@@ -30,8 +34,9 @@ def convert_boxes_list_to_batched_boxes(boxes_list):
 def convert_batched_boxes_to_boxes_list(batched_boxes):
     """
     References:
-        `convert_boxes_to_roi_format` in TorchVision
         `mmdet.core.bbox.roi2bbox` in mmdetection
+        `convert_boxes_to_roi_format` in TorchVision
+        `modeling.poolers.convert_boxes_to_pooler_format` in detectron2
     """
     assert isinstance(batched_boxes, np.ndarray)
     assert batched_boxes.ndim == 2 and batched_boxes.shape[-1] >= 5