detect.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582
  1. import os
  2. import copy
  3. import json
  4. import dataclasses
  5. from dataclasses import dataclass, field
  6. from collections import OrderedDict
  7. from typing import Optional, List
  8. import xml.etree.ElementTree as ET
  9. import khandy
  10. import lxml
  11. import lxml.builder
  12. import numpy as np
  13. __all__ = ['DetectIrObject', 'DetectIrRecord', 'load_detect',
  14. 'save_detect', 'convert_detect', 'replace_detect_label',
  15. 'load_coco_class_names']
  16. @dataclass
  17. class DetectIrObject:
  18. """Intermediate Representation Format of Object
  19. """
  20. label: str
  21. x_min: float
  22. y_min: float
  23. x_max: float
  24. y_max: float
  25. @dataclass
  26. class DetectIrRecord:
  27. """Intermediate Representation Format of Record
  28. """
  29. filename: str
  30. width: int
  31. height: int
  32. objects: List[DetectIrObject] = field(default_factory=list)
  33. @dataclass
  34. class PascalVocSource:
  35. database: str = ''
  36. annotation: str = ''
  37. image: str = ''
  38. @dataclass
  39. class PascalVocSize:
  40. height: int
  41. width: int
  42. depth: int
  43. @dataclass
  44. class PascalVocBndbox:
  45. xmin: float
  46. ymin: float
  47. xmax: float
  48. ymax: float
  49. @dataclass
  50. class PascalVocObject:
  51. name: str
  52. pose: str = 'Unspecified'
  53. truncated: int = 0
  54. difficult: int = 0
  55. bndbox: Optional[PascalVocBndbox] = None
  56. @dataclass
  57. class PascalVocRecord:
  58. folder: str = ''
  59. filename: str = ''
  60. path: str = ''
  61. source: PascalVocSource = PascalVocSource()
  62. size: Optional[PascalVocSize] = None
  63. segmented: int = 0
  64. objects: List[PascalVocObject] = field(default_factory=list)
  65. class PascalVocHandler:
  66. @staticmethod
  67. def load(filename, **kwargs) -> PascalVocRecord:
  68. pascal_voc_record = PascalVocRecord()
  69. xml_tree = ET.parse(filename)
  70. pascal_voc_record.folder = xml_tree.find('folder').text
  71. pascal_voc_record.filename = xml_tree.find('filename').text
  72. pascal_voc_record.path = xml_tree.find('path').text
  73. pascal_voc_record.segmented = xml_tree.find('segmented').text
  74. source_tag = xml_tree.find('source')
  75. pascal_voc_record.source = PascalVocSource(
  76. database=source_tag.find('database').text,
  77. # annotation=source_tag.find('annotation').text,
  78. # image=source_tag.find('image').text
  79. )
  80. size_tag = xml_tree.find('size')
  81. pascal_voc_record.size = PascalVocSize(
  82. width=int(size_tag.find('width').text),
  83. height=int(size_tag.find('height').text),
  84. depth=int(size_tag.find('depth').text)
  85. )
  86. object_tags = xml_tree.findall('object')
  87. for index, object_tag in enumerate(object_tags):
  88. bndbox_tag = object_tag.find('bndbox')
  89. bndbox = PascalVocBndbox(
  90. xmin=float(bndbox_tag.find('xmin').text) - 1,
  91. ymin=float(bndbox_tag.find('ymin').text) - 1,
  92. xmax=float(bndbox_tag.find('xmax').text) - 1,
  93. ymax=float(bndbox_tag.find('ymax').text) - 1
  94. )
  95. pascal_voc_object = PascalVocObject(
  96. name=object_tag.find('name').text,
  97. pose=object_tag.find('pose').text,
  98. truncated=object_tag.find('truncated').text,
  99. difficult=object_tag.find('difficult').text,
  100. bndbox=bndbox
  101. )
  102. pascal_voc_record.objects.append(pascal_voc_object)
  103. return pascal_voc_record
  104. @staticmethod
  105. def save(filename, pascal_voc_record: PascalVocRecord):
  106. maker = lxml.builder.ElementMaker()
  107. xml = maker.annotation(
  108. maker.folder(pascal_voc_record.folder),
  109. maker.filename(pascal_voc_record.filename),
  110. maker.path(pascal_voc_record.path),
  111. maker.source(
  112. maker.database(pascal_voc_record.source.database),
  113. ),
  114. maker.size(
  115. maker.width(str(pascal_voc_record.size.width)),
  116. maker.height(str(pascal_voc_record.size.height)),
  117. maker.depth(str(pascal_voc_record.size.depth)),
  118. ),
  119. maker.segmented(str(pascal_voc_record.segmented)),
  120. )
  121. for pascal_voc_object in pascal_voc_record.objects:
  122. object_tag = maker.object(
  123. maker.name(pascal_voc_object.name),
  124. maker.pose(pascal_voc_object.pose),
  125. maker.truncated(str(pascal_voc_object.truncated)),
  126. maker.difficult(str(pascal_voc_object.difficult)),
  127. maker.bndbox(
  128. maker.xmin(str(float(pascal_voc_object.bndbox.xmin))),
  129. maker.ymin(str(float(pascal_voc_object.bndbox.ymin))),
  130. maker.xmax(str(float(pascal_voc_object.bndbox.xmax))),
  131. maker.ymax(str(float(pascal_voc_object.bndbox.ymax))),
  132. ),
  133. )
  134. xml.append(object_tag)
  135. if not filename.endswith('.xml'):
  136. filename = filename + '.xml'
  137. with open(filename, 'wb') as f:
  138. f.write(lxml.etree.tostring(xml, pretty_print=True, encoding='utf-8'))
  139. @staticmethod
  140. def to_ir(pascal_voc_record: PascalVocRecord) -> DetectIrRecord:
  141. ir_record = DetectIrRecord(
  142. filename=pascal_voc_record.filename,
  143. width=pascal_voc_record.size.width,
  144. height=pascal_voc_record.size.height
  145. )
  146. for pascal_voc_object in pascal_voc_record.objects:
  147. ir_object = DetectIrObject(
  148. label=pascal_voc_object.name,
  149. x_min=pascal_voc_object.bndbox.xmin,
  150. y_min=pascal_voc_object.bndbox.ymin,
  151. x_max=pascal_voc_object.bndbox.xmax,
  152. y_max=pascal_voc_object.bndbox.ymax
  153. )
  154. ir_record.objects.append(ir_object)
  155. return ir_record
  156. @staticmethod
  157. def from_ir(ir_record: DetectIrRecord) -> PascalVocRecord:
  158. pascal_voc_record = PascalVocRecord(
  159. filename=ir_record.filename,
  160. size=PascalVocSize(
  161. width=ir_record.width,
  162. height=ir_record.height,
  163. depth=3
  164. )
  165. )
  166. for ir_object in ir_record.objects:
  167. pascal_voc_object = PascalVocObject(
  168. name=ir_object.label,
  169. bndbox=PascalVocBndbox(
  170. xmin=ir_object.x_min,
  171. ymin=ir_object.y_min,
  172. xmax=ir_object.x_max,
  173. ymax=ir_object.y_max,
  174. )
  175. )
  176. pascal_voc_record.objects.append(pascal_voc_object)
  177. return pascal_voc_record
  178. class _NumpyEncoder(json.JSONEncoder):
  179. """ Special json encoder for numpy types """
  180. def default(self, obj):
  181. if isinstance(obj, (np.bool_,)):
  182. return bool(obj)
  183. elif isinstance(obj, (np.int_, np.intc, np.intp, np.int8,
  184. np.int16, np.int32, np.int64, np.uint8,
  185. np.uint16, np.uint32, np.uint64)):
  186. return int(obj)
  187. elif isinstance(obj, (np.float_, np.float16, np.float32,
  188. np.float64)):
  189. return float(obj)
  190. elif isinstance(obj, (np.ndarray,)):
  191. return obj.tolist()
  192. return json.JSONEncoder.default(self, obj)
  193. @dataclass
  194. class LabelmeShape:
  195. label: str
  196. points: np.ndarray
  197. shape_type: str
  198. flags: dict = field(default_factory=dict)
  199. group_id: Optional[int] = None
  200. def __post_init__(self):
  201. self.points = np.asarray(self.points)
  202. @dataclass
  203. class LabelmeRecord:
  204. version: str = '4.5.6'
  205. flags: dict = field(default_factory=dict)
  206. shapes: List[LabelmeShape] = field(default_factory=list)
  207. imagePath: Optional[str] = None
  208. imageData: Optional[str] = None
  209. imageHeight: Optional[int] = None
  210. imageWidth: Optional[int] = None
  211. def __post_init__(self):
  212. for k, shape in enumerate(self.shapes):
  213. self.shapes[k] = LabelmeShape(**shape)
  214. class LabelmeHandler:
  215. @staticmethod
  216. def load(filename, **kwargs) -> LabelmeRecord:
  217. json_content = khandy.load_json(filename)
  218. return LabelmeRecord(**json_content)
  219. @staticmethod
  220. def save(filename, labelme_record: LabelmeRecord):
  221. json_content = dataclasses.asdict(labelme_record)
  222. khandy.save_json(filename, json_content, cls=_NumpyEncoder)
  223. @staticmethod
  224. def to_ir(labelme_record: LabelmeRecord) -> DetectIrRecord:
  225. ir_record = DetectIrRecord(
  226. filename=labelme_record.imagePath,
  227. width=labelme_record.imageWidth,
  228. height=labelme_record.imageHeight
  229. )
  230. for labelme_shape in labelme_record.shapes:
  231. if labelme_shape.shape_type != 'rectangle':
  232. continue
  233. ir_object = DetectIrObject(
  234. label=labelme_shape.label,
  235. x_min=labelme_shape.points[0][0],
  236. y_min=labelme_shape.points[0][1],
  237. x_max=labelme_shape.points[1][0],
  238. y_max=labelme_shape.points[1][1],
  239. )
  240. ir_record.objects.append(ir_object)
  241. return ir_record
  242. @staticmethod
  243. def from_ir(ir_record: DetectIrRecord) -> LabelmeRecord:
  244. labelme_record = LabelmeRecord(
  245. imagePath=ir_record.filename,
  246. imageWidth=ir_record.width,
  247. imageHeight=ir_record.height
  248. )
  249. for ir_object in ir_record.objects:
  250. labelme_shape = LabelmeShape(
  251. label=ir_object.label,
  252. shape_type='rectangle',
  253. points=[[ir_object.x_min, ir_object.y_min],
  254. [ir_object.x_max, ir_object.y_max]]
  255. )
  256. labelme_record.shapes.append(labelme_shape)
  257. return labelme_record
  258. @dataclass
  259. class YoloObject:
  260. label: str
  261. x_center: float
  262. y_center: float
  263. width: float
  264. height: float
  265. @dataclass
  266. class YoloRecord:
  267. filename: Optional[str] = None
  268. width: Optional[int] = None
  269. height: Optional[int] = None
  270. objects: List[YoloObject] = field(default_factory=list)
  271. class YoloHandler:
  272. @staticmethod
  273. def load(filename, **kwargs) -> YoloRecord:
  274. assert 'image_filename' in kwargs
  275. assert 'width' in kwargs and 'height' in kwargs
  276. records = khandy.load_list(filename)
  277. yolo_record = YoloRecord(
  278. filename=kwargs.get('image_filename'),
  279. width=kwargs.get('width'),
  280. height=kwargs.get('height'))
  281. for record in records:
  282. record_parts = record.split()
  283. yolo_record.objects.append(YoloObject(
  284. label=record_parts[0],
  285. x_center=float(record_parts[1]),
  286. y_center=float(record_parts[2]),
  287. width=float(record_parts[3]),
  288. height=float(record_parts[4]),
  289. ))
  290. return yolo_record
  291. @staticmethod
  292. def save(filename, yolo_record: YoloRecord):
  293. records = []
  294. for object in yolo_record.objects:
  295. records.append(f'{object.label} {object.x_center} {object.y_center} {object.width} {object.height}')
  296. if not filename.endswith('.txt'):
  297. filename = filename + '.txt'
  298. khandy.save_list(filename, records)
  299. @staticmethod
  300. def to_ir(yolo_record: YoloRecord) -> DetectIrRecord:
  301. ir_record = DetectIrRecord(
  302. filename=yolo_record.filename,
  303. width=yolo_record.width,
  304. height=yolo_record.height
  305. )
  306. for yolo_object in yolo_record.objects:
  307. x_min = (yolo_object.x_center - 0.5 * yolo_object.width) * yolo_record.width
  308. y_min = (yolo_object.y_center - 0.5 * yolo_object.height) * yolo_record.height
  309. x_max = (yolo_object.x_center + 0.5 * yolo_object.width) * yolo_record.width
  310. y_max = (yolo_object.y_center + 0.5 * yolo_object.height) * yolo_record.height
  311. ir_object = DetectIrObject(
  312. label=yolo_object.label,
  313. x_min=x_min,
  314. y_min=y_min,
  315. x_max=x_max,
  316. y_max=y_max
  317. )
  318. ir_record.objects.append(ir_object)
  319. return ir_record
  320. @staticmethod
  321. def from_ir(ir_record: DetectIrRecord) -> YoloRecord:
  322. yolo_record = YoloRecord(
  323. filename=ir_record.filename,
  324. width=ir_record.width,
  325. height=ir_record.height
  326. )
  327. for ir_object in ir_record.objects:
  328. x_center = (ir_object.x_max + ir_object.x_min) / (2 * ir_record.width)
  329. y_center = (ir_object.y_max + ir_object.y_min) / (2 * ir_record.height)
  330. width = abs(ir_object.x_max - ir_object.x_min) / ir_record.width
  331. height = abs(ir_object.y_max - ir_object.y_min) / ir_record.height
  332. yolo_object = YoloObject(
  333. label=ir_object.label,
  334. x_center=x_center,
  335. y_center=y_center,
  336. width=width,
  337. height=height,
  338. )
  339. yolo_record.objects.append(yolo_object)
  340. return yolo_record
  341. @dataclass
  342. class CocoObject:
  343. label: str
  344. x_min: float
  345. y_min: float
  346. width: float
  347. height: float
  348. @dataclass
  349. class CocoRecord:
  350. filename: str
  351. width: int
  352. height: int
  353. objects: List[CocoObject] = field(default_factory=list)
  354. class CocoHandler:
  355. @staticmethod
  356. def load(filename, **kwargs) -> List[CocoRecord]:
  357. json_data = khandy.load_json(filename)
  358. images = json_data['images']
  359. annotations = json_data['annotations']
  360. categories = json_data['categories']
  361. label_map = {}
  362. for cat_item in categories:
  363. label_map[cat_item['id']] = cat_item['name']
  364. coco_records = OrderedDict()
  365. for image_item in images:
  366. coco_records[image_item['id']] = CocoRecord(
  367. filename=image_item['file_name'],
  368. width=image_item['width'],
  369. height=image_item['height'],
  370. objects=[])
  371. for annotation_item in annotations:
  372. coco_object = CocoObject(
  373. label=label_map[annotation_item['category_id']],
  374. x_min=annotation_item['bbox'][0],
  375. y_min=annotation_item['bbox'][1],
  376. width=annotation_item['bbox'][2],
  377. height=annotation_item['bbox'][3])
  378. coco_records[annotation_item['image_id']].objects.append(coco_object)
  379. return list(coco_records.values())
  380. @staticmethod
  381. def to_ir(coco_record: CocoRecord) -> DetectIrRecord:
  382. ir_record = DetectIrRecord(
  383. filename=coco_record.filename,
  384. width=coco_record.width,
  385. height=coco_record.height,
  386. )
  387. for coco_object in coco_record.objects:
  388. ir_object = DetectIrObject(
  389. label=coco_object.label,
  390. x_min=coco_object.x_min,
  391. y_min=coco_object.y_min,
  392. x_max=coco_object.x_min + coco_object.width,
  393. y_max=coco_object.y_min + coco_object.height
  394. )
  395. ir_record.objects.append(ir_object)
  396. return ir_record
  397. @staticmethod
  398. def from_ir(ir_record: DetectIrRecord) -> CocoRecord:
  399. coco_record = CocoRecord(
  400. filename=ir_record.filename,
  401. width=ir_record.width,
  402. height=ir_record.height
  403. )
  404. for ir_object in ir_record.objects:
  405. coco_object = CocoObject(
  406. label=ir_object.label,
  407. x_min=ir_object.x_min,
  408. y_min=ir_object.y_min,
  409. width=ir_object.x_max - ir_object.x_min,
  410. height=ir_object.y_max - ir_object.y_min
  411. )
  412. coco_record.objects.append(coco_object)
  413. return coco_record
  414. def load_detect(filename, fmt, **kwargs) -> DetectIrRecord:
  415. if fmt == 'labelme':
  416. labelme_record = LabelmeHandler.load(filename, **kwargs)
  417. ir_record = LabelmeHandler.to_ir(labelme_record)
  418. elif fmt == 'yolo':
  419. yolo_record = YoloHandler.load(filename, **kwargs)
  420. ir_record = YoloHandler.to_ir(yolo_record)
  421. elif fmt in ('voc', 'pascal', 'pascal_voc'):
  422. pascal_voc_record = PascalVocHandler.load(filename, **kwargs)
  423. ir_record = PascalVocHandler.to_ir(pascal_voc_record)
  424. elif fmt == 'coco':
  425. coco_records = CocoHandler.load(filename, **kwargs)
  426. ir_record = [CocoHandler.to_ir(coco_record) for coco_record in coco_records]
  427. else:
  428. raise ValueError(f"Unsupported detect label fmt. Got {fmt}")
  429. return ir_record
  430. def save_detect(filename, ir_record: DetectIrRecord, out_fmt):
  431. os.makedirs(os.path.dirname(os.path.abspath(filename)), exist_ok=True)
  432. if out_fmt == 'labelme':
  433. labelme_record = LabelmeHandler.from_ir(ir_record)
  434. LabelmeHandler.save(filename, labelme_record)
  435. elif out_fmt == 'yolo':
  436. yolo_record = YoloHandler.from_ir(ir_record)
  437. YoloHandler.save(filename, yolo_record)
  438. elif out_fmt in ('voc', 'pascal', 'pascal_voc'):
  439. pascal_voc_record = PascalVocHandler.from_ir(ir_record)
  440. PascalVocHandler.save(filename, pascal_voc_record)
  441. elif out_fmt == 'coco':
  442. raise ValueError("Unsupported for `coco` now!")
  443. else:
  444. raise ValueError(f"Unsupported detect label fmt. Got {out_fmt}")
  445. def _get_format(record):
  446. if isinstance(record, LabelmeRecord):
  447. return ('labelme',)
  448. elif isinstance(record, YoloRecord):
  449. return ('yolo',)
  450. elif isinstance(record, PascalVocRecord):
  451. return ('voc', 'pascal', 'pascal_voc')
  452. elif isinstance(record, CocoRecord):
  453. return ('coco',)
  454. elif isinstance(record, DetectIrRecord):
  455. return ('ir', 'detect_ir')
  456. else:
  457. return ()
  458. def convert_detect(record, out_fmt):
  459. allowed_fmts = ('labelme', 'yolo', 'voc', 'coco', 'pascal', 'pascal_voc', 'ir', 'detect_ir')
  460. if out_fmt not in allowed_fmts:
  461. raise ValueError("Unsupported label format conversions for given out_fmt")
  462. if out_fmt in _get_format(record):
  463. return record
  464. if isinstance(record, LabelmeRecord):
  465. ir_record = LabelmeHandler.to_ir(record)
  466. elif isinstance(record, YoloRecord):
  467. ir_record = YoloHandler.to_ir(record)
  468. elif isinstance(record, PascalVocRecord):
  469. ir_record = PascalVocHandler.to_ir(record)
  470. elif isinstance(record, CocoRecord):
  471. ir_record = CocoHandler.to_ir(record)
  472. elif isinstance(record, DetectIrRecord):
  473. ir_record = record
  474. else:
  475. raise ValueError('Unsupported type for record')
  476. if out_fmt in ('ir', 'detect_ir'):
  477. dst_record = ir_record
  478. elif out_fmt == 'labelme':
  479. dst_record = LabelmeHandler.from_ir(ir_record)
  480. elif out_fmt == 'yolo':
  481. dst_record = YoloHandler.from_ir(ir_record)
  482. elif out_fmt in ('voc', 'pascal', 'pascal_voc'):
  483. dst_record = PascalVocHandler.from_ir(ir_record)
  484. elif out_fmt == 'coco':
  485. dst_record = CocoHandler.from_ir(ir_record)
  486. return dst_record
  487. def replace_detect_label(record: DetectIrRecord, label_map, ignore=True):
  488. dst_record = copy.deepcopy(record)
  489. dst_objects = []
  490. for ir_object in dst_record.objects:
  491. if not ignore:
  492. if ir_object.label in label_map:
  493. ir_object.label = label_map[ir_object.label]
  494. dst_objects.append(ir_object)
  495. else:
  496. if ir_object.label in label_map:
  497. ir_object.label = label_map[ir_object.label]
  498. dst_objects.append(ir_object)
  499. dst_record.objects = dst_objects
  500. return dst_record
  501. def load_coco_class_names(filename):
  502. json_data = khandy.load_json(filename)
  503. categories = json_data['categories']
  504. return [cat_item['name'] for cat_item in categories]