소스 검색

refactor get_similarity_transform

quarrying 4 년 전
부모
커밋
b00467a3d4
1개의 변경된 파일17개의 추가작업 그리고 12개의 파일을 삭제
  1. 17 12
      khandy/image/align_and_crop.py

+ 17 - 12
khandy/image/align_and_crop.py

@@ -21,19 +21,24 @@ def get_similarity_transform(src_pts, dst_pts):
     assert (src_pts.ndim == 2) and (src_pts.shape[-1] == 2)
     
     npts = src_pts.shape[0]
-    A = np.empty((npts * 2, 4))
-    b = np.empty((npts * 2,))
-    for k in range(npts):
-        A[2 * k + 0] = [src_pts[k, 0], -src_pts[k, 1], 1, 0]
-        A[2 * k + 1] = [src_pts[k, 1], src_pts[k, 0], 0, 1]
-        b[2 * k + 0] = dst_pts[k, 0]
-        b[2 * k + 1] = dst_pts[k, 1]
-        
+    src_x = src_pts[:, 0].reshape((-1, 1))
+    src_y = src_pts[:, 1].reshape((-1, 1))
+    tmp1 = np.hstack((src_x, -src_y, np.ones((npts, 1)), np.zeros((npts, 1))))
+    tmp2 = np.hstack((src_y, src_x, np.zeros((npts, 1)), np.ones((npts, 1))))
+    A = np.vstack((tmp1, tmp2))
+
+    dst_x = dst_pts[:, 0].reshape((-1, 1))
+    dst_y = dst_pts[:, 1].reshape((-1, 1))
+    b = np.vstack((dst_x, dst_y))
+
     x = np.linalg.lstsq(A, b, rcond=-1)[0]
-    xform_matrix = np.empty((3, 3))
-    xform_matrix[0] = [x[0], -x[1], x[2]]
-    xform_matrix[1] = [x[1], x[0], x[3]]
-    xform_matrix[2] = [0, 0, 1]
+    x = np.squeeze(x)
+    sc, ss, tx, ty = x[0], x[1], x[2], x[3]
+    xform_matrix = np.array([
+        [sc, -ss, tx],
+        [ss,  sc, ty],
+        [ 0,   0,  1]
+    ])
     return xform_matrix