Explorar el Código

add find_topk

quarrying hace 4 años
padre
commit
3c79a379bf
Se han modificado 1 ficheros con 41 adiciones y 0 borrados
  1. 41 0
      khandy/utils_numpy.py

+ 41 - 0
khandy/utils_numpy.py

@@ -88,4 +88,45 @@ def get_order_of_magnitude(number):
     oom = np.floor(np.log10(np.abs(number)))
     return oom.astype(np.int32)
     
+    
+def find_topk(x, k, axis=-1, largest=True, sorted=True):
+    """Finds values and indices of the k largest/smallest 
+    elements along a given axis.
+
+    Args:
+        x: numpy ndarray
+            1-D or higher with given axis at least k.
+        k: int
+            Number of top elements to look for along the given axis.
+        axis: int
+            The axis to sort along.
+        largest: bool
+            Controls whether to return largest or smallest elements
+        sorted: bool
+            If true the resulting k elements will be sorted by the values.
+
+    Returns:
+        topk_values: 
+            The k largest/smallest elements along the given axis.
+        topk_indices: 
+            The indices of the k largest/smallest elements along the given axis.
+    """
+    if largest:
+        index_array = np.argpartition(-x, k-1, axis=axis, order=None)
+    else:
+        index_array = np.argpartition(x, k-1, axis=axis, order=None)
+    topk_indices = np.take(index_array, range(k), axis=axis)
+    topk_values = np.take_along_axis(x, topk_indices, axis=axis)
+    if sorted:
+        if largest:
+            sorted_indices_in_topk = np.argsort(-topk_values, axis=axis, order=None)
+        else:
+            sorted_indices_in_topk = np.argsort(topk_values, axis=axis, order=None)
+        sorted_topk_values = np.take_along_axis(
+            topk_values, sorted_indices_in_topk, axis=axis)
+        sorted_topk_indices = np.take_along_axis(
+            topk_indices, sorted_indices_in_topk, axis=axis)
+        return sorted_topk_values, sorted_topk_indices
+    return topk_values, topk_indices
+