Browse Source

update khandy.label

quarrying 2 years ago
parent
commit
9ef4ac0915
2 changed files with 123 additions and 99 deletions
  1. 0 69
      khandy/label/__init__.py
  2. 123 30
      khandy/label/detect.py

+ 0 - 69
khandy/label/__init__.py

@@ -1,71 +1,2 @@
 from .detect import *
 
-
-def _get_format(record):
-    if isinstance(record, LabelmeRecord):
-        return ('labelme',)
-    elif isinstance(record, YoloRecord):
-        return ('yolo',)
-    elif isinstance(record, PascalVocRecord):
-        return ('voc', 'pascal', 'pascal_voc')
-    elif isinstance(record, CocoRecord):
-        return ('coco',)
-    else:
-        return ()
-        
-    
-def load(filename, fmt, **kwargs):
-    if fmt == 'labelme':
-        record = LabelmeHandler.load(filename)
-    elif fmt == 'yolo':
-        record = YoloHandler.load(filename)
-    elif fmt in ('voc', 'pascal', 'pascal_voc'):
-        record = PascalVocHandler.load(filename)
-    elif fmt == 'coco':
-        record = CocoDetectHandler.load(filename, **kwargs)
-    else:
-        raise ValueError(f"Unsupported detect label fmt. Got {fmt}")
-    return record
-    
-    
-def save(filename, record):
-    if isinstance(record, LabelmeRecord):
-        LabelmeHandler.save(filename, record)
-    elif isinstance(record, YoloRecord):
-        YoloHandler.save(filename, record)
-    elif isinstance(record, PascalVocRecord):
-        PascalVocHandler.save(filename, record)
-    elif isinstance(record, CocoRecord):
-        raise ValueError("Unsupported for CocoRecord now!")
-    else:
-        raise ValueError("Unsupported type!")
-        
-        
-def convert(record, out_fmt):
-    allowed_fmts = ("labelme", "yolo", "voc", "coco", 'pascal', 'pascal_voc')
-    if out_fmt not in allowed_fmts:
-        raise ValueError("Unsupported label format conversions for given out_fmt")
-    if out_fmt in _get_format(record):
-        return record
-
-    if isinstance(record, LabelmeRecord):
-        ir_record = LabelmeHandler.to_ir(record)
-    elif isinstance(record, YoloRecord):
-        ir_record = YoloHandler.to_ir(record)
-    elif isinstance(record, PascalVocRecord):
-        ir_record = PascalVocHandler.to_ir(record)
-    elif isinstance(record, CocoRecord):
-        ir_record = CocoDetectHandler.to_ir(record)
-    else:
-        raise ValueError('Unsupported type for record')
-        
-    if out_fmt == 'labelme':
-        dst_record = LabelmeHandler.from_ir(ir_record)
-    elif out_fmt == 'yolo':
-        dst_record = YoloHandler.from_ir(ir_record)
-    elif out_fmt in ('voc', 'pascal', 'pascal_voc'):
-        dst_record = PascalVocHandler.from_ir(ir_record)
-    elif out_fmt == 'coco':
-        dst_record = CocoDetectHandler.from_ir(ir_record)
-    return dst_record
-    

+ 123 - 30
khandy/label/detect.py

@@ -1,3 +1,4 @@
+import os
 import copy
 import json
 import dataclasses
@@ -12,12 +13,9 @@ import lxml.builder
 import numpy as np
 
 
-__all__ = ['DetectIrObject', 'DetectIrRecord',
-           'PascalVocSource', 'PascalVocSize', 'PascalVocBndbox', 
-           'PascalVocObject', 'PascalVocRecord', 'PascalVocHandler',
-           'LabelmeShape', 'LabelmeRecord', 'LabelmeHandler',
-           'YoloObject', 'YoloRecord', 'YoloHandler',
-           'CocoObject', 'CocoRecord', 'CocoDetectHandler']
+__all__ = ['DetectIrObject', 'DetectIrRecord', 'load_detect', 
+           'save_detect', 'convert_detect', 'replace_detect_label',
+           'load_coco_class_names']
 
 
 @dataclass
@@ -85,7 +83,7 @@ class PascalVocRecord:
     
 class PascalVocHandler:
     @staticmethod
-    def load(filename) -> PascalVocRecord:
+    def load(filename, **kwargs) -> PascalVocRecord:
         pascal_voc_record = PascalVocRecord()
         
         xml_tree = ET.parse(filename)
@@ -160,6 +158,8 @@ class PascalVocHandler:
             )
             xml.append(object_tag)
             
+        if not filename.endswith('.xml'):
+            filename = filename + '.xml'
         with open(filename, 'wb') as f:
             f.write(lxml.etree.tostring(xml, pretty_print=True, encoding='utf-8'))
             
@@ -251,7 +251,7 @@ class LabelmeRecord:
 
 class LabelmeHandler:
     @staticmethod
-    def load(filename) -> LabelmeRecord:
+    def load(filename, **kwargs) -> LabelmeRecord:
         json_content = khandy.load_json(filename)
         return LabelmeRecord(**json_content)
 
@@ -260,6 +260,7 @@ class LabelmeHandler:
         json_content = dataclasses.asdict(labelme_record)
         khandy.save_json(filename, json_content, cls=_NumpyEncoder)
 
+    @staticmethod
     def to_ir(labelme_record: LabelmeRecord) -> DetectIrRecord:
         ir_record = DetectIrRecord(
             filename=labelme_record.imagePath,
@@ -317,10 +318,12 @@ class YoloRecord:
 class YoloHandler:
     @staticmethod
     def load(filename, **kwargs) -> YoloRecord:
-        records = khandy.load_list(filename)
+        assert 'image_filename' in kwargs
+        assert 'width' in kwargs and 'height' in kwargs
 
+        records = khandy.load_list(filename)
         yolo_record = YoloRecord(
-            filename=kwargs.get('filename'),
+            filename=kwargs.get('image_filename'),
             width=kwargs.get('width'),
             height=kwargs.get('height'))
         for record in records:
@@ -335,27 +338,13 @@ class YoloHandler:
         return yolo_record
 
     @staticmethod
-    def save(filemame, yolo_record: YoloRecord):
+    def save(filename, yolo_record: YoloRecord):
         records = []
         for object in yolo_record.objects:
             records.append(f'{object.label} {object.x_center} {object.y_center} {object.width} {object.height}')
-        khandy.save_list(filemame, records)
-
-    @staticmethod
-    def replace_label(yolo_record: YoloRecord, label_map, ignore=False):
-        dst_yolo_record = copy.deepcopy(yolo_record)
-        dst_objects = []
-        for yolo_object in dst_yolo_record.objects:
-            if not ignore:
-                if yolo_object.label in label_map:
-                    yolo_object.label = label_map[yolo_object.label]
-                dst_objects.append(yolo_object)
-            else:
-                if yolo_object.label in label_map:
-                    yolo_object.label = label_map[yolo_object.label]
-                    dst_objects.append(yolo_object)
-        dst_yolo_record.objects = dst_objects
-        return dst_yolo_record
+        if not filename.endswith('.txt'):
+            filename = filename + '.txt'
+        khandy.save_list(filename, records)
 
     @staticmethod
     def to_ir(yolo_record: YoloRecord) -> DetectIrRecord:
@@ -419,9 +408,9 @@ class CocoRecord:
     objects: List[CocoObject] = field(default_factory=list)
     
 
-class CocoDetectHandler:
+class CocoHandler:
     @staticmethod
-    def load(filename) -> List[CocoRecord]:
+    def load(filename, **kwargs) -> List[CocoRecord]:
         json_data = khandy.load_json(filename)
         
         images = json_data['images']
@@ -487,3 +476,107 @@ class CocoDetectHandler:
         return coco_record
         
         
+def load_detect(filename, fmt, **kwargs) -> DetectIrRecord:
+    if fmt == 'labelme':
+        labelme_record = LabelmeHandler.load(filename, **kwargs)
+        ir_record = LabelmeHandler.to_ir(labelme_record)
+    elif fmt == 'yolo':
+        yolo_record = YoloHandler.load(filename, **kwargs)
+        ir_record = YoloHandler.to_ir(yolo_record)
+    elif fmt in ('voc', 'pascal', 'pascal_voc'):
+        pascal_voc_record = PascalVocHandler.load(filename, **kwargs)
+        ir_record = PascalVocHandler.to_ir(pascal_voc_record)
+    elif fmt == 'coco':
+        coco_records = CocoHandler.load(filename, **kwargs)
+        ir_record = [CocoHandler.to_ir(coco_record) for coco_record in coco_records]
+    else:
+        raise ValueError(f"Unsupported detect label fmt. Got {fmt}")
+    return ir_record
+    
+    
+def save_detect(filename, ir_record: DetectIrRecord, out_fmt):
+    os.makedirs(os.path.dirname(os.path.abspath(filename)), exist_ok=True)
+    if out_fmt == 'labelme':
+        labelme_record = LabelmeHandler.from_ir(ir_record)
+        LabelmeHandler.save(filename, labelme_record)
+    elif out_fmt == 'yolo':
+        yolo_record = YoloHandler.from_ir(ir_record)
+        YoloHandler.save(filename, yolo_record)
+    elif out_fmt in ('voc', 'pascal', 'pascal_voc'):
+        pascal_voc_record = PascalVocHandler.from_ir(ir_record)
+        PascalVocHandler.save(filename, pascal_voc_record)
+    elif out_fmt == 'coco':
+        raise ValueError("Unsupported for `coco` now!")
+    else:
+        raise ValueError(f"Unsupported detect label fmt. Got {out_fmt}")
+
+
+def _get_format(record):
+    if isinstance(record, LabelmeRecord):
+        return ('labelme',)
+    elif isinstance(record, YoloRecord):
+        return ('yolo',)
+    elif isinstance(record, PascalVocRecord):
+        return ('voc', 'pascal', 'pascal_voc')
+    elif isinstance(record, CocoRecord):
+        return ('coco',)
+    elif isinstance(record, DetectIrRecord):
+        return ('ir', 'detect_ir')
+    else:
+        return ()
+
+
+def convert_detect(record, out_fmt):
+    allowed_fmts = ('labelme', 'yolo', 'voc', 'coco', 'pascal', 'pascal_voc', 'ir', 'detect_ir')
+    if out_fmt not in allowed_fmts:
+        raise ValueError("Unsupported label format conversions for given out_fmt")
+    if out_fmt in _get_format(record):
+        return record
+
+    if isinstance(record, LabelmeRecord):
+        ir_record = LabelmeHandler.to_ir(record)
+    elif isinstance(record, YoloRecord):
+        ir_record = YoloHandler.to_ir(record)
+    elif isinstance(record, PascalVocRecord):
+        ir_record = PascalVocHandler.to_ir(record)
+    elif isinstance(record, CocoRecord):
+        ir_record = CocoHandler.to_ir(record)
+    elif isinstance(record, DetectIrRecord):
+        ir_record = record
+    else:
+        raise ValueError('Unsupported type for record')
+        
+    if out_fmt in ('ir', 'detect_ir'):
+        dst_record = ir_record
+    elif out_fmt == 'labelme':
+        dst_record = LabelmeHandler.from_ir(ir_record)
+    elif out_fmt == 'yolo':
+        dst_record = YoloHandler.from_ir(ir_record)
+    elif out_fmt in ('voc', 'pascal', 'pascal_voc'):
+        dst_record = PascalVocHandler.from_ir(ir_record)
+    elif out_fmt == 'coco':
+        dst_record = CocoHandler.from_ir(ir_record)
+    return dst_record
+    
+
+def replace_detect_label(record: DetectIrRecord, label_map, ignore=True):
+    dst_record = copy.deepcopy(record)
+    dst_objects = []
+    for ir_object in dst_record.objects:
+        if not ignore:
+            if ir_object.label in label_map:
+                ir_object.label = label_map[ir_object.label]
+            dst_objects.append(ir_object)
+        else:
+            if ir_object.label in label_map:
+                ir_object.label = label_map[ir_object.label]
+                dst_objects.append(ir_object)
+    dst_record.objects = dst_objects
+    return dst_record
+
+
+def load_coco_class_names(filename):
+    json_data = khandy.load_json(filename)
+    categories = json_data['categories']
+    return [cat_item['name'] for cat_item in categories]
+