|
@@ -89,7 +89,7 @@ def get_order_of_magnitude(number):
|
|
return oom.astype(np.int32)
|
|
return oom.astype(np.int32)
|
|
|
|
|
|
|
|
|
|
-def find_topk(x, k, axis=-1, largest=True, sorted=True):
|
|
|
|
|
|
+def top_k(x, k, axis=-1, largest=True, sorted=True):
|
|
"""Finds values and indices of the k largest/smallest
|
|
"""Finds values and indices of the k largest/smallest
|
|
elements along a given axis.
|
|
elements along a given axis.
|
|
|
|
|
|
@@ -111,17 +111,24 @@ def find_topk(x, k, axis=-1, largest=True, sorted=True):
|
|
topk_indices:
|
|
topk_indices:
|
|
The indices of the k largest/smallest elements along the given axis.
|
|
The indices of the k largest/smallest elements along the given axis.
|
|
"""
|
|
"""
|
|
|
|
+ if axis is None:
|
|
|
|
+ axis_size = x.size
|
|
|
|
+ else:
|
|
|
|
+ axis_size = x.shape[axis]
|
|
|
|
+ assert 1 <= k <= axis_size
|
|
|
|
+
|
|
|
|
+ x = np.asanyarray(x)
|
|
if largest:
|
|
if largest:
|
|
- index_array = np.argpartition(-x, k-1, axis=axis, order=None)
|
|
|
|
|
|
+ index_array = np.argpartition(x, axis_size-k, axis=axis)
|
|
|
|
+ topk_indices = np.take(index_array, -np.arange(k)-1, axis=axis)
|
|
else:
|
|
else:
|
|
- index_array = np.argpartition(x, k-1, axis=axis, order=None)
|
|
|
|
- topk_indices = np.take(index_array, range(k), axis=axis)
|
|
|
|
|
|
+ index_array = np.argpartition(x, k-1, axis=axis)
|
|
|
|
+ topk_indices = np.take(index_array, np.arange(k), axis=axis)
|
|
topk_values = np.take_along_axis(x, topk_indices, axis=axis)
|
|
topk_values = np.take_along_axis(x, topk_indices, axis=axis)
|
|
if sorted:
|
|
if sorted:
|
|
|
|
+ sorted_indices_in_topk = np.argsort(topk_values, axis=axis)
|
|
if largest:
|
|
if largest:
|
|
- sorted_indices_in_topk = np.argsort(-topk_values, axis=axis, order=None)
|
|
|
|
- else:
|
|
|
|
- sorted_indices_in_topk = np.argsort(topk_values, axis=axis, order=None)
|
|
|
|
|
|
+ sorted_indices_in_topk = np.flip(sorted_indices_in_topk, axis=axis)
|
|
sorted_topk_values = np.take_along_axis(
|
|
sorted_topk_values = np.take_along_axis(
|
|
topk_values, sorted_indices_in_topk, axis=axis)
|
|
topk_values, sorted_indices_in_topk, axis=axis)
|
|
sorted_topk_indices = np.take_along_axis(
|
|
sorted_topk_indices = np.take_along_axis(
|