|
@@ -11,62 +11,62 @@ def sort_dict(dict_obj, key=None, reverse=False):
|
|
|
return OrderedDict(sorted(dict_obj.items(), key=key, reverse=reverse))
|
|
|
|
|
|
|
|
|
-def create_class_dict(name_list, label_list):
|
|
|
- assert len(name_list) == len(label_list)
|
|
|
- class_dict = {}
|
|
|
- for name, label in zip(name_list, label_list):
|
|
|
- class_dict.setdefault(label, []).append(name)
|
|
|
- return class_dict
|
|
|
-
|
|
|
-
|
|
|
-def convert_class_dict_to_list(class_dict):
|
|
|
- name_list, label_list = [], []
|
|
|
- for key, value in class_dict.items():
|
|
|
- name_list += value
|
|
|
- label_list += [key] * len(value)
|
|
|
- return name_list, label_list
|
|
|
-
|
|
|
-
|
|
|
-def convert_class_dict_to_records(class_dict, label_map=None, raise_if_key_error=True):
|
|
|
+def create_multidict(key_list, value_list):
|
|
|
+ assert len(key_list) == len(value_list)
|
|
|
+ multidict_obj = {}
|
|
|
+ for key, value in zip(key_list, value_list):
|
|
|
+ multidict_obj.setdefault(key, []).append(value)
|
|
|
+ return multidict_obj
|
|
|
+
|
|
|
+
|
|
|
+def convert_multidict_to_list(multidict_obj):
|
|
|
+ key_list, value_list = [], []
|
|
|
+ for key, value in multidict_obj.items():
|
|
|
+ key_list += [key] * len(value)
|
|
|
+ value_list += value
|
|
|
+ return key_list, value_list
|
|
|
+
|
|
|
+
|
|
|
+def convert_multidict_to_records(multidict_obj, key_map=None, raise_if_key_error=True):
|
|
|
records = []
|
|
|
- if label_map is None:
|
|
|
- for label in class_dict:
|
|
|
- for name in class_dict[label]:
|
|
|
- records.append('{},{}'.format(name, label))
|
|
|
+ if key_map is None:
|
|
|
+ for key in multidict_obj:
|
|
|
+ for value in multidict_obj[key]:
|
|
|
+ records.append('{},{}'.format(value, key))
|
|
|
else:
|
|
|
- for label in class_dict:
|
|
|
+ for key in multidict_obj:
|
|
|
if raise_if_key_error:
|
|
|
- mapped_label = label_map[label]
|
|
|
+ mapped_key = key_map[key]
|
|
|
else:
|
|
|
- mapped_label = label_map.get(label, label)
|
|
|
- for name in class_dict[label]:
|
|
|
- records.append('{},{}'.format(name, mapped_label))
|
|
|
+ mapped_key = key_map.get(key, key)
|
|
|
+ for value in multidict_obj[key]:
|
|
|
+ records.append('{},{}'.format(value, mapped_key))
|
|
|
return records
|
|
|
|
|
|
|
|
|
-def sample_class_dict(class_dict, num_classes, num_examples_per_class=None):
|
|
|
- num_classes = min(num_classes, len(class_dict))
|
|
|
- sub_keys = random.sample(list(class_dict), num_classes)
|
|
|
- if num_examples_per_class is None:
|
|
|
- sub_class_dict = {key: class_dict[key] for key in sub_keys}
|
|
|
+def sample_multidict(multidict_obj, num_keys, num_per_key=None):
|
|
|
+ num_keys = min(num_keys, len(multidict_obj))
|
|
|
+ sub_keys = random.sample(list(multidict_obj), num_keys)
|
|
|
+ if num_per_key is None:
|
|
|
+ sub_mdict = {key: multidict_obj[key] for key in sub_keys}
|
|
|
else:
|
|
|
- sub_class_dict = {}
|
|
|
+ sub_mdict = {}
|
|
|
for key in sub_keys:
|
|
|
- num_examples_inner = min(num_examples_per_class, len(class_dict[key]))
|
|
|
- sub_class_dict[key] = random.sample(class_dict[key], num_examples_inner)
|
|
|
- return sub_class_dict
|
|
|
+ num_examples_inner = min(num_per_key, len(multidict_obj[key]))
|
|
|
+ sub_mdict[key] = random.sample(multidict_obj[key], num_examples_inner)
|
|
|
+ return sub_mdict
|
|
|
|
|
|
|
|
|
-def split_class_dict_on_key(class_dict, split_ratio, use_shuffle=False):
|
|
|
- """Split class_dict on its key.
|
|
|
+def split_multidict_on_key(multidict_obj, split_ratio, use_shuffle=False):
|
|
|
+ """Split multidict_obj on its key.
|
|
|
"""
|
|
|
- assert isinstance(class_dict, dict)
|
|
|
+ assert isinstance(multidict_obj, dict)
|
|
|
assert isinstance(split_ratio, (list, tuple))
|
|
|
|
|
|
pdf = [k / float(sum(split_ratio)) for k in split_ratio]
|
|
|
cdf = [sum(pdf[:k]) for k in range(len(pdf) + 1)]
|
|
|
- indices = [int(round(len(class_dict) * k)) for k in cdf]
|
|
|
- dict_keys = list(class_dict)
|
|
|
+ indices = [int(round(len(multidict_obj) * k)) for k in cdf]
|
|
|
+ dict_keys = list(multidict_obj)
|
|
|
if use_shuffle:
|
|
|
random.shuffle(dict_keys)
|
|
|
|
|
@@ -74,21 +74,21 @@ def split_class_dict_on_key(class_dict, split_ratio, use_shuffle=False):
|
|
|
for i in range(len(split_ratio)):
|
|
|
#if indices[i] != indices[i + 1]:
|
|
|
part_keys = dict_keys[indices[i]: indices[i + 1]]
|
|
|
- part_dict = dict([(key, class_dict[key]) for key in part_keys])
|
|
|
+ part_dict = dict([(key, multidict_obj[key]) for key in part_keys])
|
|
|
be_split_list.append(part_dict)
|
|
|
return be_split_list
|
|
|
|
|
|
|
|
|
-def split_class_dict_on_value(class_dict, split_ratio, use_shuffle=False):
|
|
|
- """Split class_dict on its value.
|
|
|
+def split_multidict_on_value(multidict_obj, split_ratio, use_shuffle=False):
|
|
|
+ """Split multidict_obj on its value.
|
|
|
"""
|
|
|
- assert isinstance(class_dict, dict)
|
|
|
+ assert isinstance(multidict_obj, dict)
|
|
|
assert isinstance(split_ratio, (list, tuple))
|
|
|
|
|
|
pdf = [k / float(sum(split_ratio)) for k in split_ratio]
|
|
|
cdf = [sum(pdf[:k]) for k in range(len(pdf) + 1)]
|
|
|
be_split_list = [dict() for k in range(len(split_ratio))]
|
|
|
- for key, value in class_dict.items():
|
|
|
+ for key, value in multidict_obj.items():
|
|
|
indices = [int(round(len(value) * k)) for k in cdf]
|
|
|
cloned = value[:]
|
|
|
if use_shuffle:
|
|
@@ -99,64 +99,66 @@ def split_class_dict_on_value(class_dict, split_ratio, use_shuffle=False):
|
|
|
return be_split_list
|
|
|
|
|
|
|
|
|
-def get_class_dict_info(class_dict, with_print=False, desc=None):
|
|
|
- num_list = [len(val) for val in class_dict.values()]
|
|
|
- num_classes = len(num_list)
|
|
|
- num_examples = sum(num_list)
|
|
|
- max_examples_per_class = max(num_list)
|
|
|
- min_examples_per_class = min(num_list)
|
|
|
- if num_classes == 0:
|
|
|
- avg_examples_per_class = 0
|
|
|
+def get_multidict_info(multidict_obj, with_print=False, desc=None):
|
|
|
+ num_list = [len(val) for val in multidict_obj.values()]
|
|
|
+ num_keys = len(num_list)
|
|
|
+ num_values = sum(num_list)
|
|
|
+ max_values_per_key = max(num_list)
|
|
|
+ min_values_per_key = min(num_list)
|
|
|
+ if num_keys == 0:
|
|
|
+ avg_values_per_key = 0
|
|
|
else:
|
|
|
- avg_examples_per_class = num_examples / num_classes
|
|
|
+ avg_values_per_key = num_values / num_keys
|
|
|
info = {
|
|
|
- 'num_classes': num_classes,
|
|
|
- 'num_examples': num_examples,
|
|
|
- 'max_examples_per_class': max_examples_per_class,
|
|
|
- 'min_examples_per_class': min_examples_per_class,
|
|
|
- 'avg_examples_per_class': avg_examples_per_class,
|
|
|
+ 'num_keys': num_keys,
|
|
|
+ 'num_values': num_values,
|
|
|
+ 'max_values_per_key': max_values_per_key,
|
|
|
+ 'min_values_per_key': min_values_per_key,
|
|
|
+ 'avg_values_per_key': avg_values_per_key,
|
|
|
}
|
|
|
if with_print:
|
|
|
desc = desc or '<unknown>'
|
|
|
- print('{} subject number: {}'.format(desc, info['num_classes']))
|
|
|
- print('{} example number: {}'.format(desc, info['num_examples']))
|
|
|
- print('{} max number per-id: {}'.format(desc, info['max_examples_per_class']))
|
|
|
- print('{} min number per-id: {}'.format(desc, info['min_examples_per_class']))
|
|
|
- print('{} avg number per-id: {:.2f}'.format(desc, info['avg_examples_per_class']))
|
|
|
+ print('{} key number: {}'.format(desc, info['num_keys']))
|
|
|
+ print('{} value number: {}'.format(desc, info['num_values']))
|
|
|
+ print('{} max number per-key: {}'.format(desc, info['max_values_per_key']))
|
|
|
+ print('{} min number per-key: {}'.format(desc, info['min_values_per_key']))
|
|
|
+ print('{} avg number per-key: {:.2f}'.format(desc, info['avg_values_per_key']))
|
|
|
return info
|
|
|
|
|
|
|
|
|
-def filter_class_dict_by_number(class_dict, lower, upper=None):
|
|
|
+def filter_multidict_by_number(multidict_obj, lower, upper=None):
|
|
|
if upper is None:
|
|
|
- return {key: value for key, value in class_dict.items()
|
|
|
+ return {key: value for key, value in multidict_obj.items()
|
|
|
if lower <= len(value) }
|
|
|
else:
|
|
|
assert lower <= upper, 'lower must not be greater than upper'
|
|
|
- return {key: value for key, value in class_dict.items()
|
|
|
+ return {key: value for key, value in multidict_obj.items()
|
|
|
if lower <= len(value) <= upper }
|
|
|
|
|
|
|
|
|
-def sort_class_dict_by_number(class_dict, num_classes_to_keep=None, reverse=True):
|
|
|
+def sort_multidict_by_number(multidict_obj, num_keys_to_keep=None, reverse=True):
|
|
|
"""
|
|
|
Args:
|
|
|
reverse: sort in ascending order when is True.
|
|
|
"""
|
|
|
- if num_classes_to_keep is None:
|
|
|
- num_classes_to_keep = len(class_dict)
|
|
|
+ if num_keys_to_keep is None:
|
|
|
+ num_keys_to_keep = len(multidict_obj)
|
|
|
else:
|
|
|
- num_classes_to_keep = min(num_classes_to_keep, len(class_dict))
|
|
|
- sorted_items = sorted(class_dict.items(), key=lambda x: len(x[1]), reverse=reverse)
|
|
|
+ num_keys_to_keep = min(num_keys_to_keep, len(multidict_obj))
|
|
|
+ sorted_items = sorted(multidict_obj.items(), key=lambda x: len(x[1]), reverse=reverse)
|
|
|
filtered_dict = OrderedDict()
|
|
|
- for i in range(num_classes_to_keep):
|
|
|
+ for i in range(num_keys_to_keep):
|
|
|
filtered_dict[sorted_items[i][0]] = sorted_items[i][1]
|
|
|
return filtered_dict
|
|
|
|
|
|
|
|
|
-def merge_class_dict(*class_dicts):
|
|
|
- merged_class_dict = {}
|
|
|
- for item in class_dicts:
|
|
|
+def merge_multidict(*mdicts):
|
|
|
+ merged_multidict = {}
|
|
|
+ for item in mdicts:
|
|
|
for key, value in item.items():
|
|
|
- merged_class_dict.setdefault(key, []).extend(value)
|
|
|
- return merged_class_dict
|
|
|
+ merged_multidict.setdefault(key, []).extend(value)
|
|
|
+ return merged_multidict
|
|
|
+
|
|
|
+
|
|
|
|
|
|
|