From ad9c215f01c8c0ac2313f05dd1281c804ebff4b6 Mon Sep 17 00:00:00 2001 From: Tianyu Liu Date: Mon, 12 Feb 2024 18:45:41 -0800 Subject: [PATCH 1/2] add TensorBoard logging with loss and wps [ghstack-poisoned] --- requirements.txt | 1 + run_llama_train.sh | 4 +- torchtrain/metrics.py | 49 ++++++++++++++++++++++ torchtrain/train_configs/train_config.toml | 4 ++ torchtrain/utils.py | 19 +++++++++ train.py | 46 +++++++++++++++++++- 6 files changed, 120 insertions(+), 3 deletions(-) create mode 100644 torchtrain/metrics.py create mode 100644 torchtrain/utils.py diff --git a/requirements.txt b/requirements.txt index 9bc33ca39b..8e089a3e21 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,3 +2,4 @@ torch >= 2.2.0.dev sentencepiece datasets tomli >= 1.1.0 ; python_version < "3.11" +tensorboard diff --git a/run_llama_train.sh b/run_llama_train.sh index 2749b01db2..ffecaeecac 100755 --- a/run_llama_train.sh +++ b/run_llama_train.sh @@ -24,6 +24,6 @@ CHECKPOINT_INTERVAL=${CHECKPOINT_INTERVAL:-5} torchrun --nproc_per_node=${NGPU} \ --local-ranks-filter ${LOG_RANK} --role rank --tee 3 \ -train.py --steps 10 --compile \ ---pp_degree ${PP} --sp_degree ${SP} --dp_degree ${DP} +train.py --steps 41 --compile \ +--pp_degree ${PP} --sp_degree ${SP} --dp_degree ${DP} \ --checkpoint-folder=${CHECKPOINT_FOLDER} --checkpoint-interval=${CHECKPOINT_INTERVAL} diff --git a/torchtrain/metrics.py b/torchtrain/metrics.py new file mode 100644 index 0000000000..1be7cdca99 --- /dev/null +++ b/torchtrain/metrics.py @@ -0,0 +1,49 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. + +import os +from datetime import datetime +from typing import Any, Dict, Optional + +import torch +from torch.utils.tensorboard import SummaryWriter + +from torchtrain.logging_utils import rank0_log +from torchtrain.profiling import get_config_from_toml + + +class MetricLogger: + def __init__(self, log_dir, tag, enable_tb): + self.tag = tag + self.writer: Optional[SummaryWriter] = None + if enable_tb: + self.writer = SummaryWriter(log_dir, max_queue=1000) + + def log(self, metrics: Dict[str, Any], step: int): + if self.writer is not None: + for k, v in metrics.items(): + tag = k if self.tag is None else f"{self.tag}/{k}" + self.writer.add_scalar(tag, v, step) + + def close(self): + if self.writer is not None: + self.writer.close() + + +def build_metric_logger(tag: Optional[str] = None): + config = get_config_from_toml() + + dump_dir = config["global"]["dump_folder"] + save_tb_folder = config["metrics"]["save_tb_folder"] + # since we don't have run id yet, use current minute as identifier + datetime_str = datetime.now().strftime("%Y%m%d-%H%M") + log_dir = os.path.join(dump_dir, save_tb_folder, datetime_str) + + enable_tb = config["metrics"].get("enable_tensorboard", False) + if enable_tb: + rank0_log( + f"Metrics logging active. Tensorboard logs will be saved at {log_dir}." + ) + + rank_str = f"rank_{torch.distributed.get_rank()}" + return MetricLogger(os.path.join(log_dir, rank_str), tag, enable_tb) diff --git a/torchtrain/train_configs/train_config.toml b/torchtrain/train_configs/train_config.toml index a3b02917eb..da0161e05f 100644 --- a/torchtrain/train_configs/train_config.toml +++ b/torchtrain/train_configs/train_config.toml @@ -7,3 +7,7 @@ run_profiler = true save_traces_folder = "profiling/traces" # profiling frequency - example: 10 means every 10th iter will be profiled profile_every_x_iter = 10 + +[metrics] +enable_tensorboard = true +save_tb_folder = "tb" diff --git a/torchtrain/utils.py b/torchtrain/utils.py new file mode 100644 index 0000000000..9ae71caefd --- /dev/null +++ b/torchtrain/utils.py @@ -0,0 +1,19 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. + +from typing import Union + +import torch +import torch.distributed._functional_collectives as funcol +import torch.distributed.distributed_c10d as c10d +from torch.distributed.device_mesh import DeviceMesh + + +def dist_max(x: Union[int, float], mesh: DeviceMesh) -> float: + tensor = torch.tensor(x).cuda() + return funcol.all_reduce(tensor, reduceOp=c10d.ReduceOp.MAX.name, group=mesh) + + +def dist_mean(x: Union[int, float], mesh: DeviceMesh) -> float: + tensor = torch.tensor(x).cuda() + return funcol.all_reduce(tensor, reduceOp=c10d.ReduceOp.AVG.name, group=mesh) diff --git a/train.py b/train.py index 4048101d1e..0bde88f67e 100644 --- a/train.py +++ b/train.py @@ -4,8 +4,11 @@ import argparse import os from dataclasses import dataclass, field +from timeit import default_timer as timer from typing import Any, Dict, List, Union +import numpy as np + # torch imports import torch import torch.nn.functional as F @@ -18,11 +21,13 @@ from torchtrain.datasets import create_tokenizer, dataloader_fn from torchtrain.logging_utils import init_logger, rank0_log from torchtrain.lr_scheduling import get_lr_scheduler +from torchtrain.metrics import build_metric_logger from torchtrain.models import model_name_to_cls, model_name_to_tokenizer, models_config from torchtrain.parallelisms import models_parallelize_fns, ParallelDims from torchtrain.profiling import maybe_run_profiler +from torchtrain.utils import dist_max, dist_mean @dataclass @@ -116,7 +121,7 @@ def main(args): scaler = build_grad_scaler(model) - # TODO: add metrics + metric_logger = build_metric_logger() # torch.compile model for improved performance if args.compile: @@ -146,6 +151,10 @@ def main(args): with maybe_run_profiler() as torch_profiler: checkpoint.reset() + # variables used to keep info for metrics logging + losses_since_last_log: List[float] = [] + nwords_since_last_log = 0 + time_last_log = timer() while train_state.step < args.steps or args.steps == -1: train_state.step += 1 # get batch @@ -153,6 +162,7 @@ def main(args): input_ids, labels = batch input_ids = input_ids.cuda() labels = labels.cuda() + nwords_since_last_log += labels.numel() optimizer.zero_grad() @@ -184,6 +194,32 @@ def main(args): train_state.current_loss = loss.item() train_state.losses.append(train_state.current_loss) + losses_since_last_log.append(train_state.current_loss) + + # log metrics + if (train_state.step - 1) % args.log_freq == 0: + avg_loss, max_loss = np.mean(losses_since_last_log), np.max( + losses_since_last_log + ) + global_avg_loss, global_max_loss = dist_mean( + avg_loss, world_mesh + ), dist_max(max_loss, world_mesh) + + time_delta = timer() - time_last_log + wps = nwords_since_last_log / ( + time_delta * parallel_dims.sp * parallel_dims.pp + ) + + metrics = { + "global_avg_loss": global_avg_loss, + "global_max_loss": global_max_loss, + "wps": wps, + } + metric_logger.log(metrics, step=train_state.step) + + losses_since_last_log.clear() + nwords_since_last_log = 0 + time_last_log = timer() rank0_log( f"step: {train_state.step}, current loss: {train_state.current_loss}, lr: {scheduler.get_last_lr()}" @@ -192,6 +228,8 @@ def main(args): checkpoint.save(train_state.step, force=(train_state.step == args.steps)) + metric_logger.close() + if __name__ == "__main__": parser = argparse.ArgumentParser(description="TorchTrain arg parser.") @@ -282,6 +320,12 @@ def main(args): "is an empty string, checkpointing is disabled." ), ) + parser.add_argument( + "--log_freq", + type=int, + default=10, + help="how often to log metrics to TensorBoard", + ) args = parser.parse_args() main(args) From e1abc87a10174e2819a1d693ed5146c769d33fd9 Mon Sep 17 00:00:00 2001 From: Tianyu Liu Date: Mon, 12 Feb 2024 18:48:50 -0800 Subject: [PATCH 2/2] Update on "add TensorBoard logging with loss and wps" [ghstack-poisoned] --- run_llama_train.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/run_llama_train.sh b/run_llama_train.sh index ffecaeecac..06db09f9b0 100755 --- a/run_llama_train.sh +++ b/run_llama_train.sh @@ -24,6 +24,6 @@ CHECKPOINT_INTERVAL=${CHECKPOINT_INTERVAL:-5} torchrun --nproc_per_node=${NGPU} \ --local-ranks-filter ${LOG_RANK} --role rank --tee 3 \ -train.py --steps 41 --compile \ +train.py --steps 10 --compile \ --pp_degree ${PP} --sp_degree ${SP} --dp_degree ${DP} \ --checkpoint-folder=${CHECKPOINT_FOLDER} --checkpoint-interval=${CHECKPOINT_INTERVAL}