Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.

Commit 8b6aa07

Browse files
authored
[distributed] add SingletonLogger to avoid duplicate logging, with custom formatting options including file name: line_number (#1124)
* add SingletonLogger with custom formatting options * ruff formatting
1 parent 4e7332f commit 8b6aa07

File tree

11 files changed

+123
-34
lines changed

11 files changed

+123
-34
lines changed

dist_run.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,8 @@
1515
import torch.distributed as dist
1616
from torch.distributed.pipelining import PipelineStage, ScheduleGPipe
1717

18-
from distributed.logging_utils import setup_logging
18+
19+
from distributed.logging_utils import SingletonLogger
1920

2021
# TODO - these are not distributed specific, consider moving to new package
2122
from distributed.safetensor_utils import (
@@ -41,7 +42,7 @@
4142
SentencePieceProcessor = None
4243

4344

44-
logger = setup_logging(__name__)
45+
logger = SingletonLogger.get_logger()
4546

4647
MODEL_NAME = "Transformer-2-7b-chat-hf"
4748
NAME_TO_HF_MODEL_ID_AND_DTYPE = {

distributed/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
# LICENSE file in the root directory of this source tree.
66

77
from distributed.checkpoint import load_checkpoints_to_model
8-
from distributed.logging_utils import setup_logging
8+
from distributed.logging_utils import SingletonLogger
99
from distributed.parallel_config import ParallelDims
1010
from distributed.parallelize_llama import parallelize_llama
1111
from distributed.utils import init_distributed

distributed/config_manager.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,9 @@
1212

1313
import torch
1414

15-
from distributed.logging_utils import setup_logging
15+
from distributed.logging_utils import SingletonLogger
16+
logger = SingletonLogger.get_logger()
1617

17-
logger = setup_logging(__name__)
1818

1919
try:
2020
import tomllib

distributed/dtensor_utils.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
import torch
22
from torch.distributed._tensor import DTensor, Shard, Replicate
33

4-
from distributed.logging_utils import setup_logging
4+
55
from collections import defaultdict
66

7-
logger = setup_logging(__name__)
7+
from distributed.logging_utils import SingletonLogger
8+
logger = SingletonLogger.get_logger()
9+
810

911

1012
def is_dtensor(tensor):

distributed/logging_utils.py

Lines changed: 93 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -7,25 +7,103 @@
77
import logging
88
import os
99
from datetime import datetime
10+
from typing import Optional
1011

11-
def millisecond_timestamp(*args):
12-
return datetime.now().strftime('%m-%d %H:%M:%S.%f')[:-3]
1312

14-
def setup_logging(name=None, log_level=logging.INFO):
15-
logger = logging.getLogger(name)
16-
logger.setLevel(log_level)
13+
def millisecond_timestamp(include_year: bool = False) -> str:
14+
format_string = "%Y-%m-%d %H:%M:%S.%f" if include_year else "%m-%d %H:%M:%S.%f"
15+
return datetime.now().strftime(format_string)[:-3]
1716

18-
if not logger.handlers:
19-
console_handler = logging.StreamHandler()
20-
console_handler.setLevel(log_level)
2117

22-
formatter = logging.Formatter('%(asctime)s - %(name)s:%(lineno)d - %(levelname)s - %(message)s')
23-
formatter.formatTime = millisecond_timestamp
18+
class CompactFormatter(logging.Formatter):
19+
def __init__(
20+
self,
21+
fmt: Optional[str] = None,
22+
datefmt: Optional[str] = None,
23+
style: str = "%",
24+
validate: bool = True,
25+
*,
26+
defaults: Optional[dict] = None,
27+
show_lower_levels: bool = True,
28+
):
29+
super().__init__(fmt, datefmt, style, validate, defaults=defaults)
30+
self.show_lower_levels = show_lower_levels
31+
self.original_fmt = fmt
2432

25-
console_handler.setFormatter(formatter)
26-
logger.addHandler(console_handler)
33+
def format(self, record: logging.LogRecord) -> str:
34+
# Remove .py extension from filename
35+
record.filename = os.path.splitext(record.filename)[0]
2736

28-
# suppress verbose torch.profiler logging
29-
os.environ["KINETO_LOG_LEVEL"] = "5"
37+
if self.show_lower_levels or record.levelno > logging.INFO:
38+
return super().format(record)
39+
else:
40+
# Create a copy of the record and modify it
41+
new_record = logging.makeLogRecord(record.__dict__)
42+
new_record.levelname = ""
43+
# Temporarily change the format string
44+
temp_fmt = self.original_fmt.replace(" - %(levelname)s", "")
45+
self._style._fmt = temp_fmt
46+
formatted_message = super().format(new_record)
47+
# Restore the original format string
48+
self._style._fmt = self.original_fmt
49+
return formatted_message
3050

31-
return logger
51+
52+
class SingletonLogger:
53+
"""Singleton (global) logger to avoid logging duplication"""
54+
55+
_instance = None
56+
57+
@classmethod
58+
def get_logger(
59+
cls,
60+
name: str = "global_logger",
61+
level: int = logging.INFO,
62+
include_year: bool = False,
63+
show_lower_levels: bool = False,
64+
) -> logging.Logger:
65+
"""
66+
Get or create a singleton logger instance.
67+
68+
:param name: Name of the logger
69+
:param level: Logging level
70+
:param include_year: Whether to include the year in timestamps
71+
:param show_lower_levels: Whether to show level names for INFO and DEBUG messages
72+
:return: Logger instance
73+
"""
74+
if cls._instance is None:
75+
cls._instance = cls._setup_logger(
76+
name, level, include_year, show_lower_levels
77+
)
78+
return cls._instance
79+
80+
@staticmethod
81+
def _setup_logger(
82+
name: str,
83+
level: int,
84+
include_year: bool = False,
85+
show_lower_levels: bool = False,
86+
) -> logging.Logger:
87+
logger = logging.getLogger(name)
88+
89+
if not logger.handlers:
90+
logger.setLevel(level)
91+
92+
console_handler = logging.StreamHandler()
93+
console_handler.setLevel(level)
94+
95+
formatter = CompactFormatter(
96+
"%(asctime)s - %(filename)s:%(lineno)d - %(levelname)s - %(message)s",
97+
show_lower_levels=show_lower_levels,
98+
)
99+
formatter.formatTime = lambda record, datefmt=None: millisecond_timestamp(
100+
include_year
101+
)
102+
console_handler.setFormatter(formatter)
103+
logger.addHandler(console_handler)
104+
105+
# Suppress verbose torch.profiler logging
106+
os.environ["KINETO_LOG_LEVEL"] = "5"
107+
108+
logger.propagate = False
109+
return logger

distributed/parallel_config.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,8 @@
88

99
from torch.distributed.device_mesh import init_device_mesh
1010

11-
from distributed.logging_utils import setup_logging
12-
logger = setup_logging(__name__)
11+
from distributed.logging_utils import SingletonLogger
12+
logger = SingletonLogger.get_logger()
1313

1414
@dataclass
1515
class ParallelDims:

distributed/parallelize_llama.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,11 @@
1010
RowwiseParallel,
1111
parallelize_module)
1212

13-
from distributed.logging_utils import setup_logging
13+
1414
from distributed.parallel_config import ParallelDims
1515

16-
logger = setup_logging(__name__)
16+
from distributed.logging_utils import SingletonLogger
17+
logger = SingletonLogger.get_logger()
1718

1819

1920
def apply_tp(

distributed/safetensor_utils.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,14 +13,17 @@
1313
from torch.nn import Module
1414
from typing import Dict, Tuple, Set, Optional
1515

16-
from distributed.logging_utils import setup_logging
16+
1717
from distributed.dtensor_utils import is_dtensor, load_into_dtensor
1818

1919

2020
_DEFAULT_SAFETENSOR_FILE_NAME = "model.safetensors.index.json"
2121
_CONFIG_NAME = "config.json"
2222

23-
logger = setup_logging(__name__)
23+
24+
from distributed.logging_utils import SingletonLogger
25+
logger = SingletonLogger.get_logger()
26+
2427

2528

2629
def compare_and_reverse(tensor1: torch.Tensor, tensor2: torch.Tensor) -> torch.Tensor:

distributed/utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,9 @@
1313

1414
import torch
1515

16-
from distributed.logging_utils import setup_logging
1716

18-
logger = setup_logging(__name__)
17+
from distributed.logging_utils import SingletonLogger
18+
logger = SingletonLogger.get_logger()
1919

2020

2121
def _warn_overwrite_env(env, val):

distributed/verification_utils.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,10 @@
66
import numpy as np
77
from distributed.dtensor_utils import is_dtensor
88
from typing import Dict, List, Tuple
9-
from distributed.logging_utils import setup_logging
109

11-
logger = setup_logging(__name__)
10+
from distributed.logging_utils import SingletonLogger
11+
logger = SingletonLogger.get_logger()
12+
1213

1314

1415
def record_module_dtypes(module):

0 commit comments

Comments
 (0)