utils_others.py 3.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130
  1. import re
  2. import time
  3. import json
  4. import socket
  5. import logging
  6. import argparse
  7. def print_with_no(obj):
  8. if hasattr(obj, '__len__'):
  9. for k, item in enumerate(obj):
  10. print('[{}/{}] {}'.format(k+1, len(obj), item))
  11. elif hasattr(obj, '__iter__'):
  12. for k, item in enumerate(obj):
  13. print('[{}] {}'.format(k+1, item))
  14. else:
  15. print('[1] {}'.format(obj))
  16. def get_file_line_count(filename):
  17. line_count = 0
  18. buffer_size = 1024 * 1024 * 8
  19. with open(filename, 'r') as f:
  20. while True:
  21. data = f.read(buffer_size)
  22. if not data:
  23. break
  24. line_count += data.count('\n')
  25. return line_count
  26. def get_host_ip():
  27. try:
  28. s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
  29. s.connect(('8.8.8.8', 80))
  30. ip = s.getsockname()[0]
  31. finally:
  32. s.close()
  33. return ip
  34. class ContextTimer(object):
  35. """
  36. References:
  37. WithTimer in https://github.com/uber/ludwig/blob/master/ludwig/utils/time_utils.py
  38. """
  39. def __init__(self, name=None, use_log=False, quiet=False):
  40. self.use_log = use_log
  41. self.quiet = quiet
  42. if name is None:
  43. self.name = ''
  44. else:
  45. self.name = '{}, '.format(name.rstrip())
  46. def __enter__(self):
  47. self.start_time = time.time()
  48. if not self.quiet:
  49. self._print_or_log('{}{} starts'.format(self.name, self._now_time_str))
  50. return self
  51. def __exit__(self, exc_type, exc_val, exc_tb):
  52. if not self.quiet:
  53. self._print_or_log('{}elapsed_time = {:.5}s'.format(self.name, self.get_eplased_time()))
  54. self._print_or_log('{}{} ends'.format(self.name, self._now_time_str))
  55. @property
  56. def _now_time_str(self):
  57. return time.strftime('%Y-%m-%d %H:%M:%S', time.localtime())
  58. def _print_or_log(self, output_str):
  59. if self.use_log:
  60. logging.info(output_str)
  61. else:
  62. print(output_str)
  63. def get_eplased_time(self):
  64. return time.time() - self.start_time
  65. def enter(self):
  66. """Manually trigger enter"""
  67. self.__enter__()
  68. def set_logger(filename, level=logging.INFO, logger_name=None):
  69. logger = logging.getLogger(logger_name)
  70. logger.setLevel(level)
  71. # Never mutate (insert/remove elements) the list you're currently iterating on.
  72. # If you need, make a copy.
  73. for handler in logger.handlers[:]:
  74. if isinstance(handler, logging.FileHandler):
  75. logger.removeHandler(handler)
  76. # FileHandler is subclass of StreamHandler, so isinstance(handler,
  77. # logging.StreamHandler) is True even if handler is FileHandler.
  78. # if (type(handler) == logging.StreamHandler) and (handler.stream == sys.stderr):
  79. elif type(handler) == logging.StreamHandler:
  80. logger.removeHandler(handler)
  81. file_handler = logging.FileHandler(filename)
  82. file_handler.setFormatter(logging.Formatter('%(message)s'))
  83. logger.addHandler(file_handler)
  84. console_handler = logging.StreamHandler()
  85. console_handler.setFormatter(logging.Formatter('%(message)s'))
  86. logger.addHandler(console_handler)
  87. return logger
  88. def print_arguments(args):
  89. assert isinstance(args, argparse.Namespace)
  90. arg_list = sorted(vars(args).items())
  91. for key, value in arg_list:
  92. print('{}: {}'.format(key, value))
  93. def save_arguments(filename, args, sort=True):
  94. assert isinstance(args, argparse.Namespace)
  95. args = vars(args)
  96. with open(filename, 'w') as f:
  97. json.dump(args, f, indent=4, sort_keys=sort)
  98. def strip_content_in_paren(string):
  99. """
  100. Notes:
  101. strip_content_in_paren cannot process nested paren correctly
  102. """
  103. return re.sub(r"\([^)]*\)|([^)]*)", "", string)