|
@@ -88,4 +88,45 @@ def get_order_of_magnitude(number):
|
|
|
oom = np.floor(np.log10(np.abs(number)))
|
|
|
return oom.astype(np.int32)
|
|
|
|
|
|
+
|
|
|
+def find_topk(x, k, axis=-1, largest=True, sorted=True):
|
|
|
+ """Finds values and indices of the k largest/smallest
|
|
|
+ elements along a given axis.
|
|
|
+
|
|
|
+ Args:
|
|
|
+ x: numpy ndarray
|
|
|
+ 1-D or higher with given axis at least k.
|
|
|
+ k: int
|
|
|
+ Number of top elements to look for along the given axis.
|
|
|
+ axis: int
|
|
|
+ The axis to sort along.
|
|
|
+ largest: bool
|
|
|
+ Controls whether to return largest or smallest elements
|
|
|
+ sorted: bool
|
|
|
+ If true the resulting k elements will be sorted by the values.
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ topk_values:
|
|
|
+ The k largest/smallest elements along the given axis.
|
|
|
+ topk_indices:
|
|
|
+ The indices of the k largest/smallest elements along the given axis.
|
|
|
+ """
|
|
|
+ if largest:
|
|
|
+ index_array = np.argpartition(-x, k-1, axis=axis, order=None)
|
|
|
+ else:
|
|
|
+ index_array = np.argpartition(x, k-1, axis=axis, order=None)
|
|
|
+ topk_indices = np.take(index_array, range(k), axis=axis)
|
|
|
+ topk_values = np.take_along_axis(x, topk_indices, axis=axis)
|
|
|
+ if sorted:
|
|
|
+ if largest:
|
|
|
+ sorted_indices_in_topk = np.argsort(-topk_values, axis=axis, order=None)
|
|
|
+ else:
|
|
|
+ sorted_indices_in_topk = np.argsort(topk_values, axis=axis, order=None)
|
|
|
+ sorted_topk_values = np.take_along_axis(
|
|
|
+ topk_values, sorted_indices_in_topk, axis=axis)
|
|
|
+ sorted_topk_indices = np.take_along_axis(
|
|
|
+ topk_indices, sorted_indices_in_topk, axis=axis)
|
|
|
+ return sorted_topk_values, sorted_topk_indices
|
|
|
+ return topk_values, topk_indices
|
|
|
+
|
|
|
|