boxes_coder.py 2.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869
  1. import numpy as np
  2. class FasterRcnnBoxCoder:
  3. """Faster RCNN box coder.
  4. Notes:
  5. boxes should be in cxcywh format.
  6. """
  7. def __init__(self, stddevs=None):
  8. """Constructor for FasterRcnnBoxCoder.
  9. Args:
  10. stddevs: List of 4 positive scalars to scale ty, tx, th and tw.
  11. If set to None, does not perform scaling. For Faster RCNN,
  12. the open-source implementation recommends using [0.1, 0.1, 0.2, 0.2].
  13. """
  14. if stddevs:
  15. assert len(stddevs) == 4
  16. for scalar in stddevs:
  17. assert scalar > 0
  18. self.stddevs = stddevs
  19. def encode(self, boxes, reference_boxes, copy=True):
  20. """Encode boxes with respect to reference boxes.
  21. """
  22. if copy:
  23. boxes = boxes.copy()
  24. boxes[..., 2:4] += 1e-8
  25. reference_boxes[..., 2:4] += 1e-8
  26. boxes[..., 0:2] -= reference_boxes[..., 0:2]
  27. boxes[..., 0:2] /= reference_boxes[..., 2:4]
  28. boxes[..., 2:4] /= reference_boxes[..., 2:4]
  29. boxes[..., 2:4] = np.log(boxes[..., 2:4], boxes[..., 2:4])
  30. if self.stddevs:
  31. boxes[..., 0:4] /= self.stddevs
  32. return boxes
  33. def decode(self, rel_boxes, reference_boxes, copy=True):
  34. """Decode relative codes to boxes.
  35. """
  36. if copy:
  37. rel_boxes = rel_boxes.copy()
  38. if self.stddevs:
  39. rel_boxes[..., 0:4] *= self.stddevs
  40. rel_boxes[..., 0:2] *= reference_boxes[..., 2:4]
  41. rel_boxes[..., 0:2] += reference_boxes[..., 0:2]
  42. rel_boxes[..., 2:4] = np.exp(rel_boxes[..., 2:4], rel_boxes[..., 2:4])
  43. rel_boxes[..., 2:4] *= reference_boxes[..., 2:4]
  44. return rel_boxes
  45. def decode_points(self, rel_points, reference_boxes, copy=True):
  46. """Decode relative codes to points.
  47. """
  48. if copy:
  49. rel_points = rel_points.copy()
  50. if self.stddevs:
  51. rel_points[..., 0::2] *= self.stddevs[0]
  52. rel_points[..., 1::2] *= self.stddevs[1]
  53. rel_points[..., 0::2] *= reference_boxes[..., 2:3]
  54. rel_points[..., 1::2] *= reference_boxes[..., 3:4]
  55. rel_points[..., 0::2] += reference_boxes[..., 0:1]
  56. rel_points[..., 1::2] += reference_boxes[..., 1:2]
  57. return rel_points