utils_others.py 8.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272
  1. import re
  2. import time
  3. import json
  4. import socket
  5. import imghdr
  6. import logging
  7. import argparse
  8. import numbers
  9. import datetime
  10. import warnings
  11. from enum import Enum
  12. import requests
  13. def print_with_no(obj):
  14. if hasattr(obj, '__len__'):
  15. for k, item in enumerate(obj):
  16. print('[{}/{}] {}'.format(k+1, len(obj), item))
  17. elif hasattr(obj, '__iter__'):
  18. for k, item in enumerate(obj):
  19. print('[{}] {}'.format(k+1, item))
  20. else:
  21. print('[1] {}'.format(obj))
  22. def get_file_line_count(filename):
  23. line_count = 0
  24. buffer_size = 1024 * 1024 * 8
  25. with open(filename, 'r') as f:
  26. while True:
  27. data = f.read(buffer_size)
  28. if not data:
  29. break
  30. line_count += data.count('\n')
  31. return line_count
  32. def get_host_ip():
  33. try:
  34. s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
  35. s.connect(('8.8.8.8', 80))
  36. ip = s.getsockname()[0]
  37. finally:
  38. s.close()
  39. return ip
  40. class ContextTimer(object):
  41. """
  42. References:
  43. WithTimer in https://github.com/uber/ludwig/blob/master/ludwig/utils/time_utils.py
  44. """
  45. def __init__(self, name=None, use_log=False, quiet=False):
  46. self.use_log = use_log
  47. self.quiet = quiet
  48. if name is None:
  49. self.name = ''
  50. else:
  51. self.name = '{}, '.format(name.rstrip())
  52. def __enter__(self):
  53. self.start_time = time.time()
  54. if not self.quiet:
  55. self._print_or_log('{}{} starts'.format(self.name, self._now_time_str))
  56. return self
  57. def __exit__(self, exc_type, exc_val, exc_tb):
  58. if not self.quiet:
  59. self._print_or_log('{}elapsed_time = {:.5}s'.format(self.name, self.get_eplased_time()))
  60. self._print_or_log('{}{} ends'.format(self.name, self._now_time_str))
  61. @property
  62. def _now_time_str(self):
  63. return time.strftime('%Y-%m-%d %H:%M:%S', time.localtime())
  64. def _print_or_log(self, output_str):
  65. if self.use_log:
  66. logging.info(output_str)
  67. else:
  68. print(output_str)
  69. def get_eplased_time(self):
  70. return time.time() - self.start_time
  71. def enter(self):
  72. """Manually trigger enter"""
  73. self.__enter__()
  74. def set_logger(filename, level=logging.INFO, logger_name=None):
  75. logger = logging.getLogger(logger_name)
  76. logger.setLevel(level)
  77. # Never mutate (insert/remove elements) the list you're currently iterating on.
  78. # If you need, make a copy.
  79. for handler in logger.handlers[:]:
  80. if isinstance(handler, logging.FileHandler):
  81. logger.removeHandler(handler)
  82. # FileHandler is subclass of StreamHandler, so isinstance(handler,
  83. # logging.StreamHandler) is True even if handler is FileHandler.
  84. # if (type(handler) == logging.StreamHandler) and (handler.stream == sys.stderr):
  85. elif type(handler) == logging.StreamHandler:
  86. logger.removeHandler(handler)
  87. file_handler = logging.FileHandler(filename)
  88. file_handler.setFormatter(logging.Formatter('%(message)s'))
  89. logger.addHandler(file_handler)
  90. console_handler = logging.StreamHandler()
  91. console_handler.setFormatter(logging.Formatter('%(message)s'))
  92. logger.addHandler(console_handler)
  93. return logger
  94. def print_arguments(args):
  95. assert isinstance(args, argparse.Namespace)
  96. arg_list = sorted(vars(args).items())
  97. for key, value in arg_list:
  98. print('{}: {}'.format(key, value))
  99. def save_arguments(filename, args, sort=True):
  100. assert isinstance(args, argparse.Namespace)
  101. args = vars(args)
  102. with open(filename, 'w') as f:
  103. json.dump(args, f, indent=4, sort_keys=sort)
  104. def strip_content_in_paren(string):
  105. """
  106. Notes:
  107. strip_content_in_paren cannot process nested paren correctly
  108. """
  109. return re.sub(r"\([^)]*\)|([^)]*)", "", string)
  110. def _to_timestamp(val):
  111. if val is None:
  112. timestamp = time.time()
  113. elif isinstance(val, numbers.Real):
  114. timestamp = float(val)
  115. elif isinstance(val, time.struct_time):
  116. timestamp = time.mktime(val)
  117. elif isinstance(val, datetime.datetime):
  118. timestamp = val.timestamp()
  119. elif isinstance(val, datetime.date):
  120. dt = datetime.datetime.combine(val, datetime.time())
  121. timestamp = dt.timestamp()
  122. elif isinstance(val, str):
  123. try:
  124. # The full format looks like 'YYYY-MM-DD HH:MM:SS.mmmmmm'.
  125. dt = datetime.datetime.fromisoformat(val)
  126. timestamp = dt.timestamp()
  127. except:
  128. raise TypeError('when argument is str, it should conform to isoformat')
  129. else:
  130. raise TypeError('unsupported type!')
  131. return timestamp
  132. def get_timestamp(time_val=None, rounded=True):
  133. """timestamp in seconds
  134. """
  135. timestamp = _to_timestamp(time_val)
  136. if rounded:
  137. timestamp = round(timestamp)
  138. return timestamp
  139. def get_timestamp_ms(time_val=None, rounded=True):
  140. """timestamp in milliseconds
  141. """
  142. timestamp = _to_timestamp(time_val) * 1000
  143. if rounded:
  144. timestamp = round(timestamp)
  145. return timestamp
  146. def get_utc8now():
  147. tz = datetime.timezone(datetime.timedelta(hours=8))
  148. utc8now = datetime.datetime.now(tz)
  149. return utc8now
  150. class DownloadStatusCode(Enum):
  151. FILE_SIZE_TOO_LARGE = (-100, 'the size of file from url is too large')
  152. FILE_SIZE_TOO_SMALL = (-101, 'the size of file from url is too small')
  153. FILE_SIZE_IS_ZERO = (-102, 'the size of file from url is zero')
  154. URL_IS_NOT_IMAGE = (-103, 'URL is not an image')
  155. @property
  156. def code(self):
  157. return self.value[0]
  158. @property
  159. def message(self):
  160. return self.value[1]
  161. class DownloadError(Exception):
  162. def __init__(self, status_code: DownloadStatusCode, extra_str: str=None):
  163. self.name = status_code.name
  164. self.code = status_code.code
  165. if extra_str is None:
  166. self.message = status_code.message
  167. else:
  168. self.message = f'{status_code.message}: {extra_str}'
  169. Exception.__init__(self)
  170. def __repr__(self):
  171. return f'[{self.__class__.__name__} {self.code}] {self.message}'
  172. __str__ = __repr__
  173. def download_image(image_url, min_filesize=None, max_filesize=None,
  174. imghdr_check=False, params=None, **kwargs) -> bytes:
  175. """
  176. References:
  177. https://httpwg.org/specs/rfc9110.html#field.content-length
  178. https://requests.readthedocs.io/en/latest/user/advanced/#body-content-workflow
  179. """
  180. stream = kwargs.pop('stream', True)
  181. min_filesize = min_filesize or 0
  182. max_filesize = max_filesize or 100 * 1024 * 1024
  183. with requests.get(image_url, stream=stream, params=params, **kwargs) as response:
  184. response.raise_for_status()
  185. content_type = response.headers.get('content-type')
  186. if content_type is None:
  187. warnings.warn('No Content-Type!')
  188. else:
  189. if not content_type.startswith(('image/', 'application/octet-stream')):
  190. raise DownloadError(DownloadStatusCode.URL_IS_NOT_IMAGE)
  191. # when Transfer-Encoding == chunked, Content-Length does not exist.
  192. content_length = response.headers.get('content-length')
  193. if content_length is None:
  194. warnings.warn('No Content-Length!')
  195. else:
  196. content_length = int(content_length)
  197. if content_length > max_filesize:
  198. raise DownloadError(DownloadStatusCode.FILE_SIZE_TOO_LARGE)
  199. if content_length < min_filesize:
  200. raise DownloadError(DownloadStatusCode.FILE_SIZE_TOO_SMALL)
  201. filesize = 0
  202. first_chunk = True
  203. chunks = []
  204. for chunk in response.iter_content(chunk_size=10*1024):
  205. if imghdr_check and first_chunk:
  206. # imghdr.what fails to determine image format sometimes!
  207. extension = imghdr.what('', chunk[:64])
  208. if extension is None:
  209. raise DownloadError(DownloadStatusCode.URL_IS_NOT_IMAGE)
  210. chunks.append(chunk)
  211. first_chunk = False
  212. else:
  213. chunks.append(chunk)
  214. filesize += len(chunk)
  215. if filesize > max_filesize:
  216. raise DownloadError(DownloadStatusCode.FILE_SIZE_TOO_LARGE)
  217. if filesize < min_filesize:
  218. raise DownloadError(DownloadStatusCode.FILE_SIZE_TOO_SMALL)
  219. image_bytes = b''.join(chunks)
  220. return image_bytes