Browse Source

add is_path_in_extensions and refactor get_all_filenames

quarrying 3 years ago
parent
commit
60ebbd4c1f
1 changed files with 17 additions and 13 deletions
  1. 17 13
      khandy/utils_fs.py

+ 17 - 13
khandy/utils_fs.py

@@ -50,6 +50,22 @@ def replace_path_extension(path, new_extension=None):
         return '.'.join([filename_wo_ext, new_extension])
 
 
+def normalize_extension(extension):
+    if extension.startswith('.'):
+        new_extension = extension.lower()
+    else:
+        new_extension =  '.' + extension.lower()
+    return new_extension
+
+
+def is_path_in_extensions(path, extensions):
+    if isinstance(extensions, str):
+        extensions = [extensions]
+    extensions = [normalize_extension(item) for item in extensions]
+    extension = get_path_extension(path)
+    return extension.lower() in extensions
+
+
 def makedirs(name, mode=0o755):
     """
     References:
@@ -79,26 +95,14 @@ def listdirs(paths, path_sep=None, full_path=True):
     return all_filenames
 
 
-def normalize_extension(extension):
-    if extension.startswith('.'):
-        new_extension = extension.lower()
-    else:
-        new_extension =  '.' + extension.lower()
-    return new_extension
-
-
 def get_all_filenames(path, extensions=None, is_valid_file=None):
     if (extensions is not None) and (is_valid_file is not None):
         raise ValueError("Both extensions and is_valid_file cannot "
                          "be not None at the same time")
     if is_valid_file is None:
         if extensions is not None:
-            if isinstance(extensions, str):
-                extensions = [extensions]
-            if extensions is not None:
-                extensions = tuple([normalize_extension(item) for item in extensions])
             def is_valid_file(filename):
-                return filename.lower().endswith(extensions)
+                return is_path_in_extensions(filename, extensions)
         else:
             def is_valid_file(filename):
                 return True