Răsfoiți Sursa

refactor imwrite_bytes

quarrying 2 ani în urmă
părinte
comite
5031d70f58
2 a modificat fișierele cu 22 adăugiri și 8 ștergeri
  1. 19 5
      khandy/image/misc.py
  2. 3 3
      khandy/utils_file_io.py

+ 19 - 5
khandy/image/misc.py

@@ -67,18 +67,32 @@ def imread_pil(file_or_buffer, to_mode=None):
         return None
         
         
-def imwrite_bytes(filename, image_bytes, update_extension=True):
+def imwrite_bytes(filename, image_bytes: bytes, update_extension: bool = True):
+    """Write image bytes to file.
+    
+    Args:
+        filename: str
+            filename which image_bytes is written into.
+        image_bytes: bytes
+            image content to be written.
+        update_extension: bool
+            whether update extension according to image_bytes or not.
+            the cost of update extension is smaller than update image format.
+    """
     extension = imghdr.what('', image_bytes)
-    if extension is None:
-        raise ValueError('image_bytes is not image')
-    extension = '.' + extension
     file_extension = khandy.get_path_extension(filename)
-    if extension.lower() != file_extension.lower():
+    # imghdr.what fails to determine image format sometimes!
+    # so when its return value is None, never update extension.
+    if extension is None:
+        image = cv2.imdecode(np.frombuffer(image_bytes, np.uint8), -1)
+        image_bytes = cv2.imencode(file_extension, image)[1]
+    elif (extension.lower() != file_extension.lower()[1:]):
         if update_extension:
             filename = khandy.replace_path_extension(filename, extension)
         else:
             image = cv2.imdecode(np.frombuffer(image_bytes, np.uint8), -1)
             image_bytes = cv2.imencode(file_extension, image)[1]
+    
     with open(filename, "wb") as f:
         f.write(image_bytes)
     return filename

+ 3 - 3
khandy/utils_file_io.py

@@ -44,7 +44,7 @@ def save_json(filename, data, encoding='utf-8', indent=4, cls=None, sort_keys=Fa
                   ensure_ascii=False, cls=cls, sort_keys=sort_keys)
 
 
-def load_bytes(filename: str, use_base64: bool = False) -> bytes:
+def load_bytes(filename, use_base64: bool = False) -> bytes:
     """Open the file in bytes mode, read it, and close the file.
     
     References:
@@ -57,7 +57,7 @@ def load_bytes(filename: str, use_base64: bool = False) -> bytes:
     return data
 
 
-def save_bytes(filename: str, data: bytes, use_base64: bool = False) -> int:
+def save_bytes(filename, data: bytes, use_base64: bool = False) -> int:
     """Open the file in bytes mode, write to it, and close the file.
     
     References:
@@ -70,7 +70,7 @@ def save_bytes(filename: str, data: bytes, use_base64: bool = False) -> int:
     return ret
 
 
-def load_as_base64(filename: str) -> bytes:
+def load_as_base64(filename) -> bytes:
     warnings.warn('khandy.load_as_base64 will be deprecated, use khandy.load_bytes instead!')
     return load_bytes(filename, True)