|
@@ -1,3 +1,4 @@
|
|
|
+import os
|
|
|
import copy
|
|
|
import json
|
|
|
import dataclasses
|
|
@@ -12,12 +13,9 @@ import lxml.builder
|
|
|
import numpy as np
|
|
|
|
|
|
|
|
|
-__all__ = ['DetectIrObject', 'DetectIrRecord',
|
|
|
- 'PascalVocSource', 'PascalVocSize', 'PascalVocBndbox',
|
|
|
- 'PascalVocObject', 'PascalVocRecord', 'PascalVocHandler',
|
|
|
- 'LabelmeShape', 'LabelmeRecord', 'LabelmeHandler',
|
|
|
- 'YoloObject', 'YoloRecord', 'YoloHandler',
|
|
|
- 'CocoObject', 'CocoRecord', 'CocoDetectHandler']
|
|
|
+__all__ = ['DetectIrObject', 'DetectIrRecord', 'load_detect',
|
|
|
+ 'save_detect', 'convert_detect', 'replace_detect_label',
|
|
|
+ 'load_coco_class_names']
|
|
|
|
|
|
|
|
|
@dataclass
|
|
@@ -85,7 +83,7 @@ class PascalVocRecord:
|
|
|
|
|
|
class PascalVocHandler:
|
|
|
@staticmethod
|
|
|
- def load(filename) -> PascalVocRecord:
|
|
|
+ def load(filename, **kwargs) -> PascalVocRecord:
|
|
|
pascal_voc_record = PascalVocRecord()
|
|
|
|
|
|
xml_tree = ET.parse(filename)
|
|
@@ -160,6 +158,8 @@ class PascalVocHandler:
|
|
|
)
|
|
|
xml.append(object_tag)
|
|
|
|
|
|
+ if not filename.endswith('.xml'):
|
|
|
+ filename = filename + '.xml'
|
|
|
with open(filename, 'wb') as f:
|
|
|
f.write(lxml.etree.tostring(xml, pretty_print=True, encoding='utf-8'))
|
|
|
|
|
@@ -251,7 +251,7 @@ class LabelmeRecord:
|
|
|
|
|
|
class LabelmeHandler:
|
|
|
@staticmethod
|
|
|
- def load(filename) -> LabelmeRecord:
|
|
|
+ def load(filename, **kwargs) -> LabelmeRecord:
|
|
|
json_content = khandy.load_json(filename)
|
|
|
return LabelmeRecord(**json_content)
|
|
|
|
|
@@ -260,6 +260,7 @@ class LabelmeHandler:
|
|
|
json_content = dataclasses.asdict(labelme_record)
|
|
|
khandy.save_json(filename, json_content, cls=_NumpyEncoder)
|
|
|
|
|
|
+ @staticmethod
|
|
|
def to_ir(labelme_record: LabelmeRecord) -> DetectIrRecord:
|
|
|
ir_record = DetectIrRecord(
|
|
|
filename=labelme_record.imagePath,
|
|
@@ -317,10 +318,12 @@ class YoloRecord:
|
|
|
class YoloHandler:
|
|
|
@staticmethod
|
|
|
def load(filename, **kwargs) -> YoloRecord:
|
|
|
- records = khandy.load_list(filename)
|
|
|
+ assert 'image_filename' in kwargs
|
|
|
+ assert 'width' in kwargs and 'height' in kwargs
|
|
|
|
|
|
+ records = khandy.load_list(filename)
|
|
|
yolo_record = YoloRecord(
|
|
|
- filename=kwargs.get('filename'),
|
|
|
+ filename=kwargs.get('image_filename'),
|
|
|
width=kwargs.get('width'),
|
|
|
height=kwargs.get('height'))
|
|
|
for record in records:
|
|
@@ -335,27 +338,13 @@ class YoloHandler:
|
|
|
return yolo_record
|
|
|
|
|
|
@staticmethod
|
|
|
- def save(filemame, yolo_record: YoloRecord):
|
|
|
+ def save(filename, yolo_record: YoloRecord):
|
|
|
records = []
|
|
|
for object in yolo_record.objects:
|
|
|
records.append(f'{object.label} {object.x_center} {object.y_center} {object.width} {object.height}')
|
|
|
- khandy.save_list(filemame, records)
|
|
|
-
|
|
|
- @staticmethod
|
|
|
- def replace_label(yolo_record: YoloRecord, label_map, ignore=False):
|
|
|
- dst_yolo_record = copy.deepcopy(yolo_record)
|
|
|
- dst_objects = []
|
|
|
- for yolo_object in dst_yolo_record.objects:
|
|
|
- if not ignore:
|
|
|
- if yolo_object.label in label_map:
|
|
|
- yolo_object.label = label_map[yolo_object.label]
|
|
|
- dst_objects.append(yolo_object)
|
|
|
- else:
|
|
|
- if yolo_object.label in label_map:
|
|
|
- yolo_object.label = label_map[yolo_object.label]
|
|
|
- dst_objects.append(yolo_object)
|
|
|
- dst_yolo_record.objects = dst_objects
|
|
|
- return dst_yolo_record
|
|
|
+ if not filename.endswith('.txt'):
|
|
|
+ filename = filename + '.txt'
|
|
|
+ khandy.save_list(filename, records)
|
|
|
|
|
|
@staticmethod
|
|
|
def to_ir(yolo_record: YoloRecord) -> DetectIrRecord:
|
|
@@ -419,9 +408,9 @@ class CocoRecord:
|
|
|
objects: List[CocoObject] = field(default_factory=list)
|
|
|
|
|
|
|
|
|
-class CocoDetectHandler:
|
|
|
+class CocoHandler:
|
|
|
@staticmethod
|
|
|
- def load(filename) -> List[CocoRecord]:
|
|
|
+ def load(filename, **kwargs) -> List[CocoRecord]:
|
|
|
json_data = khandy.load_json(filename)
|
|
|
|
|
|
images = json_data['images']
|
|
@@ -487,3 +476,107 @@ class CocoDetectHandler:
|
|
|
return coco_record
|
|
|
|
|
|
|
|
|
+def load_detect(filename, fmt, **kwargs) -> DetectIrRecord:
|
|
|
+ if fmt == 'labelme':
|
|
|
+ labelme_record = LabelmeHandler.load(filename, **kwargs)
|
|
|
+ ir_record = LabelmeHandler.to_ir(labelme_record)
|
|
|
+ elif fmt == 'yolo':
|
|
|
+ yolo_record = YoloHandler.load(filename, **kwargs)
|
|
|
+ ir_record = YoloHandler.to_ir(yolo_record)
|
|
|
+ elif fmt in ('voc', 'pascal', 'pascal_voc'):
|
|
|
+ pascal_voc_record = PascalVocHandler.load(filename, **kwargs)
|
|
|
+ ir_record = PascalVocHandler.to_ir(pascal_voc_record)
|
|
|
+ elif fmt == 'coco':
|
|
|
+ coco_records = CocoHandler.load(filename, **kwargs)
|
|
|
+ ir_record = [CocoHandler.to_ir(coco_record) for coco_record in coco_records]
|
|
|
+ else:
|
|
|
+ raise ValueError(f"Unsupported detect label fmt. Got {fmt}")
|
|
|
+ return ir_record
|
|
|
+
|
|
|
+
|
|
|
+def save_detect(filename, ir_record: DetectIrRecord, out_fmt):
|
|
|
+ os.makedirs(os.path.dirname(os.path.abspath(filename)), exist_ok=True)
|
|
|
+ if out_fmt == 'labelme':
|
|
|
+ labelme_record = LabelmeHandler.from_ir(ir_record)
|
|
|
+ LabelmeHandler.save(filename, labelme_record)
|
|
|
+ elif out_fmt == 'yolo':
|
|
|
+ yolo_record = YoloHandler.from_ir(ir_record)
|
|
|
+ YoloHandler.save(filename, yolo_record)
|
|
|
+ elif out_fmt in ('voc', 'pascal', 'pascal_voc'):
|
|
|
+ pascal_voc_record = PascalVocHandler.from_ir(ir_record)
|
|
|
+ PascalVocHandler.save(filename, pascal_voc_record)
|
|
|
+ elif out_fmt == 'coco':
|
|
|
+ raise ValueError("Unsupported for `coco` now!")
|
|
|
+ else:
|
|
|
+ raise ValueError(f"Unsupported detect label fmt. Got {out_fmt}")
|
|
|
+
|
|
|
+
|
|
|
+def _get_format(record):
|
|
|
+ if isinstance(record, LabelmeRecord):
|
|
|
+ return ('labelme',)
|
|
|
+ elif isinstance(record, YoloRecord):
|
|
|
+ return ('yolo',)
|
|
|
+ elif isinstance(record, PascalVocRecord):
|
|
|
+ return ('voc', 'pascal', 'pascal_voc')
|
|
|
+ elif isinstance(record, CocoRecord):
|
|
|
+ return ('coco',)
|
|
|
+ elif isinstance(record, DetectIrRecord):
|
|
|
+ return ('ir', 'detect_ir')
|
|
|
+ else:
|
|
|
+ return ()
|
|
|
+
|
|
|
+
|
|
|
+def convert_detect(record, out_fmt):
|
|
|
+ allowed_fmts = ('labelme', 'yolo', 'voc', 'coco', 'pascal', 'pascal_voc', 'ir', 'detect_ir')
|
|
|
+ if out_fmt not in allowed_fmts:
|
|
|
+ raise ValueError("Unsupported label format conversions for given out_fmt")
|
|
|
+ if out_fmt in _get_format(record):
|
|
|
+ return record
|
|
|
+
|
|
|
+ if isinstance(record, LabelmeRecord):
|
|
|
+ ir_record = LabelmeHandler.to_ir(record)
|
|
|
+ elif isinstance(record, YoloRecord):
|
|
|
+ ir_record = YoloHandler.to_ir(record)
|
|
|
+ elif isinstance(record, PascalVocRecord):
|
|
|
+ ir_record = PascalVocHandler.to_ir(record)
|
|
|
+ elif isinstance(record, CocoRecord):
|
|
|
+ ir_record = CocoHandler.to_ir(record)
|
|
|
+ elif isinstance(record, DetectIrRecord):
|
|
|
+ ir_record = record
|
|
|
+ else:
|
|
|
+ raise ValueError('Unsupported type for record')
|
|
|
+
|
|
|
+ if out_fmt in ('ir', 'detect_ir'):
|
|
|
+ dst_record = ir_record
|
|
|
+ elif out_fmt == 'labelme':
|
|
|
+ dst_record = LabelmeHandler.from_ir(ir_record)
|
|
|
+ elif out_fmt == 'yolo':
|
|
|
+ dst_record = YoloHandler.from_ir(ir_record)
|
|
|
+ elif out_fmt in ('voc', 'pascal', 'pascal_voc'):
|
|
|
+ dst_record = PascalVocHandler.from_ir(ir_record)
|
|
|
+ elif out_fmt == 'coco':
|
|
|
+ dst_record = CocoHandler.from_ir(ir_record)
|
|
|
+ return dst_record
|
|
|
+
|
|
|
+
|
|
|
+def replace_detect_label(record: DetectIrRecord, label_map, ignore=True):
|
|
|
+ dst_record = copy.deepcopy(record)
|
|
|
+ dst_objects = []
|
|
|
+ for ir_object in dst_record.objects:
|
|
|
+ if not ignore:
|
|
|
+ if ir_object.label in label_map:
|
|
|
+ ir_object.label = label_map[ir_object.label]
|
|
|
+ dst_objects.append(ir_object)
|
|
|
+ else:
|
|
|
+ if ir_object.label in label_map:
|
|
|
+ ir_object.label = label_map[ir_object.label]
|
|
|
+ dst_objects.append(ir_object)
|
|
|
+ dst_record.objects = dst_objects
|
|
|
+ return dst_record
|
|
|
+
|
|
|
+
|
|
|
+def load_coco_class_names(filename):
|
|
|
+ json_data = khandy.load_json(filename)
|
|
|
+ categories = json_data['categories']
|
|
|
+ return [cat_item['name'] for cat_item in categories]
|
|
|
+
|