123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582 |
- import os
- import copy
- import json
- import dataclasses
- from dataclasses import dataclass, field
- from collections import OrderedDict
- from typing import Optional, List
- import xml.etree.ElementTree as ET
- import khandy
- import lxml
- import lxml.builder
- import numpy as np
- __all__ = ['DetectIrObject', 'DetectIrRecord', 'load_detect',
- 'save_detect', 'convert_detect', 'replace_detect_label',
- 'load_coco_class_names']
- @dataclass
- class DetectIrObject:
- """Intermediate Representation Format of Object
- """
- label: str
- x_min: float
- y_min: float
- x_max: float
- y_max: float
-
-
- @dataclass
- class DetectIrRecord:
- """Intermediate Representation Format of Record
- """
- filename: str
- width: int
- height: int
- objects: List[DetectIrObject] = field(default_factory=list)
-
-
- @dataclass
- class PascalVocSource:
- database: str = ''
- annotation: str = ''
- image: str = ''
-
-
- @dataclass
- class PascalVocSize:
- height: int
- width: int
- depth: int
-
-
- @dataclass
- class PascalVocBndbox:
- xmin: float
- ymin: float
- xmax: float
- ymax: float
-
-
- @dataclass
- class PascalVocObject:
- name: str
- pose: str = 'Unspecified'
- truncated: int = 0
- difficult: int = 0
- bndbox: Optional[PascalVocBndbox] = None
-
-
- @dataclass
- class PascalVocRecord:
- folder: str = ''
- filename: str = ''
- path: str = ''
- source: PascalVocSource = PascalVocSource()
- size: Optional[PascalVocSize] = None
- segmented: int = 0
- objects: List[PascalVocObject] = field(default_factory=list)
-
-
- class PascalVocHandler:
- @staticmethod
- def load(filename, **kwargs) -> PascalVocRecord:
- pascal_voc_record = PascalVocRecord()
-
- xml_tree = ET.parse(filename)
- pascal_voc_record.folder = xml_tree.find('folder').text
- pascal_voc_record.filename = xml_tree.find('filename').text
- pascal_voc_record.path = xml_tree.find('path').text
- pascal_voc_record.segmented = xml_tree.find('segmented').text
-
- source_tag = xml_tree.find('source')
- pascal_voc_record.source = PascalVocSource(
- database=source_tag.find('database').text,
- # annotation=source_tag.find('annotation').text,
- # image=source_tag.find('image').text
- )
-
- size_tag = xml_tree.find('size')
- pascal_voc_record.size = PascalVocSize(
- width=int(size_tag.find('width').text),
- height=int(size_tag.find('height').text),
- depth=int(size_tag.find('depth').text)
- )
-
- object_tags = xml_tree.findall('object')
- for index, object_tag in enumerate(object_tags):
- bndbox_tag = object_tag.find('bndbox')
- bndbox = PascalVocBndbox(
- xmin=float(bndbox_tag.find('xmin').text) - 1,
- ymin=float(bndbox_tag.find('ymin').text) - 1,
- xmax=float(bndbox_tag.find('xmax').text) - 1,
- ymax=float(bndbox_tag.find('ymax').text) - 1
- )
- pascal_voc_object = PascalVocObject(
- name=object_tag.find('name').text,
- pose=object_tag.find('pose').text,
- truncated=object_tag.find('truncated').text,
- difficult=object_tag.find('difficult').text,
- bndbox=bndbox
- )
- pascal_voc_record.objects.append(pascal_voc_object)
- return pascal_voc_record
-
- @staticmethod
- def save(filename, pascal_voc_record: PascalVocRecord):
- maker = lxml.builder.ElementMaker()
- xml = maker.annotation(
- maker.folder(pascal_voc_record.folder),
- maker.filename(pascal_voc_record.filename),
- maker.path(pascal_voc_record.path),
- maker.source(
- maker.database(pascal_voc_record.source.database),
- ),
- maker.size(
- maker.width(str(pascal_voc_record.size.width)),
- maker.height(str(pascal_voc_record.size.height)),
- maker.depth(str(pascal_voc_record.size.depth)),
- ),
- maker.segmented(str(pascal_voc_record.segmented)),
- )
-
- for pascal_voc_object in pascal_voc_record.objects:
- object_tag = maker.object(
- maker.name(pascal_voc_object.name),
- maker.pose(pascal_voc_object.pose),
- maker.truncated(str(pascal_voc_object.truncated)),
- maker.difficult(str(pascal_voc_object.difficult)),
- maker.bndbox(
- maker.xmin(str(float(pascal_voc_object.bndbox.xmin))),
- maker.ymin(str(float(pascal_voc_object.bndbox.ymin))),
- maker.xmax(str(float(pascal_voc_object.bndbox.xmax))),
- maker.ymax(str(float(pascal_voc_object.bndbox.ymax))),
- ),
- )
- xml.append(object_tag)
-
- if not filename.endswith('.xml'):
- filename = filename + '.xml'
- with open(filename, 'wb') as f:
- f.write(lxml.etree.tostring(xml, pretty_print=True, encoding='utf-8'))
-
- @staticmethod
- def to_ir(pascal_voc_record: PascalVocRecord) -> DetectIrRecord:
- ir_record = DetectIrRecord(
- filename=pascal_voc_record.filename,
- width=pascal_voc_record.size.width,
- height=pascal_voc_record.size.height
- )
- for pascal_voc_object in pascal_voc_record.objects:
- ir_object = DetectIrObject(
- label=pascal_voc_object.name,
- x_min=pascal_voc_object.bndbox.xmin,
- y_min=pascal_voc_object.bndbox.ymin,
- x_max=pascal_voc_object.bndbox.xmax,
- y_max=pascal_voc_object.bndbox.ymax
- )
- ir_record.objects.append(ir_object)
- return ir_record
-
- @staticmethod
- def from_ir(ir_record: DetectIrRecord) -> PascalVocRecord:
- pascal_voc_record = PascalVocRecord(
- filename=ir_record.filename,
- size=PascalVocSize(
- width=ir_record.width,
- height=ir_record.height,
- depth=3
- )
- )
- for ir_object in ir_record.objects:
- pascal_voc_object = PascalVocObject(
- name=ir_object.label,
- bndbox=PascalVocBndbox(
- xmin=ir_object.x_min,
- ymin=ir_object.y_min,
- xmax=ir_object.x_max,
- ymax=ir_object.y_max,
- )
- )
- pascal_voc_record.objects.append(pascal_voc_object)
- return pascal_voc_record
-
-
- class _NumpyEncoder(json.JSONEncoder):
- """ Special json encoder for numpy types """
- def default(self, obj):
- if isinstance(obj, (np.bool_,)):
- return bool(obj)
- elif isinstance(obj, (np.int_, np.intc, np.intp, np.int8,
- np.int16, np.int32, np.int64, np.uint8,
- np.uint16, np.uint32, np.uint64)):
- return int(obj)
- elif isinstance(obj, (np.float_, np.float16, np.float32,
- np.float64)):
- return float(obj)
- elif isinstance(obj, (np.ndarray,)):
- return obj.tolist()
- return json.JSONEncoder.default(self, obj)
- @dataclass
- class LabelmeShape:
- label: str
- points: np.ndarray
- shape_type: str
- flags: dict = field(default_factory=dict)
- group_id: Optional[int] = None
- def __post_init__(self):
- self.points = np.asarray(self.points)
- @dataclass
- class LabelmeRecord:
- version: str = '4.5.6'
- flags: dict = field(default_factory=dict)
- shapes: List[LabelmeShape] = field(default_factory=list)
- imagePath: Optional[str] = None
- imageData: Optional[str] = None
- imageHeight: Optional[int] = None
- imageWidth: Optional[int] = None
- def __post_init__(self):
- for k, shape in enumerate(self.shapes):
- self.shapes[k] = LabelmeShape(**shape)
- class LabelmeHandler:
- @staticmethod
- def load(filename, **kwargs) -> LabelmeRecord:
- json_content = khandy.load_json(filename)
- return LabelmeRecord(**json_content)
- @staticmethod
- def save(filename, labelme_record: LabelmeRecord):
- json_content = dataclasses.asdict(labelme_record)
- khandy.save_json(filename, json_content, cls=_NumpyEncoder)
- @staticmethod
- def to_ir(labelme_record: LabelmeRecord) -> DetectIrRecord:
- ir_record = DetectIrRecord(
- filename=labelme_record.imagePath,
- width=labelme_record.imageWidth,
- height=labelme_record.imageHeight
- )
- for labelme_shape in labelme_record.shapes:
- if labelme_shape.shape_type != 'rectangle':
- continue
- ir_object = DetectIrObject(
- label=labelme_shape.label,
- x_min=labelme_shape.points[0][0],
- y_min=labelme_shape.points[0][1],
- x_max=labelme_shape.points[1][0],
- y_max=labelme_shape.points[1][1],
- )
- ir_record.objects.append(ir_object)
- return ir_record
-
- @staticmethod
- def from_ir(ir_record: DetectIrRecord) -> LabelmeRecord:
- labelme_record = LabelmeRecord(
- imagePath=ir_record.filename,
- imageWidth=ir_record.width,
- imageHeight=ir_record.height
- )
- for ir_object in ir_record.objects:
- labelme_shape = LabelmeShape(
- label=ir_object.label,
- shape_type='rectangle',
- points=[[ir_object.x_min, ir_object.y_min],
- [ir_object.x_max, ir_object.y_max]]
- )
- labelme_record.shapes.append(labelme_shape)
- return labelme_record
-
- @dataclass
- class YoloObject:
- label: str
- x_center: float
- y_center: float
- width: float
- height: float
-
-
- @dataclass
- class YoloRecord:
- filename: Optional[str] = None
- width: Optional[int] = None
- height: Optional[int] = None
- objects: List[YoloObject] = field(default_factory=list)
-
-
- class YoloHandler:
- @staticmethod
- def load(filename, **kwargs) -> YoloRecord:
- assert 'image_filename' in kwargs
- assert 'width' in kwargs and 'height' in kwargs
- records = khandy.load_list(filename)
- yolo_record = YoloRecord(
- filename=kwargs.get('image_filename'),
- width=kwargs.get('width'),
- height=kwargs.get('height'))
- for record in records:
- record_parts = record.split()
- yolo_record.objects.append(YoloObject(
- label=record_parts[0],
- x_center=float(record_parts[1]),
- y_center=float(record_parts[2]),
- width=float(record_parts[3]),
- height=float(record_parts[4]),
- ))
- return yolo_record
- @staticmethod
- def save(filename, yolo_record: YoloRecord):
- records = []
- for object in yolo_record.objects:
- records.append(f'{object.label} {object.x_center} {object.y_center} {object.width} {object.height}')
- if not filename.endswith('.txt'):
- filename = filename + '.txt'
- khandy.save_list(filename, records)
- @staticmethod
- def to_ir(yolo_record: YoloRecord) -> DetectIrRecord:
- ir_record = DetectIrRecord(
- filename=yolo_record.filename,
- width=yolo_record.width,
- height=yolo_record.height
- )
- for yolo_object in yolo_record.objects:
- x_min = (yolo_object.x_center - 0.5 * yolo_object.width) * yolo_record.width
- y_min = (yolo_object.y_center - 0.5 * yolo_object.height) * yolo_record.height
- x_max = (yolo_object.x_center + 0.5 * yolo_object.width) * yolo_record.width
- y_max = (yolo_object.y_center + 0.5 * yolo_object.height) * yolo_record.height
- ir_object = DetectIrObject(
- label=yolo_object.label,
- x_min=x_min,
- y_min=y_min,
- x_max=x_max,
- y_max=y_max
- )
- ir_record.objects.append(ir_object)
- return ir_record
-
- @staticmethod
- def from_ir(ir_record: DetectIrRecord) -> YoloRecord:
- yolo_record = YoloRecord(
- filename=ir_record.filename,
- width=ir_record.width,
- height=ir_record.height
- )
- for ir_object in ir_record.objects:
- x_center = (ir_object.x_max + ir_object.x_min) / (2 * ir_record.width)
- y_center = (ir_object.y_max + ir_object.y_min) / (2 * ir_record.height)
- width = abs(ir_object.x_max - ir_object.x_min) / ir_record.width
- height = abs(ir_object.y_max - ir_object.y_min) / ir_record.height
- yolo_object = YoloObject(
- label=ir_object.label,
- x_center=x_center,
- y_center=y_center,
- width=width,
- height=height,
- )
- yolo_record.objects.append(yolo_object)
- return yolo_record
-
-
- @dataclass
- class CocoObject:
- label: str
- x_min: float
- y_min: float
- width: float
- height: float
-
-
- @dataclass
- class CocoRecord:
- filename: str
- width: int
- height: int
- objects: List[CocoObject] = field(default_factory=list)
-
- class CocoHandler:
- @staticmethod
- def load(filename, **kwargs) -> List[CocoRecord]:
- json_data = khandy.load_json(filename)
-
- images = json_data['images']
- annotations = json_data['annotations']
- categories = json_data['categories']
-
- label_map = {}
- for cat_item in categories:
- label_map[cat_item['id']] = cat_item['name']
-
- coco_records = OrderedDict()
- for image_item in images:
- coco_records[image_item['id']] = CocoRecord(
- filename=image_item['file_name'],
- width=image_item['width'],
- height=image_item['height'],
- objects=[])
-
- for annotation_item in annotations:
- coco_object = CocoObject(
- label=label_map[annotation_item['category_id']],
- x_min=annotation_item['bbox'][0],
- y_min=annotation_item['bbox'][1],
- width=annotation_item['bbox'][2],
- height=annotation_item['bbox'][3])
- coco_records[annotation_item['image_id']].objects.append(coco_object)
- return list(coco_records.values())
-
- @staticmethod
- def to_ir(coco_record: CocoRecord) -> DetectIrRecord:
- ir_record = DetectIrRecord(
- filename=coco_record.filename,
- width=coco_record.width,
- height=coco_record.height,
- )
- for coco_object in coco_record.objects:
- ir_object = DetectIrObject(
- label=coco_object.label,
- x_min=coco_object.x_min,
- y_min=coco_object.y_min,
- x_max=coco_object.x_min + coco_object.width,
- y_max=coco_object.y_min + coco_object.height
- )
- ir_record.objects.append(ir_object)
- return ir_record
- @staticmethod
- def from_ir(ir_record: DetectIrRecord) -> CocoRecord:
- coco_record = CocoRecord(
- filename=ir_record.filename,
- width=ir_record.width,
- height=ir_record.height
- )
- for ir_object in ir_record.objects:
- coco_object = CocoObject(
- label=ir_object.label,
- x_min=ir_object.x_min,
- y_min=ir_object.y_min,
- width=ir_object.x_max - ir_object.x_min,
- height=ir_object.y_max - ir_object.y_min
- )
- coco_record.objects.append(coco_object)
- return coco_record
-
-
- def load_detect(filename, fmt, **kwargs) -> DetectIrRecord:
- if fmt == 'labelme':
- labelme_record = LabelmeHandler.load(filename, **kwargs)
- ir_record = LabelmeHandler.to_ir(labelme_record)
- elif fmt == 'yolo':
- yolo_record = YoloHandler.load(filename, **kwargs)
- ir_record = YoloHandler.to_ir(yolo_record)
- elif fmt in ('voc', 'pascal', 'pascal_voc'):
- pascal_voc_record = PascalVocHandler.load(filename, **kwargs)
- ir_record = PascalVocHandler.to_ir(pascal_voc_record)
- elif fmt == 'coco':
- coco_records = CocoHandler.load(filename, **kwargs)
- ir_record = [CocoHandler.to_ir(coco_record) for coco_record in coco_records]
- else:
- raise ValueError(f"Unsupported detect label fmt. Got {fmt}")
- return ir_record
-
-
- def save_detect(filename, ir_record: DetectIrRecord, out_fmt):
- os.makedirs(os.path.dirname(os.path.abspath(filename)), exist_ok=True)
- if out_fmt == 'labelme':
- labelme_record = LabelmeHandler.from_ir(ir_record)
- LabelmeHandler.save(filename, labelme_record)
- elif out_fmt == 'yolo':
- yolo_record = YoloHandler.from_ir(ir_record)
- YoloHandler.save(filename, yolo_record)
- elif out_fmt in ('voc', 'pascal', 'pascal_voc'):
- pascal_voc_record = PascalVocHandler.from_ir(ir_record)
- PascalVocHandler.save(filename, pascal_voc_record)
- elif out_fmt == 'coco':
- raise ValueError("Unsupported for `coco` now!")
- else:
- raise ValueError(f"Unsupported detect label fmt. Got {out_fmt}")
- def _get_format(record):
- if isinstance(record, LabelmeRecord):
- return ('labelme',)
- elif isinstance(record, YoloRecord):
- return ('yolo',)
- elif isinstance(record, PascalVocRecord):
- return ('voc', 'pascal', 'pascal_voc')
- elif isinstance(record, CocoRecord):
- return ('coco',)
- elif isinstance(record, DetectIrRecord):
- return ('ir', 'detect_ir')
- else:
- return ()
- def convert_detect(record, out_fmt):
- allowed_fmts = ('labelme', 'yolo', 'voc', 'coco', 'pascal', 'pascal_voc', 'ir', 'detect_ir')
- if out_fmt not in allowed_fmts:
- raise ValueError("Unsupported label format conversions for given out_fmt")
- if out_fmt in _get_format(record):
- return record
- if isinstance(record, LabelmeRecord):
- ir_record = LabelmeHandler.to_ir(record)
- elif isinstance(record, YoloRecord):
- ir_record = YoloHandler.to_ir(record)
- elif isinstance(record, PascalVocRecord):
- ir_record = PascalVocHandler.to_ir(record)
- elif isinstance(record, CocoRecord):
- ir_record = CocoHandler.to_ir(record)
- elif isinstance(record, DetectIrRecord):
- ir_record = record
- else:
- raise ValueError('Unsupported type for record')
-
- if out_fmt in ('ir', 'detect_ir'):
- dst_record = ir_record
- elif out_fmt == 'labelme':
- dst_record = LabelmeHandler.from_ir(ir_record)
- elif out_fmt == 'yolo':
- dst_record = YoloHandler.from_ir(ir_record)
- elif out_fmt in ('voc', 'pascal', 'pascal_voc'):
- dst_record = PascalVocHandler.from_ir(ir_record)
- elif out_fmt == 'coco':
- dst_record = CocoHandler.from_ir(ir_record)
- return dst_record
-
- def replace_detect_label(record: DetectIrRecord, label_map, ignore=True):
- dst_record = copy.deepcopy(record)
- dst_objects = []
- for ir_object in dst_record.objects:
- if not ignore:
- if ir_object.label in label_map:
- ir_object.label = label_map[ir_object.label]
- dst_objects.append(ir_object)
- else:
- if ir_object.label in label_map:
- ir_object.label = label_map[ir_object.label]
- dst_objects.append(ir_object)
- dst_record.objects = dst_objects
- return dst_record
- def load_coco_class_names(filename):
- json_data = khandy.load_json(filename)
- categories = json_data['categories']
- return [cat_item['name'] for cat_item in categories]
|