Browse Source

add FasterRcnnBoxCoder

quarrying 3 years ago
parent
commit
810b11a495
2 changed files with 70 additions and 0 deletions
  1. 1 0
      khandy/boxes/__init__.py
  2. 69 0
      khandy/boxes/boxes_coder.py

+ 1 - 0
khandy/boxes/__init__.py

@@ -2,6 +2,7 @@ from .boxes_clip import *
 from .boxes_overlap import *
 from .boxes_filter import *
 from .boxes_convert import *
+from .boxes_coder import *
 
 from .boxes_transform_flip import *
 from .boxes_transform_rotate import *

+ 69 - 0
khandy/boxes/boxes_coder.py

@@ -0,0 +1,69 @@
+import numpy as np
+
+
+class FasterRcnnBoxCoder:
+    """Faster RCNN box coder.
+    
+    Notes:
+        boxes should be in cxcywh format.
+    """
+    
+    def __init__(self, stddevs=None):
+        """Constructor for FasterRcnnBoxCoder.
+      
+        Args:
+          stddevs: List of 4 positive scalars to scale ty, tx, th and tw.
+            If set to None, does not perform scaling. For Faster RCNN,
+            the open-source implementation recommends using [0.1, 0.1, 0.2, 0.2].
+        """
+        if stddevs:
+            assert len(stddevs) == 4
+            for scalar in stddevs:
+                assert scalar > 0
+        self.stddevs = stddevs
+
+    def encode(self, boxes, reference_boxes, copy=True):
+        """Encode boxes with respect to reference boxes.
+        """
+        if copy:
+            boxes = boxes.copy()
+            
+        boxes[..., 2:4] += 1e-8
+        reference_boxes[..., 2:4] += 1e-8
+        
+        boxes[..., 0:2] -= reference_boxes[..., 0:2]
+        boxes[..., 0:2] /= reference_boxes[..., 2:4]
+        boxes[..., 2:4] /= reference_boxes[..., 2:4]
+        boxes[..., 2:4] = np.log(boxes[..., 2:4], boxes[..., 2:4])
+        if self.stddevs:
+            boxes[..., 0:4] /= self.stddevs
+        return boxes
+
+    def decode(self, rel_boxes, reference_boxes, copy=True):
+        """Decode relative codes to boxes.
+        """
+        if copy:
+            rel_boxes = rel_boxes.copy()
+            
+        if self.stddevs:
+            rel_boxes[..., 0:4] *= self.stddevs
+        
+        rel_boxes[..., 0:2] *= reference_boxes[..., 2:4]
+        rel_boxes[..., 0:2] += reference_boxes[..., 0:2]
+        rel_boxes[..., 2:4] = np.exp(rel_boxes[..., 2:4], rel_boxes[..., 2:4])
+        rel_boxes[..., 2:4] *= reference_boxes[..., 2:4]
+        return rel_boxes
+    
+    def decode_points(self, rel_points, reference_boxes, copy=True):
+        """Decode relative codes to points.
+        """
+        if copy:
+            rel_points = rel_points.copy()
+        if self.stddevs:
+            rel_points[..., 0::2] *= self.stddevs[0]
+            rel_points[..., 1::2] *= self.stddevs[1]
+        rel_points[..., 0::2] *= reference_boxes[..., 2:3]
+        rel_points[..., 1::2] *= reference_boxes[..., 3:4]
+        rel_points[..., 0::2] += reference_boxes[..., 0:1]
+        rel_points[..., 1::2] += reference_boxes[..., 1:2]
+        return rel_points