utils_split.py 1.9 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273
  1. import numbers
  2. from collections import Sequence
  3. import numpy as np
  4. def split_by_num(x, num_splits, strict=True):
  5. """
  6. Args:
  7. num_splits: an integer indicating the number of splits
  8. References:
  9. numpy.split and numpy.array_split
  10. """
  11. # NB: np.ndarray is not Sequence
  12. assert isinstance(x, (Sequence, np.ndarray))
  13. assert isinstance(num_splits, numbers.Integral)
  14. if strict:
  15. assert len(x) % num_splits == 0
  16. split_size = (len(x) + num_splits - 1) // num_splits
  17. out_list = []
  18. for i in range(0, len(x), split_size):
  19. out_list.append(x[i: i + split_size])
  20. return out_list
  21. def split_by_size(x, sizes):
  22. """
  23. References:
  24. tf.split
  25. https://github.com/open-mmlab/mmcv/blob/master/mmcv/utils/misc.py
  26. """
  27. # NB: np.ndarray is not Sequence
  28. assert isinstance(x, (Sequence, np.ndarray))
  29. assert isinstance(sizes, (list, tuple))
  30. assert sum(sizes) == len(x)
  31. out_list = []
  32. start_index = 0
  33. for size in sizes:
  34. out_list.append(x[start_index: start_index + size])
  35. start_index += size
  36. return out_list
  37. def split_by_slice(x, slices):
  38. """
  39. References:
  40. SliceLayer in Caffe, and numpy.split
  41. """
  42. # NB: np.ndarray is not Sequence
  43. assert isinstance(x, (Sequence, np.ndarray))
  44. assert isinstance(slices, (list, tuple))
  45. out_list = []
  46. indices = [0] + list(slices) + [len(x)]
  47. for i in range(len(slices) + 1):
  48. out_list.append(x[indices[i]: indices[i + 1]])
  49. return out_list
  50. def split_by_ratio(x, ratios):
  51. # NB: np.ndarray is not Sequence
  52. assert isinstance(x, (Sequence, np.ndarray))
  53. assert isinstance(ratios, (list, tuple))
  54. pdf = [k / sum(ratios) for k in ratios]
  55. cdf = [sum(pdf[:k]) for k in range(len(pdf) + 1)]
  56. indices = [int(round(len(x) * k)) for k in cdf]
  57. return [x[indices[i]: indices[i + 1]] for i in range(len(ratios))]