Browse Source

refactor khandy.label.detect

quarrying 2 years ago
parent
commit
1f56aa8778
2 changed files with 298 additions and 193 deletions
  1. 296 192
      khandy/label/detect.py
  2. 2 1
      requirements.txt

+ 296 - 192
khandy/label/detect.py

@@ -2,6 +2,7 @@ import copy
 import json
 import dataclasses
 from dataclasses import dataclass, field
+from collections import OrderedDict
 from typing import Optional, List
 import xml.etree.ElementTree as ET
 
@@ -10,14 +11,36 @@ import lxml
 import lxml.builder
 import numpy as np
 
-__all__ = ['PascalVocSource', 'PascalVocSize', 'PascalVocBndbox', 
-           'PascalVocObject', 'PascalVocInfo', 'PascalVocHandler',
-           'LabelmeShape', 'LabelmeInfo', 'LabelmeHandler',
-           'YoloObject', 'YoloInfo', 'YoloHandler', 
-           'convert_pascal_voc_to_labelme', 'convert_labelme_to_pascal_voc',
-           'convert_labelme_to_yolo', 'convert_yolo_to_labelme',
-           'convert_pascal_voc_to_yolo', 'convert_yolo_to_pascal_voc']
 
+__all__ = ['DetectIrObject', 'DetectIrRecord',
+           'PascalVocSource', 'PascalVocSize', 'PascalVocBndbox', 
+           'PascalVocObject', 'PascalVocRecord', 'PascalVocHandler',
+           'LabelmeShape', 'LabelmeRecord', 'LabelmeHandler',
+           'YoloObject', 'YoloRecord', 'YoloHandler',
+           'CocoObject', 'CocoRecord', 'CocoDetectHandler']
+
+
+@dataclass
+class DetectIrObject:
+    """Intermediate Representation Format of Object
+    """
+    label: str
+    x_min: float
+    y_min: float
+    x_max: float
+    y_max: float
+    
+    
+@dataclass
+class DetectIrRecord:
+    """Intermediate Representation Format of Record
+    """
+    filename: str
+    width: int
+    height: int
+    objects: List[DetectIrObject] = field(default_factory=list)
+    
+    
 @dataclass
 class PascalVocSource:
     database: str = ''
@@ -50,98 +73,139 @@ class PascalVocObject:
     
     
 @dataclass
-class PascalVocInfo:
+class PascalVocRecord:
     folder: str = ''
     filename: str = ''
     path: str = ''
     source: PascalVocSource = PascalVocSource()
     size: Optional[PascalVocSize] = None
     segmented: int = 0
-    object: List[PascalVocObject] = field(default_factory=list)
+    objects: List[PascalVocObject] = field(default_factory=list)
     
     
 class PascalVocHandler:
     @staticmethod
-    def load(filename) -> PascalVocInfo:
-        pascal_voc_info = PascalVocInfo()
+    def load(filename) -> PascalVocRecord:
+        pascal_voc_record = PascalVocRecord()
         
         xml_tree = ET.parse(filename)
-        pascal_voc_info.folder = xml_tree.find('folder').text
-        pascal_voc_info.filename = xml_tree.find('filename').text
-        pascal_voc_info.path = xml_tree.find('path').text
-        pascal_voc_info.segmented = xml_tree.find('segmented').text
+        pascal_voc_record.folder = xml_tree.find('folder').text
+        pascal_voc_record.filename = xml_tree.find('filename').text
+        pascal_voc_record.path = xml_tree.find('path').text
+        pascal_voc_record.segmented = xml_tree.find('segmented').text
         
         source_tag = xml_tree.find('source')
-        pascal_voc_info.source = PascalVocSource(
+        pascal_voc_record.source = PascalVocSource(
             database=source_tag.find('database').text,
             # annotation=source_tag.find('annotation').text,
             # image=source_tag.find('image').text
         )
         
         size_tag = xml_tree.find('size')
-        pascal_voc_info.size = PascalVocSize(
+        pascal_voc_record.size = PascalVocSize(
             width=int(size_tag.find('width').text),
             height=int(size_tag.find('height').text),
             depth=int(size_tag.find('depth').text)
         )
         
         object_tags = xml_tree.findall('object')
-        for index, obj in enumerate(object_tags):
-            bndbox_tag = obj.find('bndbox')
+        for index, object_tag in enumerate(object_tags):
+            bndbox_tag = object_tag.find('bndbox')
             bndbox = PascalVocBndbox(
                 xmin=float(bndbox_tag.find('xmin').text) - 1,
                 ymin=float(bndbox_tag.find('ymin').text) - 1,
                 xmax=float(bndbox_tag.find('xmax').text) - 1,
                 ymax=float(bndbox_tag.find('ymax').text) - 1
             )
-            one_object = PascalVocObject(
-                name=obj.find('name').text,
-                pose=obj.find('pose').text,
-                truncated=obj.find('truncated').text,
-                difficult=obj.find('difficult').text,
+            pascal_voc_object = PascalVocObject(
+                name=object_tag.find('name').text,
+                pose=object_tag.find('pose').text,
+                truncated=object_tag.find('truncated').text,
+                difficult=object_tag.find('difficult').text,
                 bndbox=bndbox
             )
-            pascal_voc_info.object.append(one_object)
-        return pascal_voc_info
+            pascal_voc_record.objects.append(pascal_voc_object)
+        return pascal_voc_record
         
     @staticmethod
-    def save(filename, pascal_voc_info: PascalVocInfo):
+    def save(filename, pascal_voc_record: PascalVocRecord):
         maker = lxml.builder.ElementMaker()
         xml = maker.annotation(
-            maker.folder(pascal_voc_info.folder),
-            maker.filename(pascal_voc_info.filename),
-            maker.path(pascal_voc_info.path),
+            maker.folder(pascal_voc_record.folder),
+            maker.filename(pascal_voc_record.filename),
+            maker.path(pascal_voc_record.path),
             maker.source(
-                maker.database(pascal_voc_info.source.database),
+                maker.database(pascal_voc_record.source.database),
             ),
             maker.size( 
-                maker.width(str(pascal_voc_info.size.width)),
-                maker.height(str(pascal_voc_info.size.height)),
-                maker.depth(str(pascal_voc_info.size.depth)),
+                maker.width(str(pascal_voc_record.size.width)),
+                maker.height(str(pascal_voc_record.size.height)),
+                maker.depth(str(pascal_voc_record.size.depth)),
             ),
-            maker.segmented(str(pascal_voc_info.segmented)),
+            maker.segmented(str(pascal_voc_record.segmented)),
         )
         
-        for one_object in pascal_voc_info.object:
+        for pascal_voc_object in pascal_voc_record.objects:
             object_tag = maker.object(
-                maker.name(one_object.name),
-                maker.pose(one_object.pose),
-                maker.truncated(str(one_object.truncated)),
-                maker.difficult(str(one_object.difficult)),
+                maker.name(pascal_voc_object.name),
+                maker.pose(pascal_voc_object.pose),
+                maker.truncated(str(pascal_voc_object.truncated)),
+                maker.difficult(str(pascal_voc_object.difficult)),
                 maker.bndbox(
-                    maker.xmin(str(float(one_object.bndbox.xmin))),
-                    maker.ymin(str(float(one_object.bndbox.ymin))),
-                    maker.xmax(str(float(one_object.bndbox.xmax))),
-                    maker.ymax(str(float(one_object.bndbox.ymax))),
+                    maker.xmin(str(float(pascal_voc_object.bndbox.xmin))),
+                    maker.ymin(str(float(pascal_voc_object.bndbox.ymin))),
+                    maker.xmax(str(float(pascal_voc_object.bndbox.xmax))),
+                    maker.ymax(str(float(pascal_voc_object.bndbox.ymax))),
                 ),
             )
             xml.append(object_tag)
             
         with open(filename, 'wb') as f:
             f.write(lxml.etree.tostring(xml, pretty_print=True, encoding='utf-8'))
+            
+    @staticmethod
+    def to_ir(pascal_voc_record: PascalVocRecord) -> DetectIrRecord:
+        ir_record = DetectIrRecord(
+            filename=pascal_voc_record.filename,
+            width=pascal_voc_record.size.width,
+            height=pascal_voc_record.size.height
+        )
+        for pascal_voc_object in pascal_voc_record.objects:
+            ir_object = DetectIrObject(
+                label=pascal_voc_object.name,
+                x_min=pascal_voc_object.bndbox.xmin,
+                y_min=pascal_voc_object.bndbox.ymin,
+                x_max=pascal_voc_object.bndbox.xmax,
+                y_max=pascal_voc_object.bndbox.ymax
+            )
+            ir_record.objects.append(ir_object)
+        return ir_record
         
-
-class NumpyEncoder(json.JSONEncoder):
+    @staticmethod
+    def from_ir(ir_record: DetectIrRecord) -> PascalVocRecord:
+        pascal_voc_record = PascalVocRecord(
+            filename=ir_record.filename,
+            size=PascalVocSize(
+                width=ir_record.width,
+                height=ir_record.height,
+                depth=3
+            )
+        )
+        for ir_object in ir_record.objects:
+            pascal_voc_object = PascalVocObject(
+                name=ir_object.label,
+                bndbox=PascalVocBndbox(
+                    xmin=ir_object.x_min,
+                    ymin=ir_object.y_min,
+                    xmax=ir_object.x_max,
+                    ymax=ir_object.y_max,
+                )
+            )
+            pascal_voc_record.objects.append(pascal_voc_object)
+        return pascal_voc_record
+        
+        
+class _NumpyEncoder(json.JSONEncoder):
     """ Special json encoder for numpy types """
     def default(self, obj):
         if isinstance(obj, (np.bool_,)):
@@ -171,7 +235,7 @@ class LabelmeShape:
 
 
 @dataclass
-class LabelmeInfo:
+class LabelmeRecord:
     version: str = '4.5.6'
     flags: dict = field(default_factory=dict)
     shapes: List[LabelmeShape] = field(default_factory=list)
@@ -187,15 +251,51 @@ class LabelmeInfo:
 
 class LabelmeHandler:
     @staticmethod
-    def load(filename) -> LabelmeInfo:
+    def load(filename) -> LabelmeRecord:
         json_content = khandy.load_json(filename)
-        return LabelmeInfo(**json_content)
+        return LabelmeRecord(**json_content)
 
     @staticmethod
-    def save(filename, labelme_info: LabelmeInfo):
-        json_content = dataclasses.asdict(labelme_info)
-        khandy.save_json(filename, json_content, cls=NumpyEncoder)
+    def save(filename, labelme_record: LabelmeRecord):
+        json_content = dataclasses.asdict(labelme_record)
+        khandy.save_json(filename, json_content, cls=_NumpyEncoder)
 
+    def to_ir(labelme_record: LabelmeRecord) -> DetectIrRecord:
+        ir_record = DetectIrRecord(
+            filename=labelme_record.imagePath,
+            width=labelme_record.imageWidth,
+            height=labelme_record.imageHeight
+        )
+        for labelme_shape in labelme_record.shapes:
+            if labelme_shape.shape_type != 'rectangle':
+                continue
+            ir_object = DetectIrObject(
+                label=labelme_shape.label,
+                x_min=labelme_shape.points[0][0],
+                y_min=labelme_shape.points[0][1],
+                x_max=labelme_shape.points[1][0],
+                y_max=labelme_shape.points[1][1],
+            )
+            ir_record.objects.append(ir_object)
+        return ir_record
+        
+    @staticmethod
+    def from_ir(ir_record: DetectIrRecord) -> LabelmeRecord:
+        labelme_record = LabelmeRecord(
+            imagePath=ir_record.filename,
+            imageWidth=ir_record.width,
+            imageHeight=ir_record.height
+        )
+        for ir_object in ir_record.objects:
+            labelme_shape = LabelmeShape(
+                label=ir_object.label,
+                shape_type='rectangle',
+                points=[[ir_object.x_min, ir_object.y_min], 
+                        [ir_object.x_max, ir_object.y_max]]
+            )
+            labelme_record.shapes.append(labelme_shape)
+        return labelme_record
+        
 
 @dataclass
 class YoloObject:
@@ -207,8 +307,8 @@ class YoloObject:
     
     
 @dataclass
-class YoloInfo:
-    image_filename: Optional[str] = None
+class YoloRecord:
+    filename: Optional[str] = None
     width: Optional[int] = None
     height: Optional[int] = None
     objects: List[YoloObject] = field(default_factory=list)
@@ -216,170 +316,174 @@ class YoloInfo:
     
 class YoloHandler:
     @staticmethod
-    def load(filename, **kwargs) -> YoloInfo:
+    def load(filename, **kwargs) -> YoloRecord:
         records = khandy.load_list(filename)
 
-        yolo_info = YoloInfo(
-            image_filename=kwargs.get('image_filename'),
+        yolo_record = YoloRecord(
+            filename=kwargs.get('filename'),
             width=kwargs.get('width'),
             height=kwargs.get('height'))
         for record in records:
             record_parts = record.split()
-            
-            yolo_info.objects.append(YoloObject(
+            yolo_record.objects.append(YoloObject(
                 label=record_parts[0],
                 x_center=float(record_parts[1]),
                 y_center=float(record_parts[2]),
                 width=float(record_parts[3]),
                 height=float(record_parts[4]),
             ))
-        return yolo_info
+        return yolo_record
 
     @staticmethod
-    def save(filemame, yolo_info: YoloInfo):
+    def save(filemame, yolo_record: YoloRecord):
         records = []
-        for object in yolo_info.objects:
+        for object in yolo_record.objects:
             records.append(f'{object.label} {object.x_center} {object.y_center} {object.width} {object.height}')
         khandy.save_list(filemame, records)
 
     @staticmethod
-    def replace_label(yolo_info: YoloInfo, label_map):
-        dst_yolo_info = copy.deepcopy(yolo_info)
-        for object in dst_yolo_info.objects:
-            object.label= label_map[object.label]
-        return dst_yolo_info
-
+    def replace_label(yolo_record: YoloRecord, label_map, ignore=False):
+        dst_yolo_record = copy.deepcopy(yolo_record)
+        dst_objects = []
+        for yolo_object in dst_yolo_record.objects:
+            if not ignore:
+                if yolo_object.label in label_map:
+                    yolo_object.label = label_map[yolo_object.label]
+                dst_objects.append(yolo_object)
+            else:
+                if yolo_object.label in label_map:
+                    yolo_object.label = label_map[yolo_object.label]
+                    dst_objects.append(yolo_object)
+        dst_yolo_record.objects = dst_objects
+        return dst_yolo_record
 
-def convert_pascal_voc_to_labelme(pascal_voc_info: PascalVocInfo) -> LabelmeInfo:
-    labelme_info = LabelmeInfo(
-        imagePath=pascal_voc_info.filename,
-        imageWidth=pascal_voc_info.size.width,
-        imageHeight=pascal_voc_info.size.height
-    )
-    for object in pascal_voc_info.object:
-        labelme_shape = LabelmeShape(
-            label=object.name,
-            shape_type='rectangle',
-            points=[[object.bndbox.xmin, object.bndbox.ymin], 
-                    [object.bndbox.xmax, object.bndbox.ymax]]
-        )
-        labelme_info.shapes.append(labelme_shape)
-    return labelme_info
-
-
-def convert_labelme_to_pascal_voc(labelme_info: LabelmeInfo) -> PascalVocInfo:
-    pascal_voc_info = PascalVocInfo(
-        filename=labelme_info.imagePath,
-        size=PascalVocSize(
-            width=labelme_info.imageWidth,
-            height=labelme_info.imageHeight,
-            depth=3
+    @staticmethod
+    def to_ir(yolo_record: YoloRecord) -> DetectIrRecord:
+        ir_record = DetectIrRecord(
+            filename=yolo_record.filename,
+            width=yolo_record.width,
+            height=yolo_record.height
         )
-    )
-    for shape in labelme_info.shapes:
-        if shape.shape_type != 'rectangle':
-            continue
-        pascal_voc_object = PascalVocObject(
-            name=shape.label,
-            bndbox=PascalVocBndbox(
-                xmin=shape.points[0][0],
-                ymin=shape.points[0][1],
-                xmax=shape.points[1][0],
-                ymax=shape.points[1][1],
+        for yolo_object in yolo_record.objects:
+            x_min = (yolo_object.x_center - 0.5 * yolo_object.width) * yolo_record.width
+            y_min = (yolo_object.y_center - 0.5 * yolo_object.height) * yolo_record.height
+            x_max = (yolo_object.x_center + 0.5 * yolo_object.width) * yolo_record.width
+            y_max = (yolo_object.y_center + 0.5 * yolo_object.height) * yolo_record.height
+            ir_object = DetectIrObject(
+                label=yolo_object.label,
+                x_min=x_min,
+                y_min=y_min,
+                x_max=x_max,
+                y_max=y_max
             )
+            ir_record.objects.append(ir_object)
+        return ir_record
+        
+    @staticmethod
+    def from_ir(ir_record: DetectIrRecord) -> YoloRecord:
+        yolo_record = YoloRecord(
+            filename=ir_record.filename,
+            width=ir_record.width,
+            height=ir_record.height
         )
-        pascal_voc_info.object.append(pascal_voc_object)
-    return pascal_voc_info
-
-
-def convert_labelme_to_yolo(labelme_info: LabelmeInfo) -> YoloInfo:
-    yolo_info = YoloInfo(
-        image_filename=labelme_info.imagePath,
-        width=labelme_info.imageWidth,
-        height=labelme_info.imageHeight
-    )
-    for shape in labelme_info.shapes:
-        if shape.shape_type != 'rectangle':
-            continue
-        x_center = (shape.points[0][0] + shape.points[1][0]) / (2 * labelme_info.imageWidth)
-        y_center = (shape.points[0][1] + shape.points[1][1]) / (2 * labelme_info.imageHeight)
-        width = abs(shape.points[0][0] - shape.points[1][0]) / labelme_info.imageWidth
-        height = abs(shape.points[0][1] - shape.points[1][1]) / labelme_info.imageHeight
-        yolo_object = YoloObject(
-            label=shape.label,
-            x_center=x_center,
-            y_center=y_center,
-            width=width,
-            height=height,
-        )
-        yolo_info.objects.append(yolo_object)
-    return yolo_info
-    
-    
-def convert_yolo_to_labelme(yolo_info: YoloInfo) -> LabelmeInfo:
-    assert (yolo_info.width is not None) and (yolo_info.height is not None)
-
-    labelme_info = LabelmeInfo(
-        imagePath=yolo_info.image_filename,
-        imageHeight=yolo_info.height,
-        imageWidth=yolo_info.width,
-    )
-    for object in yolo_info.objects:
-        x_min = (object.x_center - 0.5 * object.width) * yolo_info.width
-        y_min = (object.y_center - 0.5 * object.height) * yolo_info.height
-        x_max = (object.x_center + 0.5 * object.width) * yolo_info.width
-        y_max = (object.y_center + 0.5 * object.height) * yolo_info.height
-        labelme_shape = LabelmeShape(
-            label=object.label,
-            shape_type='rectangle',
-            points=[[x_min, y_min], [x_max, y_max]]
-        )
-        labelme_info.shapes.append(labelme_shape)
-    return labelme_info
-
+        for ir_object in ir_record.objects:
+            x_center = (ir_object.x_max + ir_object.x_min) / (2 * ir_record.width)
+            y_center = (ir_object.y_max + ir_object.y_min) / (2 * ir_record.height)
+            width = abs(ir_object.x_max - ir_object.x_min) / ir_record.width
+            height = abs(ir_object.y_max - ir_object.y_min) / ir_record.height
+            yolo_object = YoloObject(
+                label=ir_object.label,
+                x_center=x_center,
+                y_center=y_center,
+                width=width,
+                height=height,
+            )
+            yolo_record.objects.append(yolo_object)
+        return yolo_record
     
-def convert_pascal_voc_to_yolo(pascal_voc_info: PascalVocInfo) -> YoloInfo:
-    yolo_info = YoloInfo(
-        image_filename=pascal_voc_info.filename,
-        width=pascal_voc_info.size.width,
-        height=pascal_voc_info.size.height
-    )
-    for object in pascal_voc_info.object:
-        x_center = (object.bndbox.xmax + object.bndbox.xmin) / (2 * pascal_voc_info.size.width)
-        y_center = (object.bndbox.ymax + object.bndbox.ymin) / (2 * pascal_voc_info.size.height)
-        width = abs(object.bndbox.xmax - object.bndbox.xmin) / pascal_voc_info.size.width
-        height = abs(object.bndbox.ymax - object.bndbox.ymin) / pascal_voc_info.size.height
-        yolo_object = YoloObject(
-            label=object.name,
-            x_center=x_center,
-            y_center=y_center,
-            width=width,
-            height=height,
-        )
-        yolo_info.objects.append(yolo_object)
-    return yolo_info
+        
+@dataclass
+class CocoObject:
+    label: str
+    x_min: float
+    y_min: float
+    width: float
+    height: float
     
     
-def convert_yolo_to_pascal_voc(yolo_info: YoloInfo) -> PascalVocInfo:
-    pascal_voc_info = PascalVocInfo(
-        filename=yolo_info.image_filename,
-        size=PascalVocSize(
-            width=yolo_info.width,
-            height=yolo_info.height,
-            depth=3
-        )
-    )
-    for object in yolo_info.objects:
-        x_min = (object.x_center - 0.5 * object.width) * yolo_info.width
-        y_min = (object.y_center - 0.5 * object.height) * yolo_info.height
-        x_max = (object.x_center + 0.5 * object.width) * yolo_info.width
-        y_max = (object.y_center + 0.5 * object.height) * yolo_info.height
-        voc_object = PascalVocObject(
-            name=object.label,
-            bndbox=PascalVocBndbox(xmin=x_min,ymin=y_min,xmax=x_max,ymax=y_max)
-        )
-        pascal_voc_info.object.append(voc_object)
-    return pascal_voc_info
+@dataclass
+class CocoRecord:
+    filename: str
+    width: int
+    height: int
+    objects: List[CocoObject] = field(default_factory=list)
     
 
+class CocoDetectHandler:
+    @staticmethod
+    def load(filename) -> List[CocoRecord]:
+        json_data = khandy.load_json(filename)
+        
+        images = json_data['images']
+        annotations = json_data['annotations']
+        categories = json_data['categories']
+        
+        label_map = {}
+        for cat_item in categories:
+            label_map[cat_item['id']] = cat_item['name']
+        
+        coco_records = OrderedDict()
+        for image_item in images:
+            coco_records[image_item['id']] = CocoRecord(
+                filename=image_item['file_name'],
+                width=image_item['width'],
+                height=image_item['height'],
+                objects=[])
+                
+        for annotation_item in annotations:
+            coco_object = CocoObject(
+                label=label_map[annotation_item['category_id']],
+                x_min=annotation_item['bbox'][0],
+                y_min=annotation_item['bbox'][1],
+                width=annotation_item['bbox'][2],
+                height=annotation_item['bbox'][3])
+            coco_records[annotation_item['image_id']].objects.append(coco_object)
+        return list(coco_records.values())
+        
+    @staticmethod
+    def to_ir(coco_record: CocoRecord) -> DetectIrRecord:
+        ir_record = DetectIrRecord(
+            filename=coco_record.filename,
+            width=coco_record.width,
+            height=coco_record.height,
+        )
+        for coco_object in coco_record.objects:
+            ir_object = DetectIrObject(
+                label=coco_object.label,
+                x_min=coco_object.x_min,
+                y_min=coco_object.y_min,
+                x_max=coco_object.x_min + coco_object.width,
+                y_max=coco_object.y_min + coco_object.height
+            )
+            ir_record.objects.append(ir_object)
+        return ir_record
 
+    @staticmethod
+    def from_ir(ir_record: DetectIrRecord) -> CocoRecord:
+        coco_record = CocoRecord(
+            filename=ir_record.filename,
+            width=ir_record.width,
+            height=ir_record.height
+        )
+        for ir_object in ir_record.objects:
+            coco_object = CocoObject(
+                label=ir_object.label,
+                x_min=ir_object.x_min,
+                y_min=ir_object.y_min,
+                width=ir_object.x_max - ir_object.x_min,
+                height=ir_object.y_max - ir_object.y_min
+            )
+            coco_record.objects.append(coco_object)
+        return coco_record
+        
+        

+ 2 - 1
requirements.txt

@@ -1,3 +1,4 @@
 numpy>=1.11.1
 opencv-python
-pillow
+pillow
+lxml