Quellcode durchsuchen

refactor pairwise_intersection

quarrying vor 3 Jahren
Ursprung
Commit
3ebbdc6ce4
1 geänderte Dateien mit 3 neuen und 6 gelöschten Zeilen
  1. 3 6
      khandy/boxes/boxes_overlap.py

+ 3 - 6
khandy/boxes/boxes_overlap.py

@@ -39,17 +39,14 @@ def pairwise_intersection(boxes1, boxes2):
         `core.box_list_ops.intersection` in Tensorflow object detection API
         `utils.box_list_ops.intersection` in Tensorflow object detection API
     """
-    rows = boxes1.shape[0]
-    cols = boxes2.shape[0]
-    if rows * cols == 0:
-        return np.zeros((rows, cols), dtype=boxes1.dtype)
+    if boxes1.shape[0] * boxes2.shape[0] == 0:
+        return np.zeros((boxes1.shape[0], boxes2.shape[0]), dtype=boxes1.dtype)
 
-    intersect_areas = np.empty((rows, cols), dtype=boxes1.dtype)
     swap = False
     if boxes1.shape[0] > boxes2.shape[0]:
         boxes1, boxes2 = boxes2, boxes1
-        intersect_areas = np.empty((cols, rows), dtype=boxes1.dtype)
         swap = True
+    intersect_areas = np.empty((boxes1.shape[0], boxes2.shape[0]), dtype=boxes1.dtype)
 
     for i in range(boxes1.shape[0]):
         max_x_mins = np.maximum(boxes1[i, 0], boxes2[:, 0])