utils_dict.py 5.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168
  1. import random
  2. from collections import OrderedDict
  3. def get_dict_first_item(dict_obj):
  4. for key in dict_obj:
  5. return key, dict_obj[key]
  6. def sort_dict(dict_obj, key=None, reverse=False):
  7. return OrderedDict(sorted(dict_obj.items(), key=key, reverse=reverse))
  8. def create_multidict(key_list, value_list):
  9. assert len(key_list) == len(value_list)
  10. multidict_obj = {}
  11. for key, value in zip(key_list, value_list):
  12. multidict_obj.setdefault(key, []).append(value)
  13. return multidict_obj
  14. def convert_multidict_to_list(multidict_obj):
  15. key_list, value_list = [], []
  16. for key, value in multidict_obj.items():
  17. key_list += [key] * len(value)
  18. value_list += value
  19. return key_list, value_list
  20. def convert_multidict_to_records(multidict_obj, key_map=None, raise_if_key_error=True):
  21. records = []
  22. if key_map is None:
  23. for key in multidict_obj:
  24. for value in multidict_obj[key]:
  25. records.append('{},{}'.format(value, key))
  26. else:
  27. for key in multidict_obj:
  28. if raise_if_key_error:
  29. mapped_key = key_map[key]
  30. else:
  31. mapped_key = key_map.get(key, key)
  32. for value in multidict_obj[key]:
  33. records.append('{},{}'.format(value, mapped_key))
  34. return records
  35. def sample_multidict(multidict_obj, num_keys, num_per_key=None):
  36. num_keys = min(num_keys, len(multidict_obj))
  37. sub_keys = random.sample(list(multidict_obj), num_keys)
  38. if num_per_key is None:
  39. sub_mdict = {key: multidict_obj[key] for key in sub_keys}
  40. else:
  41. sub_mdict = {}
  42. for key in sub_keys:
  43. num_examples_inner = min(num_per_key, len(multidict_obj[key]))
  44. sub_mdict[key] = random.sample(multidict_obj[key], num_examples_inner)
  45. return sub_mdict
  46. def split_multidict_on_key(multidict_obj, split_ratio, use_shuffle=False):
  47. """Split multidict_obj on its key.
  48. """
  49. assert isinstance(multidict_obj, dict)
  50. assert isinstance(split_ratio, (list, tuple))
  51. pdf = [k / float(sum(split_ratio)) for k in split_ratio]
  52. cdf = [sum(pdf[:k]) for k in range(len(pdf) + 1)]
  53. indices = [int(round(len(multidict_obj) * k)) for k in cdf]
  54. dict_keys = list(multidict_obj)
  55. if use_shuffle:
  56. random.shuffle(dict_keys)
  57. be_split_list = []
  58. for i in range(len(split_ratio)):
  59. part_keys = dict_keys[indices[i]: indices[i + 1]]
  60. part_dict = dict([(key, multidict_obj[key]) for key in part_keys])
  61. be_split_list.append(part_dict)
  62. return be_split_list
  63. def split_multidict_on_value(multidict_obj, split_ratio, use_shuffle=False):
  64. """Split multidict_obj on its value.
  65. """
  66. assert isinstance(multidict_obj, dict)
  67. assert isinstance(split_ratio, (list, tuple))
  68. pdf = [k / float(sum(split_ratio)) for k in split_ratio]
  69. cdf = [sum(pdf[:k]) for k in range(len(pdf) + 1)]
  70. be_split_list = [dict() for k in range(len(split_ratio))]
  71. for key, value in multidict_obj.items():
  72. indices = [int(round(len(value) * k)) for k in cdf]
  73. cloned = value[:]
  74. if use_shuffle:
  75. random.shuffle(cloned)
  76. for i in range(len(split_ratio)):
  77. be_split_list[i][key] = cloned[indices[i]: indices[i + 1]]
  78. return be_split_list
  79. def get_multidict_info(multidict_obj, with_print=False, desc=None):
  80. num_list = [len(val) for val in multidict_obj.values()]
  81. num_keys = len(num_list)
  82. num_values = sum(num_list)
  83. max_values_per_key = max(num_list)
  84. min_values_per_key = min(num_list)
  85. if num_keys == 0:
  86. avg_values_per_key = 0
  87. else:
  88. avg_values_per_key = num_values / num_keys
  89. info = {
  90. 'num_keys': num_keys,
  91. 'num_values': num_values,
  92. 'max_values_per_key': max_values_per_key,
  93. 'min_values_per_key': min_values_per_key,
  94. 'avg_values_per_key': avg_values_per_key,
  95. }
  96. if with_print:
  97. desc = desc or '<unknown>'
  98. print('{} key number: {}'.format(desc, info['num_keys']))
  99. print('{} value number: {}'.format(desc, info['num_values']))
  100. print('{} max number per-key: {}'.format(desc, info['max_values_per_key']))
  101. print('{} min number per-key: {}'.format(desc, info['min_values_per_key']))
  102. print('{} avg number per-key: {:.2f}'.format(desc, info['avg_values_per_key']))
  103. return info
  104. def filter_multidict_by_number(multidict_obj, lower, upper=None):
  105. if upper is None:
  106. return {key: value for key, value in multidict_obj.items()
  107. if lower <= len(value) }
  108. else:
  109. assert lower <= upper, 'lower must not be greater than upper'
  110. return {key: value for key, value in multidict_obj.items()
  111. if lower <= len(value) <= upper }
  112. def sort_multidict_by_number(multidict_obj, num_keys_to_keep=None, reverse=True):
  113. """
  114. Args:
  115. reverse: sort in ascending order when is True.
  116. """
  117. if num_keys_to_keep is None:
  118. num_keys_to_keep = len(multidict_obj)
  119. else:
  120. num_keys_to_keep = min(num_keys_to_keep, len(multidict_obj))
  121. sorted_items = sorted(multidict_obj.items(), key=lambda x: len(x[1]), reverse=reverse)
  122. filtered_dict = OrderedDict()
  123. for i in range(num_keys_to_keep):
  124. filtered_dict[sorted_items[i][0]] = sorted_items[i][1]
  125. return filtered_dict
  126. def merge_multidict(*mdicts):
  127. merged_multidict = {}
  128. for item in mdicts:
  129. for key, value in item.items():
  130. merged_multidict.setdefault(key, []).extend(value)
  131. return merged_multidict
  132. def invert_multidict(multidict_obj):
  133. inverted_dict = {}
  134. for key, value in multidict_obj.items():
  135. for item in value:
  136. inverted_dict.setdefault(item, []).append(key)
  137. return inverted_dict