align_and_crop.py 2.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960
  1. import cv2
  2. import numpy as np
  3. def get_similarity_transform(src_pts, dst_pts):
  4. """Get similarity transform matrix from src_pts to dst_pts
  5. Args:
  6. src_pts: Kx2 np.array
  7. source points matrix, each row is a pair of coordinates (x, y)
  8. dst_pts: Kx2 np.array
  9. destination points matrix, each row is a pair of coordinates (x, y)
  10. Returns:
  11. xform_matrix: 3x3 np.array
  12. transform matrix from src_pts to dst_pts
  13. """
  14. src_pts = np.asarray(src_pts)
  15. dst_pts = np.asarray(dst_pts)
  16. assert src_pts.shape == dst_pts.shape
  17. assert (src_pts.ndim == 2) and (src_pts.shape[-1] == 2)
  18. npts = src_pts.shape[0]
  19. src_x = src_pts[:, 0].reshape((-1, 1))
  20. src_y = src_pts[:, 1].reshape((-1, 1))
  21. tmp1 = np.hstack((src_x, -src_y, np.ones((npts, 1)), np.zeros((npts, 1))))
  22. tmp2 = np.hstack((src_y, src_x, np.zeros((npts, 1)), np.ones((npts, 1))))
  23. A = np.vstack((tmp1, tmp2))
  24. dst_x = dst_pts[:, 0].reshape((-1, 1))
  25. dst_y = dst_pts[:, 1].reshape((-1, 1))
  26. b = np.vstack((dst_x, dst_y))
  27. x = np.linalg.lstsq(A, b, rcond=-1)[0]
  28. x = np.squeeze(x)
  29. sc, ss, tx, ty = x[0], x[1], x[2], x[3]
  30. xform_matrix = np.array([
  31. [sc, -ss, tx],
  32. [ss, sc, ty],
  33. [ 0, 0, 1]
  34. ])
  35. return xform_matrix
  36. def align_and_crop(image, landmarks, std_landmarks, align_size,
  37. border_value=0, return_transform_matrix=False):
  38. landmarks = np.asarray(landmarks)
  39. std_landmarks = np.asarray(std_landmarks)
  40. xform_matrix = get_similarity_transform(landmarks, std_landmarks)
  41. landmarks_ex = np.pad(landmarks, ((0,0),(0,1)), mode='constant', constant_values=1)
  42. dst_landmarks = np.dot(landmarks_ex, xform_matrix[:2,:].T)
  43. dst_image = cv2.warpAffine(image, xform_matrix[:2,:], dsize=align_size,
  44. borderValue=border_value)
  45. if return_transform_matrix:
  46. return dst_image, dst_landmarks, xform_matrix
  47. else:
  48. return dst_image, dst_landmarks