|
2 | 2 | # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. |
3 | 3 |
|
4 | 4 | import os |
| 5 | + |
5 | 6 | from dataclasses import dataclass, field |
6 | 7 | from timeit import default_timer as timer |
7 | 8 | from typing import Any, Dict, List |
|
27 | 28 | from torchtrain.parallelisms import models_parallelize_fns, ParallelDims |
28 | 29 |
|
29 | 30 | from torchtrain.profiling import maybe_run_profiler |
30 | | -from torchtrain.utils import dist_max, dist_mean |
| 31 | +from torchtrain.utils import Color, dist_max, dist_mean |
| 32 | + |
| 33 | +_is_local_logging = True |
| 34 | +if "SLURM_JOB_ID" in os.environ: |
| 35 | + _is_local_logging = False |
31 | 36 |
|
32 | 37 |
|
33 | 38 | @dataclass |
@@ -119,9 +124,16 @@ def main(job_config: JobConfig): |
119 | 124 |
|
120 | 125 | # log model size |
121 | 126 | model_param_count = get_num_params(model) |
122 | | - rank0_log( |
123 | | - f"Model {model_name} {job_config.model.flavor} size: {model_param_count:,} total parameters" |
124 | | - ) |
| 127 | + if _is_local_logging: |
| 128 | + rank0_log( |
| 129 | + f"{Color.blue}Model {model_name} {job_config.model.flavor} {Color.red}size: {model_param_count:,}" |
| 130 | + f" total parameters{Color.reset}" |
| 131 | + ) |
| 132 | + else: |
| 133 | + rank0_log( |
| 134 | + f"{model_name} {job_config.model.flavor} size: {model_param_count:,} total parameters" |
| 135 | + ) |
| 136 | + |
125 | 137 | gpu_metrics = GPUMemoryMonitor("cuda") |
126 | 138 | rank0_log(f"GPU memory usage: {gpu_metrics}") |
127 | 139 |
|
@@ -268,10 +280,21 @@ def main(job_config: JobConfig): |
268 | 280 | nwords_since_last_log = 0 |
269 | 281 | time_last_log = timer() |
270 | 282 |
|
271 | | - rank0_log( |
272 | | - f"step: {train_state.step:>2} loss: {round(train_state.current_loss,4):>7}" |
273 | | - f" iter: {curr_iter_time:>7} data: {data_load_time:>5} lr: {round(float(scheduler.get_last_lr()[0]), 8):<6}" |
274 | | - ) |
| 283 | + if _is_local_logging: |
| 284 | + rank0_log( |
| 285 | + f"{Color.cyan}step: {train_state.step:>2} {Color.green}loss: {round(train_state.current_loss,4):>7}" |
| 286 | + f" {Color.reset}iter: {Color.blue}{curr_iter_time:>7}{Color.reset}" |
| 287 | + f" data: {Color.blue}{data_load_time:>5} {Color.reset}" |
| 288 | + f"lr: {Color.yellow}{round(float(scheduler.get_last_lr()[0]), 8):<6}{Color.reset}" |
| 289 | + ) |
| 290 | + else: |
| 291 | + rank0_log( |
| 292 | + f"step: {train_state.step:>2} loss: {round(train_state.current_loss,4):>7}" |
| 293 | + f" iter: {curr_iter_time:>7}" |
| 294 | + f" data: {data_load_time:>5} " |
| 295 | + f"lr: {round(float(scheduler.get_last_lr()[0]), 8):<6}" |
| 296 | + ) |
| 297 | + |
275 | 298 | scheduler.step() |
276 | 299 |
|
277 | 300 | checkpoint.save( |
|
0 commit comments