瀏覽代碼

rename find_topk to top_k, and refactor it

quarrying 4 年之前
父節點
當前提交
7684b4418d
共有 1 個文件被更改,包括 14 次插入7 次删除
  1. 14 7
      khandy/utils_numpy.py

+ 14 - 7
khandy/utils_numpy.py

@@ -89,7 +89,7 @@ def get_order_of_magnitude(number):
     return oom.astype(np.int32)
     
     
-def find_topk(x, k, axis=-1, largest=True, sorted=True):
+def top_k(x, k, axis=-1, largest=True, sorted=True):
     """Finds values and indices of the k largest/smallest 
     elements along a given axis.
 
@@ -111,17 +111,24 @@ def find_topk(x, k, axis=-1, largest=True, sorted=True):
         topk_indices: 
             The indices of the k largest/smallest elements along the given axis.
     """
+    if axis is None:
+        axis_size = x.size
+    else:
+        axis_size = x.shape[axis]
+    assert 1 <= k <= axis_size
+
+    x = np.asanyarray(x)
     if largest:
-        index_array = np.argpartition(-x, k-1, axis=axis, order=None)
+        index_array = np.argpartition(x, axis_size-k, axis=axis)
+        topk_indices = np.take(index_array, -np.arange(k)-1, axis=axis)
     else:
-        index_array = np.argpartition(x, k-1, axis=axis, order=None)
-    topk_indices = np.take(index_array, range(k), axis=axis)
+        index_array = np.argpartition(x, k-1, axis=axis)
+        topk_indices = np.take(index_array, np.arange(k), axis=axis)
     topk_values = np.take_along_axis(x, topk_indices, axis=axis)
     if sorted:
+        sorted_indices_in_topk = np.argsort(topk_values, axis=axis)
         if largest:
-            sorted_indices_in_topk = np.argsort(-topk_values, axis=axis, order=None)
-        else:
-            sorted_indices_in_topk = np.argsort(topk_values, axis=axis, order=None)
+            sorted_indices_in_topk = np.flip(sorted_indices_in_topk, axis=axis)
         sorted_topk_values = np.take_along_axis(
             topk_values, sorted_indices_in_topk, axis=axis)
         sorted_topk_indices = np.take_along_axis(