Bladeren bron

refactor convert_feature_dict_to_array and convert_feature_array_to_dict

quarrying 4 jaren geleden
bovenliggende
commit
330027c465
1 gewijzigde bestanden met toevoegingen van 17 en 11 verwijderingen
  1. 17 11
      khandy/utils_feature.py

+ 17 - 11
khandy/utils_feature.py

@@ -6,23 +6,29 @@ from .utils_dict import get_dict_first_item as _get_dict_first_item
 
 
 def convert_feature_dict_to_array(feature_dict):
-    key_list = []
     one_feature = _get_dict_first_item(feature_dict)[1]
-    feature_array = np.empty((len(feature_dict), len(one_feature)), one_feature.dtype)
-    for k, (key, value) in enumerate(feature_dict.items()):
-        key_list.append(key)
-        feature_array[k] = value
-    return key_list, feature_array
-    
+    num_features = sum([len(item) for item in feature_dict.values()])
     
+    key_list = []
+    start_index = 0
+    feature_array = np.empty((num_features, len(one_feature)), one_feature.dtype)
+    for key, value in feature_dict.items():
+        feature_array[start_index: start_index + len(value)]= value
+        key_list += [key] * len(value)
+        start_index += len(value)
+    return key_list, feature_array
+
+
 def convert_feature_array_to_dict(key_list, feature_array):
-    assert len(feature_array) == len(key_list)
+    assert len(key_list) == len(feature_array)
     feature_dict = OrderedDict()
-    for k, key in enumerate(key_list):
-        feature_dict[key] = feature_array[k]
+    for key, feat in zip(key_list, feature_array):
+        feature_dict.setdefault(key, []).append(feat)
+    for label in feature_dict.keys():
+        feature_dict[label] = np.vstack(feature_dict[label])
     return feature_dict
     
-
+    
 def pairwise_distances(x, y, squared=True):
     """Compute pairwise (squared) Euclidean distances.