detect.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558
  1. import copy
  2. import json
  3. import dataclasses
  4. from dataclasses import dataclass, field
  5. from collections import OrderedDict
  6. from typing import Optional, List
  7. import xml.etree.ElementTree as ET
  8. import khandy
  9. import lxml
  10. import lxml.builder
  11. import numpy as np
  12. __all__ = ['DetectIrObject', 'DetectIrRecord',
  13. 'PascalVocSource', 'PascalVocSize', 'PascalVocBndbox',
  14. 'PascalVocObject', 'PascalVocRecord', 'PascalVocHandler',
  15. 'LabelmeShape', 'LabelmeRecord', 'LabelmeHandler',
  16. 'YoloObject', 'YoloRecord', 'YoloHandler',
  17. 'CocoObject', 'CocoRecord', 'CocoDetectHandler',
  18. 'load', 'save', 'convert']
  19. @dataclass
  20. class DetectIrObject:
  21. """Intermediate Representation Format of Object
  22. """
  23. label: str
  24. x_min: float
  25. y_min: float
  26. x_max: float
  27. y_max: float
  28. @dataclass
  29. class DetectIrRecord:
  30. """Intermediate Representation Format of Record
  31. """
  32. filename: str
  33. width: int
  34. height: int
  35. objects: List[DetectIrObject] = field(default_factory=list)
  36. @dataclass
  37. class PascalVocSource:
  38. database: str = ''
  39. annotation: str = ''
  40. image: str = ''
  41. @dataclass
  42. class PascalVocSize:
  43. height: int
  44. width: int
  45. depth: int
  46. @dataclass
  47. class PascalVocBndbox:
  48. xmin: float
  49. ymin: float
  50. xmax: float
  51. ymax: float
  52. @dataclass
  53. class PascalVocObject:
  54. name: str
  55. pose: str = 'Unspecified'
  56. truncated: int = 0
  57. difficult: int = 0
  58. bndbox: Optional[PascalVocBndbox] = None
  59. @dataclass
  60. class PascalVocRecord:
  61. folder: str = ''
  62. filename: str = ''
  63. path: str = ''
  64. source: PascalVocSource = PascalVocSource()
  65. size: Optional[PascalVocSize] = None
  66. segmented: int = 0
  67. objects: List[PascalVocObject] = field(default_factory=list)
  68. class PascalVocHandler:
  69. @staticmethod
  70. def load(filename) -> PascalVocRecord:
  71. pascal_voc_record = PascalVocRecord()
  72. xml_tree = ET.parse(filename)
  73. pascal_voc_record.folder = xml_tree.find('folder').text
  74. pascal_voc_record.filename = xml_tree.find('filename').text
  75. pascal_voc_record.path = xml_tree.find('path').text
  76. pascal_voc_record.segmented = xml_tree.find('segmented').text
  77. source_tag = xml_tree.find('source')
  78. pascal_voc_record.source = PascalVocSource(
  79. database=source_tag.find('database').text,
  80. # annotation=source_tag.find('annotation').text,
  81. # image=source_tag.find('image').text
  82. )
  83. size_tag = xml_tree.find('size')
  84. pascal_voc_record.size = PascalVocSize(
  85. width=int(size_tag.find('width').text),
  86. height=int(size_tag.find('height').text),
  87. depth=int(size_tag.find('depth').text)
  88. )
  89. object_tags = xml_tree.findall('object')
  90. for index, object_tag in enumerate(object_tags):
  91. bndbox_tag = object_tag.find('bndbox')
  92. bndbox = PascalVocBndbox(
  93. xmin=float(bndbox_tag.find('xmin').text) - 1,
  94. ymin=float(bndbox_tag.find('ymin').text) - 1,
  95. xmax=float(bndbox_tag.find('xmax').text) - 1,
  96. ymax=float(bndbox_tag.find('ymax').text) - 1
  97. )
  98. pascal_voc_object = PascalVocObject(
  99. name=object_tag.find('name').text,
  100. pose=object_tag.find('pose').text,
  101. truncated=object_tag.find('truncated').text,
  102. difficult=object_tag.find('difficult').text,
  103. bndbox=bndbox
  104. )
  105. pascal_voc_record.objects.append(pascal_voc_object)
  106. return pascal_voc_record
  107. @staticmethod
  108. def save(filename, pascal_voc_record: PascalVocRecord):
  109. maker = lxml.builder.ElementMaker()
  110. xml = maker.annotation(
  111. maker.folder(pascal_voc_record.folder),
  112. maker.filename(pascal_voc_record.filename),
  113. maker.path(pascal_voc_record.path),
  114. maker.source(
  115. maker.database(pascal_voc_record.source.database),
  116. ),
  117. maker.size(
  118. maker.width(str(pascal_voc_record.size.width)),
  119. maker.height(str(pascal_voc_record.size.height)),
  120. maker.depth(str(pascal_voc_record.size.depth)),
  121. ),
  122. maker.segmented(str(pascal_voc_record.segmented)),
  123. )
  124. for pascal_voc_object in pascal_voc_record.objects:
  125. object_tag = maker.object(
  126. maker.name(pascal_voc_object.name),
  127. maker.pose(pascal_voc_object.pose),
  128. maker.truncated(str(pascal_voc_object.truncated)),
  129. maker.difficult(str(pascal_voc_object.difficult)),
  130. maker.bndbox(
  131. maker.xmin(str(float(pascal_voc_object.bndbox.xmin))),
  132. maker.ymin(str(float(pascal_voc_object.bndbox.ymin))),
  133. maker.xmax(str(float(pascal_voc_object.bndbox.xmax))),
  134. maker.ymax(str(float(pascal_voc_object.bndbox.ymax))),
  135. ),
  136. )
  137. xml.append(object_tag)
  138. with open(filename, 'wb') as f:
  139. f.write(lxml.etree.tostring(xml, pretty_print=True, encoding='utf-8'))
  140. @staticmethod
  141. def to_ir(pascal_voc_record: PascalVocRecord) -> DetectIrRecord:
  142. ir_record = DetectIrRecord(
  143. filename=pascal_voc_record.filename,
  144. width=pascal_voc_record.size.width,
  145. height=pascal_voc_record.size.height
  146. )
  147. for pascal_voc_object in pascal_voc_record.objects:
  148. ir_object = DetectIrObject(
  149. label=pascal_voc_object.name,
  150. x_min=pascal_voc_object.bndbox.xmin,
  151. y_min=pascal_voc_object.bndbox.ymin,
  152. x_max=pascal_voc_object.bndbox.xmax,
  153. y_max=pascal_voc_object.bndbox.ymax
  154. )
  155. ir_record.objects.append(ir_object)
  156. return ir_record
  157. @staticmethod
  158. def from_ir(ir_record: DetectIrRecord) -> PascalVocRecord:
  159. pascal_voc_record = PascalVocRecord(
  160. filename=ir_record.filename,
  161. size=PascalVocSize(
  162. width=ir_record.width,
  163. height=ir_record.height,
  164. depth=3
  165. )
  166. )
  167. for ir_object in ir_record.objects:
  168. pascal_voc_object = PascalVocObject(
  169. name=ir_object.label,
  170. bndbox=PascalVocBndbox(
  171. xmin=ir_object.x_min,
  172. ymin=ir_object.y_min,
  173. xmax=ir_object.x_max,
  174. ymax=ir_object.y_max,
  175. )
  176. )
  177. pascal_voc_record.objects.append(pascal_voc_object)
  178. return pascal_voc_record
  179. class _NumpyEncoder(json.JSONEncoder):
  180. """ Special json encoder for numpy types """
  181. def default(self, obj):
  182. if isinstance(obj, (np.bool_,)):
  183. return bool(obj)
  184. elif isinstance(obj, (np.int_, np.intc, np.intp, np.int8,
  185. np.int16, np.int32, np.int64, np.uint8,
  186. np.uint16, np.uint32, np.uint64)):
  187. return int(obj)
  188. elif isinstance(obj, (np.float_, np.float16, np.float32,
  189. np.float64)):
  190. return float(obj)
  191. elif isinstance(obj, (np.ndarray,)):
  192. return obj.tolist()
  193. return json.JSONEncoder.default(self, obj)
  194. @dataclass
  195. class LabelmeShape:
  196. label: str
  197. points: np.ndarray
  198. shape_type: str
  199. flags: dict = field(default_factory=dict)
  200. group_id: Optional[int] = None
  201. def __post_init__(self):
  202. self.points = np.asarray(self.points)
  203. @dataclass
  204. class LabelmeRecord:
  205. version: str = '4.5.6'
  206. flags: dict = field(default_factory=dict)
  207. shapes: List[LabelmeShape] = field(default_factory=list)
  208. imagePath: Optional[str] = None
  209. imageData: Optional[str] = None
  210. imageHeight: Optional[int] = None
  211. imageWidth: Optional[int] = None
  212. def __post_init__(self):
  213. for k, shape in enumerate(self.shapes):
  214. self.shapes[k] = LabelmeShape(**shape)
  215. class LabelmeHandler:
  216. @staticmethod
  217. def load(filename) -> LabelmeRecord:
  218. json_content = khandy.load_json(filename)
  219. return LabelmeRecord(**json_content)
  220. @staticmethod
  221. def save(filename, labelme_record: LabelmeRecord):
  222. json_content = dataclasses.asdict(labelme_record)
  223. khandy.save_json(filename, json_content, cls=_NumpyEncoder)
  224. def to_ir(labelme_record: LabelmeRecord) -> DetectIrRecord:
  225. ir_record = DetectIrRecord(
  226. filename=labelme_record.imagePath,
  227. width=labelme_record.imageWidth,
  228. height=labelme_record.imageHeight
  229. )
  230. for labelme_shape in labelme_record.shapes:
  231. if labelme_shape.shape_type != 'rectangle':
  232. continue
  233. ir_object = DetectIrObject(
  234. label=labelme_shape.label,
  235. x_min=labelme_shape.points[0][0],
  236. y_min=labelme_shape.points[0][1],
  237. x_max=labelme_shape.points[1][0],
  238. y_max=labelme_shape.points[1][1],
  239. )
  240. ir_record.objects.append(ir_object)
  241. return ir_record
  242. @staticmethod
  243. def from_ir(ir_record: DetectIrRecord) -> LabelmeRecord:
  244. labelme_record = LabelmeRecord(
  245. imagePath=ir_record.filename,
  246. imageWidth=ir_record.width,
  247. imageHeight=ir_record.height
  248. )
  249. for ir_object in ir_record.objects:
  250. labelme_shape = LabelmeShape(
  251. label=ir_object.label,
  252. shape_type='rectangle',
  253. points=[[ir_object.x_min, ir_object.y_min],
  254. [ir_object.x_max, ir_object.y_max]]
  255. )
  256. labelme_record.shapes.append(labelme_shape)
  257. return labelme_record
  258. @dataclass
  259. class YoloObject:
  260. label: str
  261. x_center: float
  262. y_center: float
  263. width: float
  264. height: float
  265. @dataclass
  266. class YoloRecord:
  267. filename: Optional[str] = None
  268. width: Optional[int] = None
  269. height: Optional[int] = None
  270. objects: List[YoloObject] = field(default_factory=list)
  271. class YoloHandler:
  272. @staticmethod
  273. def load(filename, **kwargs) -> YoloRecord:
  274. records = khandy.load_list(filename)
  275. yolo_record = YoloRecord(
  276. filename=kwargs.get('filename'),
  277. width=kwargs.get('width'),
  278. height=kwargs.get('height'))
  279. for record in records:
  280. record_parts = record.split()
  281. yolo_record.objects.append(YoloObject(
  282. label=record_parts[0],
  283. x_center=float(record_parts[1]),
  284. y_center=float(record_parts[2]),
  285. width=float(record_parts[3]),
  286. height=float(record_parts[4]),
  287. ))
  288. return yolo_record
  289. @staticmethod
  290. def save(filemame, yolo_record: YoloRecord):
  291. records = []
  292. for object in yolo_record.objects:
  293. records.append(f'{object.label} {object.x_center} {object.y_center} {object.width} {object.height}')
  294. khandy.save_list(filemame, records)
  295. @staticmethod
  296. def replace_label(yolo_record: YoloRecord, label_map, ignore=False):
  297. dst_yolo_record = copy.deepcopy(yolo_record)
  298. dst_objects = []
  299. for yolo_object in dst_yolo_record.objects:
  300. if not ignore:
  301. if yolo_object.label in label_map:
  302. yolo_object.label = label_map[yolo_object.label]
  303. dst_objects.append(yolo_object)
  304. else:
  305. if yolo_object.label in label_map:
  306. yolo_object.label = label_map[yolo_object.label]
  307. dst_objects.append(yolo_object)
  308. dst_yolo_record.objects = dst_objects
  309. return dst_yolo_record
  310. @staticmethod
  311. def to_ir(yolo_record: YoloRecord) -> DetectIrRecord:
  312. ir_record = DetectIrRecord(
  313. filename=yolo_record.filename,
  314. width=yolo_record.width,
  315. height=yolo_record.height
  316. )
  317. for yolo_object in yolo_record.objects:
  318. x_min = (yolo_object.x_center - 0.5 * yolo_object.width) * yolo_record.width
  319. y_min = (yolo_object.y_center - 0.5 * yolo_object.height) * yolo_record.height
  320. x_max = (yolo_object.x_center + 0.5 * yolo_object.width) * yolo_record.width
  321. y_max = (yolo_object.y_center + 0.5 * yolo_object.height) * yolo_record.height
  322. ir_object = DetectIrObject(
  323. label=yolo_object.label,
  324. x_min=x_min,
  325. y_min=y_min,
  326. x_max=x_max,
  327. y_max=y_max
  328. )
  329. ir_record.objects.append(ir_object)
  330. return ir_record
  331. @staticmethod
  332. def from_ir(ir_record: DetectIrRecord) -> YoloRecord:
  333. yolo_record = YoloRecord(
  334. filename=ir_record.filename,
  335. width=ir_record.width,
  336. height=ir_record.height
  337. )
  338. for ir_object in ir_record.objects:
  339. x_center = (ir_object.x_max + ir_object.x_min) / (2 * ir_record.width)
  340. y_center = (ir_object.y_max + ir_object.y_min) / (2 * ir_record.height)
  341. width = abs(ir_object.x_max - ir_object.x_min) / ir_record.width
  342. height = abs(ir_object.y_max - ir_object.y_min) / ir_record.height
  343. yolo_object = YoloObject(
  344. label=ir_object.label,
  345. x_center=x_center,
  346. y_center=y_center,
  347. width=width,
  348. height=height,
  349. )
  350. yolo_record.objects.append(yolo_object)
  351. return yolo_record
  352. @dataclass
  353. class CocoObject:
  354. label: str
  355. x_min: float
  356. y_min: float
  357. width: float
  358. height: float
  359. @dataclass
  360. class CocoRecord:
  361. filename: str
  362. width: int
  363. height: int
  364. objects: List[CocoObject] = field(default_factory=list)
  365. class CocoDetectHandler:
  366. @staticmethod
  367. def load(filename) -> List[CocoRecord]:
  368. json_data = khandy.load_json(filename)
  369. images = json_data['images']
  370. annotations = json_data['annotations']
  371. categories = json_data['categories']
  372. label_map = {}
  373. for cat_item in categories:
  374. label_map[cat_item['id']] = cat_item['name']
  375. coco_records = OrderedDict()
  376. for image_item in images:
  377. coco_records[image_item['id']] = CocoRecord(
  378. filename=image_item['file_name'],
  379. width=image_item['width'],
  380. height=image_item['height'],
  381. objects=[])
  382. for annotation_item in annotations:
  383. coco_object = CocoObject(
  384. label=label_map[annotation_item['category_id']],
  385. x_min=annotation_item['bbox'][0],
  386. y_min=annotation_item['bbox'][1],
  387. width=annotation_item['bbox'][2],
  388. height=annotation_item['bbox'][3])
  389. coco_records[annotation_item['image_id']].objects.append(coco_object)
  390. return list(coco_records.values())
  391. @staticmethod
  392. def to_ir(coco_record: CocoRecord) -> DetectIrRecord:
  393. ir_record = DetectIrRecord(
  394. filename=coco_record.filename,
  395. width=coco_record.width,
  396. height=coco_record.height,
  397. )
  398. for coco_object in coco_record.objects:
  399. ir_object = DetectIrObject(
  400. label=coco_object.label,
  401. x_min=coco_object.x_min,
  402. y_min=coco_object.y_min,
  403. x_max=coco_object.x_min + coco_object.width,
  404. y_max=coco_object.y_min + coco_object.height
  405. )
  406. ir_record.objects.append(ir_object)
  407. return ir_record
  408. @staticmethod
  409. def from_ir(ir_record: DetectIrRecord) -> CocoRecord:
  410. coco_record = CocoRecord(
  411. filename=ir_record.filename,
  412. width=ir_record.width,
  413. height=ir_record.height
  414. )
  415. for ir_object in ir_record.objects:
  416. coco_object = CocoObject(
  417. label=ir_object.label,
  418. x_min=ir_object.x_min,
  419. y_min=ir_object.y_min,
  420. width=ir_object.x_max - ir_object.x_min,
  421. height=ir_object.y_max - ir_object.y_min
  422. )
  423. coco_record.objects.append(coco_object)
  424. return coco_record
  425. def _get_format(record):
  426. if isinstance(record, LabelmeRecord):
  427. return ('labelme',)
  428. elif isinstance(record, YoloRecord):
  429. return ('yolo',)
  430. elif isinstance(record, PascalVocRecord):
  431. return ('voc', 'pascal', 'pascal_voc')
  432. elif isinstance(record, CocoRecord):
  433. return ('coco',)
  434. else:
  435. return ()
  436. def load(filename, fmt, **kwargs):
  437. if fmt == 'labelme':
  438. record = LabelmeHandler.load(filename)
  439. elif fmt == 'yolo':
  440. record = YoloHandler.load(filename)
  441. elif fmt in ('voc', 'pascal', 'pascal_voc'):
  442. record = PascalVocHandler.load(filename)
  443. elif fmt == 'coco':
  444. record = CocoDetectHandler.load(filename, **kwargs)
  445. else:
  446. raise ValueError(f"Unsupported detect label fmt. Got {fmt}")
  447. return record
  448. def save(filename, record):
  449. if isinstance(record, LabelmeRecord):
  450. LabelmeHandler.save(filename, record)
  451. elif isinstance(record, YoloRecord):
  452. YoloHandler.save(filename, record)
  453. elif isinstance(record, PascalVocRecord):
  454. PascalVocHandler.save(filename, record)
  455. elif isinstance(record, CocoRecord):
  456. raise ValueError("Unsupported for CocoRecord now!")
  457. else:
  458. raise ValueError("Unsupported type!")
  459. def convert(record, out_fmt):
  460. allowed_fmts = ("labelme", "yolo", "voc", "coco", 'pascal', 'pascal_voc')
  461. if out_fmt not in allowed_fmts:
  462. raise ValueError("Unsupported label format conversions for given out_fmt")
  463. if out_fmt in _get_format(record):
  464. return record
  465. if isinstance(record, LabelmeRecord):
  466. ir_record = LabelmeHandler.to_ir(record)
  467. elif isinstance(record, YoloRecord):
  468. ir_record = YoloHandler.to_ir(record)
  469. elif isinstance(record, PascalVocRecord):
  470. ir_record = PascalVocHandler.to_ir(record)
  471. elif isinstance(record, CocoRecord):
  472. ir_record = CocoDetectHandler.to_ir(record)
  473. else:
  474. raise ValueError('Unsupported type for record')
  475. if out_fmt == 'labelme':
  476. dst_record = LabelmeHandler.from_ir(ir_record)
  477. elif out_fmt == 'yolo':
  478. dst_record = YoloHandler.from_ir(ir_record)
  479. elif out_fmt in ('voc', 'pascal', 'pascal_voc'):
  480. dst_record = PascalVocHandler.from_ir(ir_record)
  481. elif out_fmt == 'coco':
  482. dst_record = CocoDetectHandler.from_ir(ir_record)
  483. return dst_record