|
14 | 14 | from QEfficient.utils.constants import ROOT_DIR
|
15 | 15 |
|
16 | 16 |
|
17 |
| -class QEffFormatter(logging.Formatter): |
18 |
| - """ |
19 |
| - Formatter class used to set colors for printing different logging levels of messages on console. |
20 |
| - """ |
21 |
| - |
22 |
| - cyan: str = "\x1b[38;5;14m" |
23 |
| - yellow: str = "\x1b[33;20m" |
24 |
| - red: str = "\x1b[31;20m" |
25 |
| - bold_red: str = "\x1b[31;1m" |
26 |
| - reset: str = "\x1b[0m" |
27 |
| - common_format: str = "%(levelname)s - %(name)s - %(message)s" # type: ignore |
28 |
| - format_with_line_info = "%(levelname)s - %(name)s - %(message)s (%(filename)s:%(lineno)d)" # type: ignore |
29 |
| - |
30 |
| - FORMATS = { |
31 |
| - logging.DEBUG: cyan + format_with_line_info + reset, |
32 |
| - logging.INFO: cyan + common_format + reset, |
33 |
| - logging.WARNING: yellow + common_format + reset, |
34 |
| - logging.ERROR: red + format_with_line_info + reset, |
35 |
| - logging.CRITICAL: bold_red + format_with_line_info + reset, |
36 |
| - } |
37 |
| - |
38 |
| - def format(self, record): |
39 |
| - """ |
40 |
| - Overriding the base class method to Choose format based on log level. |
41 |
| - """ |
42 |
| - log_fmt = self.FORMATS.get(record.levelno) |
43 |
| - formatter = logging.Formatter(log_fmt) |
44 |
| - return formatter.format(record) |
45 |
| - |
46 |
| - |
47 |
| -def create_logger() -> logging.Logger: |
48 |
| - """ |
49 |
| - Creates a logger object with Colored QEffFormatter. |
50 |
| - """ |
51 |
| - logger = logging.getLogger("QEfficient") |
52 |
| - |
53 |
| - # create console handler and set level |
54 |
| - ch = logging.StreamHandler() |
55 |
| - ch.setLevel(logging.INFO) |
56 |
| - ch.setFormatter(QEffFormatter()) |
57 |
| - logger.addHandler(ch) |
58 |
| - |
59 |
| - return logger |
60 |
| - |
61 |
| - |
62 |
| -class CustomLogger(logging.Logger): |
63 |
| - def raise_runtimeerror(self, message): |
64 |
| - self.error(message) |
65 |
| - raise RuntimeError(message) |
66 |
| - |
67 |
| - def log_rank_zero(self, msg: str, level: int = logging.INFO) -> None: |
68 |
| - rank = dist.get_rank() if dist.is_available() and dist.is_initialized() else 0 |
69 |
| - if rank != 0: |
70 |
| - return |
71 |
| - self.log(level, msg, stacklevel=2) |
72 |
| - |
73 |
| - def prepare_dump_logs(self, dump_logs=False): |
74 |
| - if dump_logs: |
75 |
| - logs_path = os.path.join(ROOT_DIR, "logs") |
76 |
| - if not os.path.exists(logs_path): |
77 |
| - os.makedirs(logs_path, exist_ok=True) |
78 |
| - file_name = f"log-file-{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}" + ".txt" |
79 |
| - log_file = os.path.join(logs_path, file_name) |
80 |
| - |
81 |
| - # create file handler and set level |
82 |
| - fh = logging.FileHandler(log_file) |
83 |
| - fh.setLevel(logging.INFO) |
84 |
| - formatter = logging.Formatter("%(levelname)s - %(name)s - %(message)s") |
85 |
| - fh.setFormatter(formatter) |
86 |
| - logger.addHandler(fh) |
87 |
| - |
88 |
| - |
89 |
| -logging.setLoggerClass(CustomLogger) |
90 |
| - |
91 |
| -# Define the logger object that can be used for logging purposes throughout the module. |
92 |
| -logger = create_logger() |
| 17 | +class FTLogger: |
| 18 | + def __init__(self, level=logging.DEBUG): |
| 19 | + self.logger = logging.getLogger("QEfficient") |
| 20 | + if not getattr(self.logger, "_custom_methods_added", False): |
| 21 | + self._bind_custom_methods() |
| 22 | + self.logger._custom_methods_added = True # Prevent adding handlers/methods twice |
| 23 | + |
| 24 | + def _bind_custom_methods(self): |
| 25 | + def raise_runtimeerror(message): |
| 26 | + self.logger.error(message) |
| 27 | + raise RuntimeError(message) |
| 28 | + |
| 29 | + def log_rank_zero(msg: str, level: int = logging.INFO): |
| 30 | + rank = dist.get_rank() if dist.is_available() and dist.is_initialized() else 0 |
| 31 | + if rank != 0: |
| 32 | + return |
| 33 | + self.logger.log(level, msg, stacklevel=2) |
| 34 | + |
| 35 | + def prepare_dump_logs(dump_logs=False, level=logging.INFO): |
| 36 | + if dump_logs: |
| 37 | + logs_path = os.path.join(ROOT_DIR, "logs") |
| 38 | + if not os.path.exists(logs_path): |
| 39 | + os.makedirs(logs_path, exist_ok=True) |
| 40 | + file_name = f"log-file-{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}" + ".txt" |
| 41 | + log_file = os.path.join(logs_path, file_name) |
| 42 | + |
| 43 | + fh = logging.FileHandler(log_file) |
| 44 | + fh.setLevel(level) |
| 45 | + formatter = logging.Formatter("%(levelname)s - %(name)s - %(message)s") |
| 46 | + fh.setFormatter(formatter) |
| 47 | + self.logger.addHandler(fh) |
| 48 | + |
| 49 | + self.logger.raise_runtimeerror = raise_runtimeerror |
| 50 | + self.logger.log_rank_zero = log_rank_zero |
| 51 | + self.logger.prepare_dump_logs = prepare_dump_logs |
| 52 | + |
| 53 | + def get_logger(self): |
| 54 | + return self.logger |
| 55 | + |
| 56 | + |
| 57 | +logger = FTLogger().get_logger() |
0 commit comments