Explorar o código

refactor pairwise_intersection

quarrying %!s(int64=3) %!d(string=hai) anos
pai
achega
3ebbdc6ce4
Modificáronse 1 ficheiros con 3 adicións e 6 borrados
  1. 3 6
      khandy/boxes/boxes_overlap.py

+ 3 - 6
khandy/boxes/boxes_overlap.py

@@ -39,17 +39,14 @@ def pairwise_intersection(boxes1, boxes2):
         `core.box_list_ops.intersection` in Tensorflow object detection API
         `utils.box_list_ops.intersection` in Tensorflow object detection API
     """
-    rows = boxes1.shape[0]
-    cols = boxes2.shape[0]
-    if rows * cols == 0:
-        return np.zeros((rows, cols), dtype=boxes1.dtype)
+    if boxes1.shape[0] * boxes2.shape[0] == 0:
+        return np.zeros((boxes1.shape[0], boxes2.shape[0]), dtype=boxes1.dtype)
 
-    intersect_areas = np.empty((rows, cols), dtype=boxes1.dtype)
     swap = False
     if boxes1.shape[0] > boxes2.shape[0]:
         boxes1, boxes2 = boxes2, boxes1
-        intersect_areas = np.empty((cols, rows), dtype=boxes1.dtype)
         swap = True
+    intersect_areas = np.empty((boxes1.shape[0], boxes2.shape[0]), dtype=boxes1.dtype)
 
     for i in range(boxes1.shape[0]):
         max_x_mins = np.maximum(boxes1[i, 0], boxes2[:, 0])