Przeglądaj źródła

refactor non_max_suppression

quarrying 3 lat temu
rodzic
commit
53b7f7c316
1 zmienionych plików z 19 dodań i 5 usunięć
  1. 19 5
      khandy/boxes/boxes_filter.py

+ 19 - 5
khandy/boxes/boxes_filter.py

@@ -50,26 +50,41 @@ def filter_boxes_completely_outside(boxes, reference_box):
     return np.nonzero(keep)[0]
     
 
-def non_max_suppression(boxes, scores, thresh, ratio_type="iou"):
+def non_max_suppression(boxes, scores, thresh, classes=None, ratio_type="iou"):
     """Greedily select boxes with high confidence
     Args:
         boxes: [[x_min, y_min, x_max, y_max], ...]
         scores: object confidence
         thresh: retain overlap_ratio <= thresh
+        classes: class labels
         
     Returns:
         indexes to keep
         
     References:
         `py_cpu_nms` in py-faster-rcnn
+        torchvision.ops.nms
+        torchvision.ops.batched_nms
     """
+
+    if boxes.size == 0:
+        return np.empty((0,), dtype=np.int64)
+    if classes is not None:
+        # strategy: in order to perform NMS independently per class,
+        # we add an offset to all the boxes. The offset is dependent
+        # only on the class idx, and is large enough so that boxes
+        # from different classes do not overlap
+        max_coordinate = np.max(boxes)
+        offsets = classes * (max_coordinate + 1)
+        boxes = boxes + offsets[:, None]
+    
     x_mins = boxes[:, 0]
     y_mins = boxes[:, 1]
     x_maxs = boxes[:, 2]
     y_maxs = boxes[:, 3]
     areas = (x_maxs - x_mins) * (y_maxs - y_mins)
-    order = scores.argsort()[::-1]
-
+    order = scores.flatten().argsort()[::-1]
+    
     keep = []
     while order.size > 0:
         i = order[0]
@@ -92,6 +107,5 @@ def non_max_suppression(boxes, scores, thresh, ratio_type="iou"):
             
         inds = np.nonzero(ratio <= thresh)[0]
         order = order[inds + 1]
-
-    return keep
+    return np.asarray(keep)