utils_dict.py 5.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162
  1. import random
  2. from collections import OrderedDict
  3. def get_dict_first_item(dict_obj):
  4. for key in dict_obj:
  5. return key, dict_obj[key]
  6. def sort_dict(dict_obj, key=None, reverse=False):
  7. return OrderedDict(sorted(dict_obj.items(), key=key, reverse=reverse))
  8. def create_class_dict(name_list, label_list):
  9. assert len(name_list) == len(label_list)
  10. class_dict = {}
  11. for name, label in zip(name_list, label_list):
  12. class_dict.setdefault(label, []).append(name)
  13. return class_dict
  14. def convert_class_dict_to_list(class_dict):
  15. name_list, label_list = [], []
  16. for key, value in class_dict.items():
  17. name_list += value
  18. label_list += [key] * len(value)
  19. return name_list, label_list
  20. def convert_class_dict_to_records(class_dict, label_map=None, raise_if_key_error=True):
  21. records = []
  22. if label_map is None:
  23. for label in class_dict:
  24. for name in class_dict[label]:
  25. records.append('{},{}'.format(name, label))
  26. else:
  27. for label in class_dict:
  28. if raise_if_key_error:
  29. mapped_label = label_map[label]
  30. else:
  31. mapped_label = label_map.get(label, label)
  32. for name in class_dict[label]:
  33. records.append('{},{}'.format(name, mapped_label))
  34. return records
  35. def sample_class_dict(class_dict, num_classes, num_examples_per_class=None):
  36. num_classes = min(num_classes, len(class_dict))
  37. sub_keys = random.sample(list(class_dict), num_classes)
  38. if num_examples_per_class is None:
  39. sub_class_dict = {key: class_dict[key] for key in sub_keys}
  40. else:
  41. sub_class_dict = {}
  42. for key in sub_keys:
  43. num_examples_inner = min(num_examples_per_class, len(class_dict[key]))
  44. sub_class_dict[key] = random.sample(class_dict[key], num_examples_inner)
  45. return sub_class_dict
  46. def split_class_dict_on_key(class_dict, split_ratio, use_shuffle=False):
  47. """Split class_dict on its key.
  48. """
  49. assert isinstance(class_dict, dict)
  50. assert isinstance(split_ratio, (list, tuple))
  51. pdf = [k / float(sum(split_ratio)) for k in split_ratio]
  52. cdf = [sum(pdf[:k]) for k in range(len(pdf) + 1)]
  53. indices = [int(round(len(class_dict) * k)) for k in cdf]
  54. dict_keys = list(class_dict)
  55. if use_shuffle:
  56. random.shuffle(dict_keys)
  57. be_split_list = []
  58. for i in range(len(split_ratio)):
  59. #if indices[i] != indices[i + 1]:
  60. part_keys = dict_keys[indices[i]: indices[i + 1]]
  61. part_dict = dict([(key, class_dict[key]) for key in part_keys])
  62. be_split_list.append(part_dict)
  63. return be_split_list
  64. def split_class_dict_on_value(class_dict, split_ratio, use_shuffle=False):
  65. """Split class_dict on its value.
  66. """
  67. assert isinstance(class_dict, dict)
  68. assert isinstance(split_ratio, (list, tuple))
  69. pdf = [k / float(sum(split_ratio)) for k in split_ratio]
  70. cdf = [sum(pdf[:k]) for k in range(len(pdf) + 1)]
  71. be_split_list = [dict() for k in range(len(split_ratio))]
  72. for key, value in class_dict.items():
  73. indices = [int(round(len(value) * k)) for k in cdf]
  74. cloned = value[:]
  75. if use_shuffle:
  76. random.shuffle(cloned)
  77. for i in range(len(split_ratio)):
  78. #if indices[i] != indices[i + 1]:
  79. be_split_list[i][key] = cloned[indices[i]: indices[i + 1]]
  80. return be_split_list
  81. def get_class_dict_info(class_dict, with_print=False, desc=None):
  82. num_list = [len(val) for val in class_dict.values()]
  83. num_classes = len(num_list)
  84. num_examples = sum(num_list)
  85. max_examples_per_class = max(num_list)
  86. min_examples_per_class = min(num_list)
  87. if num_classes == 0:
  88. avg_examples_per_class = 0
  89. else:
  90. avg_examples_per_class = num_examples / num_classes
  91. info = {
  92. 'num_classes': num_classes,
  93. 'num_examples': num_examples,
  94. 'max_examples_per_class': max_examples_per_class,
  95. 'min_examples_per_class': min_examples_per_class,
  96. 'avg_examples_per_class': avg_examples_per_class,
  97. }
  98. if with_print:
  99. desc = desc or '<unknown>'
  100. print('{} subject number: {}'.format(desc, info['num_classes']))
  101. print('{} example number: {}'.format(desc, info['num_examples']))
  102. print('{} max number per-id: {}'.format(desc, info['max_examples_per_class']))
  103. print('{} min number per-id: {}'.format(desc, info['min_examples_per_class']))
  104. print('{} avg number per-id: {:.2f}'.format(desc, info['avg_examples_per_class']))
  105. return info
  106. def filter_class_dict_by_number(class_dict, lower, upper=None):
  107. if upper is None:
  108. return {key: value for key, value in class_dict.items()
  109. if lower <= len(value) }
  110. else:
  111. assert lower <= upper, 'lower must not be greater than upper'
  112. return {key: value for key, value in class_dict.items()
  113. if lower <= len(value) <= upper }
  114. def sort_class_dict_by_number(class_dict, num_classes_to_keep=None, reverse=True):
  115. """
  116. Args:
  117. reverse: sort in ascending order when is True.
  118. """
  119. if num_classes_to_keep is None:
  120. num_classes_to_keep = len(class_dict)
  121. else:
  122. num_classes_to_keep = min(num_classes_to_keep, len(class_dict))
  123. sorted_items = sorted(class_dict.items(), key=lambda x: len(x[1]), reverse=reverse)
  124. filtered_dict = OrderedDict()
  125. for i in range(num_classes_to_keep):
  126. filtered_dict[sorted_items[i][0]] = sorted_items[i][1]
  127. return filtered_dict
  128. def merge_class_dict(*class_dicts):
  129. merged_class_dict = {}
  130. for item in class_dicts:
  131. for key, value in item.items():
  132. merged_class_dict.setdefault(key, []).extend(value)
  133. return merged_class_dict