Skip to content

Commit 8671c91

Browse files
authored
Add color to console output if local logging, auto avoid color logging on slurm (#93)
This PR adds the ability to do colored console outputs in order to highlight the training data outputs. It also adds a check to not use this color formatting on slurm, where it will add 33= instead of the color if not avoided. Note that I've just added some color to highlight the main training data. Users that fork/clone can use it to enhance their outputs as desired. <img width="1372" alt="Screenshot 2024-02-26 at 10 20 15 PM" src="https://github.com/pytorch/torchtrain/assets/46302957/44849821-1677-40bf-896c-39344cd661d6"> Note that on slurm it remains plain: <img width="847" alt="Screenshot 2024-02-26 at 10 46 24 PM" src="https://github.com/pytorch/torchtrain/assets/46302957/172eaa58-4f5c-48f5-8ec1-bc349e3e82f2"> if you dont' check this, then it would otherwise look like this (this does not happen with this PR, just showing if we didn't check and credit to Yifu for noting this would be an issue): <img width="847" alt="Screenshot 2024-02-26 at 10 39 23 PM" src="https://github.com/pytorch/torchtrain/assets/46302957/4a87fb9a-dd3a-417c-a29e-286ded069358">
1 parent 5dec536 commit 8671c91

File tree

2 files changed

+66
-8
lines changed

2 files changed

+66
-8
lines changed

torchtrain/utils.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# Copyright (c) Meta Platforms, Inc. and affiliates.
22
# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
33

4+
from dataclasses import dataclass
45
from typing import Union
56

67
import torch
@@ -17,3 +18,37 @@ def dist_max(x: Union[int, float], mesh: DeviceMesh) -> float:
1718
def dist_mean(x: Union[int, float], mesh: DeviceMesh) -> float:
1819
tensor = torch.tensor(x).cuda()
1920
return funcol.all_reduce(tensor, reduceOp=c10d.ReduceOp.AVG.name, group=mesh)
21+
22+
23+
@dataclass
24+
class Color:
25+
black = "\033[30m"
26+
red = "\033[31m"
27+
green = "\033[32m"
28+
yellow = "\033[33m"
29+
blue = "\033[34m"
30+
magenta = "\033[35m"
31+
cyan = "\033[36m"
32+
white = "\033[37m"
33+
reset = "\033[39m"
34+
35+
36+
@dataclass
37+
class Background:
38+
black = "\033[40m"
39+
red = "\033[41m"
40+
green = "\033[42m"
41+
yellow = "\033[43m"
42+
blue = "\033[44m"
43+
magenta = "\033[45m"
44+
cyan = "\033[46m"
45+
white = "\033[47m"
46+
reset = "\033[49m"
47+
48+
49+
@dataclass
50+
class Style:
51+
bright = "\033[1m"
52+
dim = "\033[2m"
53+
normal = "\033[22m"
54+
reset = "\033[0m"

train.py

Lines changed: 31 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
33

44
import os
5+
56
from dataclasses import dataclass, field
67
from timeit import default_timer as timer
78
from typing import Any, Dict, List
@@ -27,7 +28,11 @@
2728
from torchtrain.parallelisms import models_parallelize_fns, ParallelDims
2829

2930
from torchtrain.profiling import maybe_run_profiler
30-
from torchtrain.utils import dist_max, dist_mean
31+
from torchtrain.utils import Color, dist_max, dist_mean
32+
33+
_is_local_logging = True
34+
if "SLURM_JOB_ID" in os.environ:
35+
_is_local_logging = False
3136

3237

3338
@dataclass
@@ -119,9 +124,16 @@ def main(job_config: JobConfig):
119124

120125
# log model size
121126
model_param_count = get_num_params(model)
122-
rank0_log(
123-
f"Model {model_name} {job_config.model.flavor} size: {model_param_count:,} total parameters"
124-
)
127+
if _is_local_logging:
128+
rank0_log(
129+
f"{Color.blue}Model {model_name} {job_config.model.flavor} {Color.red}size: {model_param_count:,}"
130+
f" total parameters{Color.reset}"
131+
)
132+
else:
133+
rank0_log(
134+
f"{model_name} {job_config.model.flavor} size: {model_param_count:,} total parameters"
135+
)
136+
125137
gpu_metrics = GPUMemoryMonitor("cuda")
126138
rank0_log(f"GPU memory usage: {gpu_metrics}")
127139

@@ -268,10 +280,21 @@ def main(job_config: JobConfig):
268280
nwords_since_last_log = 0
269281
time_last_log = timer()
270282

271-
rank0_log(
272-
f"step: {train_state.step:>2} loss: {round(train_state.current_loss,4):>7}"
273-
f" iter: {curr_iter_time:>7} data: {data_load_time:>5} lr: {round(float(scheduler.get_last_lr()[0]), 8):<6}"
274-
)
283+
if _is_local_logging:
284+
rank0_log(
285+
f"{Color.cyan}step: {train_state.step:>2} {Color.green}loss: {round(train_state.current_loss,4):>7}"
286+
f" {Color.reset}iter: {Color.blue}{curr_iter_time:>7}{Color.reset}"
287+
f" data: {Color.blue}{data_load_time:>5} {Color.reset}"
288+
f"lr: {Color.yellow}{round(float(scheduler.get_last_lr()[0]), 8):<6}{Color.reset}"
289+
)
290+
else:
291+
rank0_log(
292+
f"step: {train_state.step:>2} loss: {round(train_state.current_loss,4):>7}"
293+
f" iter: {curr_iter_time:>7}"
294+
f" data: {data_load_time:>5} "
295+
f"lr: {round(float(scheduler.get_last_lr()[0]), 8):<6}"
296+
)
297+
275298
scheduler.step()
276299

277300
checkpoint.save(

0 commit comments

Comments
 (0)