detect.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489
  1. import copy
  2. import json
  3. import dataclasses
  4. from dataclasses import dataclass, field
  5. from collections import OrderedDict
  6. from typing import Optional, List
  7. import xml.etree.ElementTree as ET
  8. import khandy
  9. import lxml
  10. import lxml.builder
  11. import numpy as np
  12. __all__ = ['DetectIrObject', 'DetectIrRecord',
  13. 'PascalVocSource', 'PascalVocSize', 'PascalVocBndbox',
  14. 'PascalVocObject', 'PascalVocRecord', 'PascalVocHandler',
  15. 'LabelmeShape', 'LabelmeRecord', 'LabelmeHandler',
  16. 'YoloObject', 'YoloRecord', 'YoloHandler',
  17. 'CocoObject', 'CocoRecord', 'CocoDetectHandler']
  18. @dataclass
  19. class DetectIrObject:
  20. """Intermediate Representation Format of Object
  21. """
  22. label: str
  23. x_min: float
  24. y_min: float
  25. x_max: float
  26. y_max: float
  27. @dataclass
  28. class DetectIrRecord:
  29. """Intermediate Representation Format of Record
  30. """
  31. filename: str
  32. width: int
  33. height: int
  34. objects: List[DetectIrObject] = field(default_factory=list)
  35. @dataclass
  36. class PascalVocSource:
  37. database: str = ''
  38. annotation: str = ''
  39. image: str = ''
  40. @dataclass
  41. class PascalVocSize:
  42. height: int
  43. width: int
  44. depth: int
  45. @dataclass
  46. class PascalVocBndbox:
  47. xmin: float
  48. ymin: float
  49. xmax: float
  50. ymax: float
  51. @dataclass
  52. class PascalVocObject:
  53. name: str
  54. pose: str = 'Unspecified'
  55. truncated: int = 0
  56. difficult: int = 0
  57. bndbox: Optional[PascalVocBndbox] = None
  58. @dataclass
  59. class PascalVocRecord:
  60. folder: str = ''
  61. filename: str = ''
  62. path: str = ''
  63. source: PascalVocSource = PascalVocSource()
  64. size: Optional[PascalVocSize] = None
  65. segmented: int = 0
  66. objects: List[PascalVocObject] = field(default_factory=list)
  67. class PascalVocHandler:
  68. @staticmethod
  69. def load(filename) -> PascalVocRecord:
  70. pascal_voc_record = PascalVocRecord()
  71. xml_tree = ET.parse(filename)
  72. pascal_voc_record.folder = xml_tree.find('folder').text
  73. pascal_voc_record.filename = xml_tree.find('filename').text
  74. pascal_voc_record.path = xml_tree.find('path').text
  75. pascal_voc_record.segmented = xml_tree.find('segmented').text
  76. source_tag = xml_tree.find('source')
  77. pascal_voc_record.source = PascalVocSource(
  78. database=source_tag.find('database').text,
  79. # annotation=source_tag.find('annotation').text,
  80. # image=source_tag.find('image').text
  81. )
  82. size_tag = xml_tree.find('size')
  83. pascal_voc_record.size = PascalVocSize(
  84. width=int(size_tag.find('width').text),
  85. height=int(size_tag.find('height').text),
  86. depth=int(size_tag.find('depth').text)
  87. )
  88. object_tags = xml_tree.findall('object')
  89. for index, object_tag in enumerate(object_tags):
  90. bndbox_tag = object_tag.find('bndbox')
  91. bndbox = PascalVocBndbox(
  92. xmin=float(bndbox_tag.find('xmin').text) - 1,
  93. ymin=float(bndbox_tag.find('ymin').text) - 1,
  94. xmax=float(bndbox_tag.find('xmax').text) - 1,
  95. ymax=float(bndbox_tag.find('ymax').text) - 1
  96. )
  97. pascal_voc_object = PascalVocObject(
  98. name=object_tag.find('name').text,
  99. pose=object_tag.find('pose').text,
  100. truncated=object_tag.find('truncated').text,
  101. difficult=object_tag.find('difficult').text,
  102. bndbox=bndbox
  103. )
  104. pascal_voc_record.objects.append(pascal_voc_object)
  105. return pascal_voc_record
  106. @staticmethod
  107. def save(filename, pascal_voc_record: PascalVocRecord):
  108. maker = lxml.builder.ElementMaker()
  109. xml = maker.annotation(
  110. maker.folder(pascal_voc_record.folder),
  111. maker.filename(pascal_voc_record.filename),
  112. maker.path(pascal_voc_record.path),
  113. maker.source(
  114. maker.database(pascal_voc_record.source.database),
  115. ),
  116. maker.size(
  117. maker.width(str(pascal_voc_record.size.width)),
  118. maker.height(str(pascal_voc_record.size.height)),
  119. maker.depth(str(pascal_voc_record.size.depth)),
  120. ),
  121. maker.segmented(str(pascal_voc_record.segmented)),
  122. )
  123. for pascal_voc_object in pascal_voc_record.objects:
  124. object_tag = maker.object(
  125. maker.name(pascal_voc_object.name),
  126. maker.pose(pascal_voc_object.pose),
  127. maker.truncated(str(pascal_voc_object.truncated)),
  128. maker.difficult(str(pascal_voc_object.difficult)),
  129. maker.bndbox(
  130. maker.xmin(str(float(pascal_voc_object.bndbox.xmin))),
  131. maker.ymin(str(float(pascal_voc_object.bndbox.ymin))),
  132. maker.xmax(str(float(pascal_voc_object.bndbox.xmax))),
  133. maker.ymax(str(float(pascal_voc_object.bndbox.ymax))),
  134. ),
  135. )
  136. xml.append(object_tag)
  137. with open(filename, 'wb') as f:
  138. f.write(lxml.etree.tostring(xml, pretty_print=True, encoding='utf-8'))
  139. @staticmethod
  140. def to_ir(pascal_voc_record: PascalVocRecord) -> DetectIrRecord:
  141. ir_record = DetectIrRecord(
  142. filename=pascal_voc_record.filename,
  143. width=pascal_voc_record.size.width,
  144. height=pascal_voc_record.size.height
  145. )
  146. for pascal_voc_object in pascal_voc_record.objects:
  147. ir_object = DetectIrObject(
  148. label=pascal_voc_object.name,
  149. x_min=pascal_voc_object.bndbox.xmin,
  150. y_min=pascal_voc_object.bndbox.ymin,
  151. x_max=pascal_voc_object.bndbox.xmax,
  152. y_max=pascal_voc_object.bndbox.ymax
  153. )
  154. ir_record.objects.append(ir_object)
  155. return ir_record
  156. @staticmethod
  157. def from_ir(ir_record: DetectIrRecord) -> PascalVocRecord:
  158. pascal_voc_record = PascalVocRecord(
  159. filename=ir_record.filename,
  160. size=PascalVocSize(
  161. width=ir_record.width,
  162. height=ir_record.height,
  163. depth=3
  164. )
  165. )
  166. for ir_object in ir_record.objects:
  167. pascal_voc_object = PascalVocObject(
  168. name=ir_object.label,
  169. bndbox=PascalVocBndbox(
  170. xmin=ir_object.x_min,
  171. ymin=ir_object.y_min,
  172. xmax=ir_object.x_max,
  173. ymax=ir_object.y_max,
  174. )
  175. )
  176. pascal_voc_record.objects.append(pascal_voc_object)
  177. return pascal_voc_record
  178. class _NumpyEncoder(json.JSONEncoder):
  179. """ Special json encoder for numpy types """
  180. def default(self, obj):
  181. if isinstance(obj, (np.bool_,)):
  182. return bool(obj)
  183. elif isinstance(obj, (np.int_, np.intc, np.intp, np.int8,
  184. np.int16, np.int32, np.int64, np.uint8,
  185. np.uint16, np.uint32, np.uint64)):
  186. return int(obj)
  187. elif isinstance(obj, (np.float_, np.float16, np.float32,
  188. np.float64)):
  189. return float(obj)
  190. elif isinstance(obj, (np.ndarray,)):
  191. return obj.tolist()
  192. return json.JSONEncoder.default(self, obj)
  193. @dataclass
  194. class LabelmeShape:
  195. label: str
  196. points: np.ndarray
  197. shape_type: str
  198. flags: dict = field(default_factory=dict)
  199. group_id: Optional[int] = None
  200. def __post_init__(self):
  201. self.points = np.asarray(self.points)
  202. @dataclass
  203. class LabelmeRecord:
  204. version: str = '4.5.6'
  205. flags: dict = field(default_factory=dict)
  206. shapes: List[LabelmeShape] = field(default_factory=list)
  207. imagePath: Optional[str] = None
  208. imageData: Optional[str] = None
  209. imageHeight: Optional[int] = None
  210. imageWidth: Optional[int] = None
  211. def __post_init__(self):
  212. for k, shape in enumerate(self.shapes):
  213. self.shapes[k] = LabelmeShape(**shape)
  214. class LabelmeHandler:
  215. @staticmethod
  216. def load(filename) -> LabelmeRecord:
  217. json_content = khandy.load_json(filename)
  218. return LabelmeRecord(**json_content)
  219. @staticmethod
  220. def save(filename, labelme_record: LabelmeRecord):
  221. json_content = dataclasses.asdict(labelme_record)
  222. khandy.save_json(filename, json_content, cls=_NumpyEncoder)
  223. def to_ir(labelme_record: LabelmeRecord) -> DetectIrRecord:
  224. ir_record = DetectIrRecord(
  225. filename=labelme_record.imagePath,
  226. width=labelme_record.imageWidth,
  227. height=labelme_record.imageHeight
  228. )
  229. for labelme_shape in labelme_record.shapes:
  230. if labelme_shape.shape_type != 'rectangle':
  231. continue
  232. ir_object = DetectIrObject(
  233. label=labelme_shape.label,
  234. x_min=labelme_shape.points[0][0],
  235. y_min=labelme_shape.points[0][1],
  236. x_max=labelme_shape.points[1][0],
  237. y_max=labelme_shape.points[1][1],
  238. )
  239. ir_record.objects.append(ir_object)
  240. return ir_record
  241. @staticmethod
  242. def from_ir(ir_record: DetectIrRecord) -> LabelmeRecord:
  243. labelme_record = LabelmeRecord(
  244. imagePath=ir_record.filename,
  245. imageWidth=ir_record.width,
  246. imageHeight=ir_record.height
  247. )
  248. for ir_object in ir_record.objects:
  249. labelme_shape = LabelmeShape(
  250. label=ir_object.label,
  251. shape_type='rectangle',
  252. points=[[ir_object.x_min, ir_object.y_min],
  253. [ir_object.x_max, ir_object.y_max]]
  254. )
  255. labelme_record.shapes.append(labelme_shape)
  256. return labelme_record
  257. @dataclass
  258. class YoloObject:
  259. label: str
  260. x_center: float
  261. y_center: float
  262. width: float
  263. height: float
  264. @dataclass
  265. class YoloRecord:
  266. filename: Optional[str] = None
  267. width: Optional[int] = None
  268. height: Optional[int] = None
  269. objects: List[YoloObject] = field(default_factory=list)
  270. class YoloHandler:
  271. @staticmethod
  272. def load(filename, **kwargs) -> YoloRecord:
  273. records = khandy.load_list(filename)
  274. yolo_record = YoloRecord(
  275. filename=kwargs.get('filename'),
  276. width=kwargs.get('width'),
  277. height=kwargs.get('height'))
  278. for record in records:
  279. record_parts = record.split()
  280. yolo_record.objects.append(YoloObject(
  281. label=record_parts[0],
  282. x_center=float(record_parts[1]),
  283. y_center=float(record_parts[2]),
  284. width=float(record_parts[3]),
  285. height=float(record_parts[4]),
  286. ))
  287. return yolo_record
  288. @staticmethod
  289. def save(filemame, yolo_record: YoloRecord):
  290. records = []
  291. for object in yolo_record.objects:
  292. records.append(f'{object.label} {object.x_center} {object.y_center} {object.width} {object.height}')
  293. khandy.save_list(filemame, records)
  294. @staticmethod
  295. def replace_label(yolo_record: YoloRecord, label_map, ignore=False):
  296. dst_yolo_record = copy.deepcopy(yolo_record)
  297. dst_objects = []
  298. for yolo_object in dst_yolo_record.objects:
  299. if not ignore:
  300. if yolo_object.label in label_map:
  301. yolo_object.label = label_map[yolo_object.label]
  302. dst_objects.append(yolo_object)
  303. else:
  304. if yolo_object.label in label_map:
  305. yolo_object.label = label_map[yolo_object.label]
  306. dst_objects.append(yolo_object)
  307. dst_yolo_record.objects = dst_objects
  308. return dst_yolo_record
  309. @staticmethod
  310. def to_ir(yolo_record: YoloRecord) -> DetectIrRecord:
  311. ir_record = DetectIrRecord(
  312. filename=yolo_record.filename,
  313. width=yolo_record.width,
  314. height=yolo_record.height
  315. )
  316. for yolo_object in yolo_record.objects:
  317. x_min = (yolo_object.x_center - 0.5 * yolo_object.width) * yolo_record.width
  318. y_min = (yolo_object.y_center - 0.5 * yolo_object.height) * yolo_record.height
  319. x_max = (yolo_object.x_center + 0.5 * yolo_object.width) * yolo_record.width
  320. y_max = (yolo_object.y_center + 0.5 * yolo_object.height) * yolo_record.height
  321. ir_object = DetectIrObject(
  322. label=yolo_object.label,
  323. x_min=x_min,
  324. y_min=y_min,
  325. x_max=x_max,
  326. y_max=y_max
  327. )
  328. ir_record.objects.append(ir_object)
  329. return ir_record
  330. @staticmethod
  331. def from_ir(ir_record: DetectIrRecord) -> YoloRecord:
  332. yolo_record = YoloRecord(
  333. filename=ir_record.filename,
  334. width=ir_record.width,
  335. height=ir_record.height
  336. )
  337. for ir_object in ir_record.objects:
  338. x_center = (ir_object.x_max + ir_object.x_min) / (2 * ir_record.width)
  339. y_center = (ir_object.y_max + ir_object.y_min) / (2 * ir_record.height)
  340. width = abs(ir_object.x_max - ir_object.x_min) / ir_record.width
  341. height = abs(ir_object.y_max - ir_object.y_min) / ir_record.height
  342. yolo_object = YoloObject(
  343. label=ir_object.label,
  344. x_center=x_center,
  345. y_center=y_center,
  346. width=width,
  347. height=height,
  348. )
  349. yolo_record.objects.append(yolo_object)
  350. return yolo_record
  351. @dataclass
  352. class CocoObject:
  353. label: str
  354. x_min: float
  355. y_min: float
  356. width: float
  357. height: float
  358. @dataclass
  359. class CocoRecord:
  360. filename: str
  361. width: int
  362. height: int
  363. objects: List[CocoObject] = field(default_factory=list)
  364. class CocoDetectHandler:
  365. @staticmethod
  366. def load(filename) -> List[CocoRecord]:
  367. json_data = khandy.load_json(filename)
  368. images = json_data['images']
  369. annotations = json_data['annotations']
  370. categories = json_data['categories']
  371. label_map = {}
  372. for cat_item in categories:
  373. label_map[cat_item['id']] = cat_item['name']
  374. coco_records = OrderedDict()
  375. for image_item in images:
  376. coco_records[image_item['id']] = CocoRecord(
  377. filename=image_item['file_name'],
  378. width=image_item['width'],
  379. height=image_item['height'],
  380. objects=[])
  381. for annotation_item in annotations:
  382. coco_object = CocoObject(
  383. label=label_map[annotation_item['category_id']],
  384. x_min=annotation_item['bbox'][0],
  385. y_min=annotation_item['bbox'][1],
  386. width=annotation_item['bbox'][2],
  387. height=annotation_item['bbox'][3])
  388. coco_records[annotation_item['image_id']].objects.append(coco_object)
  389. return list(coco_records.values())
  390. @staticmethod
  391. def to_ir(coco_record: CocoRecord) -> DetectIrRecord:
  392. ir_record = DetectIrRecord(
  393. filename=coco_record.filename,
  394. width=coco_record.width,
  395. height=coco_record.height,
  396. )
  397. for coco_object in coco_record.objects:
  398. ir_object = DetectIrObject(
  399. label=coco_object.label,
  400. x_min=coco_object.x_min,
  401. y_min=coco_object.y_min,
  402. x_max=coco_object.x_min + coco_object.width,
  403. y_max=coco_object.y_min + coco_object.height
  404. )
  405. ir_record.objects.append(ir_object)
  406. return ir_record
  407. @staticmethod
  408. def from_ir(ir_record: DetectIrRecord) -> CocoRecord:
  409. coco_record = CocoRecord(
  410. filename=ir_record.filename,
  411. width=ir_record.width,
  412. height=ir_record.height
  413. )
  414. for ir_object in ir_record.objects:
  415. coco_object = CocoObject(
  416. label=ir_object.label,
  417. x_min=ir_object.x_min,
  418. y_min=ir_object.y_min,
  419. width=ir_object.x_max - ir_object.x_min,
  420. height=ir_object.y_max - ir_object.y_min
  421. )
  422. coco_record.objects.append(coco_object)
  423. return coco_record