Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions dist_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@
import torch.distributed as dist
from torch.distributed.pipelining import PipelineStage, ScheduleGPipe

from distributed.logging_utils import setup_logging

from distributed.logging_utils import SingletonLogger

# TODO - these are not distributed specific, consider moving to new package
from distributed.safetensor_utils import (
Expand All @@ -41,7 +42,7 @@
SentencePieceProcessor = None


logger = setup_logging(__name__)
logger = SingletonLogger.get_logger()

MODEL_NAME = "Transformer-2-7b-chat-hf"
NAME_TO_HF_MODEL_ID_AND_DTYPE = {
Expand Down
2 changes: 1 addition & 1 deletion distributed/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
# LICENSE file in the root directory of this source tree.

from distributed.checkpoint import load_checkpoints_to_model
from distributed.logging_utils import setup_logging
from distributed.logging_utils import SingletonLogger
from distributed.parallel_config import ParallelDims
from distributed.parallelize_llama import parallelize_llama
from distributed.utils import init_distributed
Expand Down
4 changes: 2 additions & 2 deletions distributed/config_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@

import torch

from distributed.logging_utils import setup_logging
from distributed.logging_utils import SingletonLogger
logger = SingletonLogger.get_logger()

logger = setup_logging(__name__)

try:
import tomllib
Expand Down
6 changes: 4 additions & 2 deletions distributed/dtensor_utils.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import torch
from torch.distributed._tensor import DTensor, Shard, Replicate

from distributed.logging_utils import setup_logging

from collections import defaultdict

logger = setup_logging(__name__)
from distributed.logging_utils import SingletonLogger
logger = SingletonLogger.get_logger()



def is_dtensor(tensor):
Expand Down
108 changes: 93 additions & 15 deletions distributed/logging_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,25 +7,103 @@
import logging
import os
from datetime import datetime
from typing import Optional

def millisecond_timestamp(*args):
return datetime.now().strftime('%m-%d %H:%M:%S.%f')[:-3]

def setup_logging(name=None, log_level=logging.INFO):
logger = logging.getLogger(name)
logger.setLevel(log_level)
def millisecond_timestamp(include_year: bool = False) -> str:
format_string = "%Y-%m-%d %H:%M:%S.%f" if include_year else "%m-%d %H:%M:%S.%f"
return datetime.now().strftime(format_string)[:-3]

if not logger.handlers:
console_handler = logging.StreamHandler()
console_handler.setLevel(log_level)

formatter = logging.Formatter('%(asctime)s - %(name)s:%(lineno)d - %(levelname)s - %(message)s')
formatter.formatTime = millisecond_timestamp
class CompactFormatter(logging.Formatter):
def __init__(
self,
fmt: Optional[str] = None,
datefmt: Optional[str] = None,
style: str = "%",
validate: bool = True,
*,
defaults: Optional[dict] = None,
show_lower_levels: bool = True,
):
super().__init__(fmt, datefmt, style, validate, defaults=defaults)
self.show_lower_levels = show_lower_levels
self.original_fmt = fmt

console_handler.setFormatter(formatter)
logger.addHandler(console_handler)
def format(self, record: logging.LogRecord) -> str:
# Remove .py extension from filename
record.filename = os.path.splitext(record.filename)[0]

# suppress verbose torch.profiler logging
os.environ["KINETO_LOG_LEVEL"] = "5"
if self.show_lower_levels or record.levelno > logging.INFO:
return super().format(record)
else:
# Create a copy of the record and modify it
new_record = logging.makeLogRecord(record.__dict__)
new_record.levelname = ""
# Temporarily change the format string
temp_fmt = self.original_fmt.replace(" - %(levelname)s", "")
self._style._fmt = temp_fmt
formatted_message = super().format(new_record)
# Restore the original format string
self._style._fmt = self.original_fmt
return formatted_message

return logger

class SingletonLogger:
"""Singleton (global) logger to avoid logging duplication"""

_instance = None

@classmethod
def get_logger(
cls,
name: str = "global_logger",
level: int = logging.INFO,
include_year: bool = False,
show_lower_levels: bool = False,
) -> logging.Logger:
"""
Get or create a singleton logger instance.

:param name: Name of the logger
:param level: Logging level
:param include_year: Whether to include the year in timestamps
:param show_lower_levels: Whether to show level names for INFO and DEBUG messages
:return: Logger instance
"""
if cls._instance is None:
cls._instance = cls._setup_logger(
name, level, include_year, show_lower_levels
)
return cls._instance

@staticmethod
def _setup_logger(
name: str,
level: int,
include_year: bool = False,
show_lower_levels: bool = False,
) -> logging.Logger:
logger = logging.getLogger(name)

if not logger.handlers:
logger.setLevel(level)

console_handler = logging.StreamHandler()
console_handler.setLevel(level)

formatter = CompactFormatter(
"%(asctime)s - %(filename)s:%(lineno)d - %(levelname)s - %(message)s",
show_lower_levels=show_lower_levels,
)
formatter.formatTime = lambda record, datefmt=None: millisecond_timestamp(
include_year
)
console_handler.setFormatter(formatter)
logger.addHandler(console_handler)

# Suppress verbose torch.profiler logging
os.environ["KINETO_LOG_LEVEL"] = "5"

logger.propagate = False
return logger
4 changes: 2 additions & 2 deletions distributed/parallel_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@

from torch.distributed.device_mesh import init_device_mesh

from distributed.logging_utils import setup_logging
logger = setup_logging(__name__)
from distributed.logging_utils import SingletonLogger
logger = SingletonLogger.get_logger()

@dataclass
class ParallelDims:
Expand Down
5 changes: 3 additions & 2 deletions distributed/parallelize_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,11 @@
RowwiseParallel,
parallelize_module)

from distributed.logging_utils import setup_logging

from distributed.parallel_config import ParallelDims

logger = setup_logging(__name__)
from distributed.logging_utils import SingletonLogger
logger = SingletonLogger.get_logger()


def apply_tp(
Expand Down
7 changes: 5 additions & 2 deletions distributed/safetensor_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,17 @@
from torch.nn import Module
from typing import Dict, Tuple, Set, Optional

from distributed.logging_utils import setup_logging

from distributed.dtensor_utils import is_dtensor, load_into_dtensor


_DEFAULT_SAFETENSOR_FILE_NAME = "model.safetensors.index.json"
_CONFIG_NAME = "config.json"

logger = setup_logging(__name__)

from distributed.logging_utils import SingletonLogger
logger = SingletonLogger.get_logger()



def compare_and_reverse(tensor1: torch.Tensor, tensor2: torch.Tensor) -> torch.Tensor:
Expand Down
4 changes: 2 additions & 2 deletions distributed/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,9 @@

import torch

from distributed.logging_utils import setup_logging

logger = setup_logging(__name__)
from distributed.logging_utils import SingletonLogger
logger = SingletonLogger.get_logger()


def _warn_overwrite_env(env, val):
Expand Down
5 changes: 3 additions & 2 deletions distributed/verification_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,10 @@
import numpy as np
from distributed.dtensor_utils import is_dtensor
from typing import Dict, List, Tuple
from distributed.logging_utils import setup_logging

logger = setup_logging(__name__)
from distributed.logging_utils import SingletonLogger
logger = SingletonLogger.get_logger()



def record_module_dtypes(module):
Expand Down
7 changes: 5 additions & 2 deletions distributed/world_maker.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,16 @@

from torch.distributed.device_mesh import DeviceMesh

from distributed.logging_utils import setup_logging

from distributed.parallel_config import ParallelDims
from distributed.utils import init_distributed

from .config_manager import InferenceConfig

logger = setup_logging(__name__)

from distributed.logging_utils import SingletonLogger
logger = SingletonLogger.get_logger()


def launch_distributed(
toml_config: str,
Expand Down