utils_others.py 3.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119
  1. import time
  2. import json
  3. import socket
  4. import logging
  5. import argparse
  6. def print_with_no(obj):
  7. if hasattr(obj, '__len__'):
  8. for k, item in enumerate(obj):
  9. print('[{}/{}] {}'.format(k+1, len(obj), item))
  10. elif hasattr(obj, '__iter__'):
  11. for k, item in enumerate(obj):
  12. print('[{}] {}'.format(k+1, item))
  13. else:
  14. print('[1] {}'.format(obj))
  15. def get_file_line_count(filename):
  16. line_count = 0
  17. buffer_size = 1024 * 1024 * 8
  18. with open(filename, 'r') as f:
  19. while True:
  20. data = f.read(buffer_size)
  21. if not data:
  22. break
  23. line_count += data.count('\n')
  24. return line_count
  25. def get_host_ip():
  26. try:
  27. s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
  28. s.connect(('8.8.8.8', 80))
  29. ip = s.getsockname()[0]
  30. finally:
  31. s.close()
  32. return ip
  33. class ContextTimer(object):
  34. """
  35. References:
  36. WithTimer in https://github.com/uber/ludwig/blob/master/ludwig/utils/time_utils.py
  37. """
  38. def __init__(self, name=None, use_log=False, quiet=False):
  39. self.use_log = use_log
  40. self.quiet = quiet
  41. if name is None:
  42. self.name = ''
  43. else:
  44. self.name = '{}, '.format(name.rstrip())
  45. def __enter__(self):
  46. self.start_time = time.time()
  47. if not self.quiet:
  48. self._print_or_log('{}{} starts'.format(self.name, self._now_time_str))
  49. return self
  50. def __exit__(self, exc_type, exc_val, exc_tb):
  51. if not self.quiet:
  52. self._print_or_log('{}elapsed_time = {:.5}s'.format(self.name, self.get_eplased_time()))
  53. self._print_or_log('{}{} ends'.format(self.name, self._now_time_str))
  54. @property
  55. def _now_time_str(self):
  56. return time.strftime('%Y-%m-%d %H:%M:%S', time.localtime())
  57. def _print_or_log(self, output_str):
  58. if self.use_log:
  59. logging.info(output_str)
  60. else:
  61. print(output_str)
  62. def get_eplased_time(self):
  63. return time.time() - self.start_time
  64. def enter(self):
  65. """Manually trigger enter"""
  66. self.__enter__()
  67. def set_logger(filename, level=logging.INFO, logger_name=None):
  68. logger = logging.getLogger(logger_name)
  69. logger.setLevel(level)
  70. # Never mutate (insert/remove elements) the list you're currently iterating on.
  71. # If you need, make a copy.
  72. for handler in logger.handlers[:]:
  73. if isinstance(handler, logging.FileHandler):
  74. logger.removeHandler(handler)
  75. # FileHandler is subclass of StreamHandler, so isinstance(handler,
  76. # logging.StreamHandler) is True even if handler is FileHandler.
  77. # if (type(handler) == logging.StreamHandler) and (handler.stream == sys.stderr):
  78. elif type(handler) == logging.StreamHandler:
  79. logger.removeHandler(handler)
  80. file_handler = logging.FileHandler(filename)
  81. file_handler.setFormatter(logging.Formatter('%(message)s'))
  82. logger.addHandler(file_handler)
  83. console_handler = logging.StreamHandler()
  84. console_handler.setFormatter(logging.Formatter('%(message)s'))
  85. logger.addHandler(console_handler)
  86. return logger
  87. def print_arguments(args):
  88. assert isinstance(args, argparse.Namespace)
  89. arg_list = sorted(vars(args).items())
  90. for key, value in arg_list:
  91. print('{}: {}'.format(key, value))
  92. def save_arguments(filename, args, sort=True):
  93. assert isinstance(args, argparse.Namespace)
  94. args = vars(args)
  95. with open(filename, 'w') as f:
  96. json.dump(args, f, indent=4, sort_keys=sort)