utils_feature.py 2.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263
  1. from collections import OrderedDict
  2. import numpy as np
  3. from .utils_dict import get_dict_first_item as _get_dict_first_item
  4. def convert_feature_dict_to_array(feature_dict):
  5. one_feature = _get_dict_first_item(feature_dict)[1]
  6. num_features = sum([len(item) for item in feature_dict.values()])
  7. key_list = []
  8. start_index = 0
  9. feature_array = np.empty((num_features, one_feature.shape[-1]), one_feature.dtype)
  10. for key, value in feature_dict.items():
  11. feature_array[start_index: start_index + len(value)]= value
  12. key_list += [key] * len(value)
  13. start_index += len(value)
  14. return key_list, feature_array
  15. def convert_feature_array_to_dict(key_list, feature_array):
  16. assert len(key_list) == len(feature_array)
  17. feature_dict = OrderedDict()
  18. for key, feat in zip(key_list, feature_array):
  19. feature_dict.setdefault(key, []).append(feat)
  20. for label in feature_dict.keys():
  21. feature_dict[label] = np.vstack(feature_dict[label])
  22. return feature_dict
  23. def pairwise_distances(x, y, squared=True):
  24. """Compute pairwise (squared) Euclidean distances.
  25. References:
  26. [2016 CVPR] Deep Metric Learning via Lifted Structured Feature Embedding
  27. `euclidean_distances` from sklearn
  28. """
  29. assert isinstance(x, np.ndarray) and x.ndim == 2
  30. assert isinstance(y, np.ndarray) and y.ndim == 2
  31. assert x.shape[1] == y.shape[1]
  32. x_square = np.expand_dims(np.einsum('ij,ij->i', x, x), axis=1)
  33. if x is y:
  34. y_square = x_square.T
  35. else:
  36. y_square = np.expand_dims(np.einsum('ij,ij->i', y, y), axis=0)
  37. distances = np.dot(x, y.T)
  38. # use inplace operation to accelerate
  39. distances *= -2
  40. distances += x_square
  41. distances += y_square
  42. # result maybe less than 0 due to floating point rounding errors.
  43. np.maximum(distances, 0, distances)
  44. if x is y:
  45. # Ensure that distances between vectors and themselves are set to 0.0.
  46. # This may not be the case due to floating point rounding errors.
  47. distances.flat[::distances.shape[0] + 1] = 0.0
  48. if not squared:
  49. np.sqrt(distances, distances)
  50. return distances