Browse Source

add split related utils

quarrying 2 years ago
parent
commit
96a011aa90
2 changed files with 74 additions and 0 deletions
  1. 1 0
      khandy/__init__.py
  2. 73 0
      khandy/utils_split.py

+ 1 - 0
khandy/__init__.py

@@ -7,6 +7,7 @@ from .utils_list import *
 from .utils_numpy import *
 from .utils_draw import *
 from .utils_others import *
+from .utils_split import *
 
 from .boxes import *
 from .image import *

+ 73 - 0
khandy/utils_split.py

@@ -0,0 +1,73 @@
+import numbers
+from collections import Sequence
+
+import numpy as np
+
+
+def split_by_num(x, num_splits, strict=True):
+    """
+    Args:
+        num_splits: an integer indicating the number of splits
+
+    References:
+        numpy.split and numpy.array_split
+    """
+    # NB: np.ndarray is not Sequence
+    assert isinstance(x, (Sequence, np.ndarray))
+    assert isinstance(num_splits, numbers.Integral)
+    
+    if strict:
+        assert len(x) % num_splits == 0
+    split_size = (len(x) + num_splits - 1) // num_splits
+    out_list = []
+    for i in range(0, len(x), split_size):
+        out_list.append(x[i: i + split_size])
+    return out_list
+    
+    
+def split_by_size(x, sizes):
+    """
+    References:
+        tf.split
+        https://github.com/open-mmlab/mmcv/blob/master/mmcv/utils/misc.py
+    """
+    # NB: np.ndarray is not Sequence
+    assert isinstance(x, (Sequence, np.ndarray))
+    assert isinstance(sizes, (list, tuple))
+     
+    assert sum(sizes) == len(x)
+    out_list = []
+    start_index = 0
+    for size in sizes:
+        out_list.append(x[start_index: start_index + size])
+        start_index += size
+    return out_list
+    
+    
+def split_by_slice(x, slices):
+    """
+    References:
+        SliceLayer in Caffe, and numpy.split
+    """
+    # NB: np.ndarray is not Sequence
+    assert isinstance(x, (Sequence, np.ndarray))
+    assert isinstance(slices, (list, tuple))
+    
+    out_list = []
+    indices = [0] + list(slices) + [len(x)]
+    for i in range(len(slices) + 1):
+        out_list.append(x[indices[i]: indices[i + 1]])
+    return out_list
+
+
+def split_by_ratio(x, ratios):
+    # NB: np.ndarray is not Sequence
+    assert isinstance(x, (Sequence, np.ndarray))
+    assert isinstance(ratios, (list, tuple))
+    
+    pdf = [k / sum(ratios) for k in ratios]
+    cdf = [sum(pdf[:k]) for k in range(len(pdf) + 1)]
+    indices = [int(round(len(x) * k)) for k in cdf]
+    return [x[indices[i]: indices[i + 1]] for i in range(len(ratios))]
+    
+