|
| 1 | +# Copyright (c) Meta Platforms, Inc. and affiliates. |
| 2 | +# All rights reserved. |
| 3 | +# |
| 4 | +# This source code is licensed under the BSD-style license found in the |
| 5 | +# LICENSE file in the root directory of this source tree. |
| 6 | + |
| 7 | +import contextlib |
| 8 | +import gc |
| 9 | +import os |
| 10 | + |
| 11 | +import torch |
| 12 | +import torch.nn.functional as F |
| 13 | +from torch._guards import active_fake_mode |
| 14 | +from torch._subclasses.fake_tensor import FakeTensorMode |
| 15 | +from torch.distributed import destroy_process_group |
| 16 | +from torch.distributed._tools.fsdp2_mem_tracker import FSDPMemTracker |
| 17 | +from torch.distributed.tensor.parallel import loss_parallel |
| 18 | +from torch.testing._internal.distributed.fake_pg import FakeStore |
| 19 | + |
| 20 | +from torchtitan.config_manager import JobConfig |
| 21 | +from torchtitan.datasets import create_tokenizer |
| 22 | +from torchtitan.float8_linear import build_fp8_linear |
| 23 | +from torchtitan.logging_utils import init_logger, logger |
| 24 | +from torchtitan.lr_scheduling import get_lr_schedulers |
| 25 | +from torchtitan.models import model_name_to_cls, model_name_to_tokenizer, models_config |
| 26 | +from torchtitan.parallelisms import models_parallelize_fns, ParallelDims |
| 27 | +from train import build_optimizers |
| 28 | + |
| 29 | + |
| 30 | +def estimate_memory(job_config: JobConfig): |
| 31 | + init_logger() |
| 32 | + logger.info("Estimating memory usage...") |
| 33 | + gc.disable() |
| 34 | + gc.collect(1) |
| 35 | + |
| 36 | + # Get the world size |
| 37 | + world_size = int(os.environ["WORLD_SIZE"]) |
| 38 | + |
| 39 | + # if tp > or pp > 1, we exit |
| 40 | + if ( |
| 41 | + job_config.training.tensor_parallel_degree > 1 |
| 42 | + or job_config.experimental.pipeline_parallel_degree > 1 |
| 43 | + ): |
| 44 | + logger.info( |
| 45 | + "Tensor parallelism and pipeline parallelism are not supported yet." |
| 46 | + ) |
| 47 | + return |
| 48 | + |
| 49 | + # fake tensor doesn't work with fused rmsnorm |
| 50 | + if ( |
| 51 | + job_config.model.norm_type == "fused_rmsnorm" |
| 52 | + and job_config.estimate.mode == "fake" |
| 53 | + ): |
| 54 | + logger.info( |
| 55 | + "Fused RMSNorm is not supported yet under fake estimation mode. " |
| 56 | + "Switching to rmsnorm." |
| 57 | + ) |
| 58 | + job_config.model.norm_type = "rmsnorm" |
| 59 | + |
| 60 | + parallel_dims = ParallelDims( |
| 61 | + dp=job_config.training.data_parallel_degree, |
| 62 | + tp=job_config.training.tensor_parallel_degree, |
| 63 | + pp=job_config.experimental.pipeline_parallel_degree, |
| 64 | + world_size=world_size, |
| 65 | + enable_loss_parallel=job_config.training.enable_loss_parallel, |
| 66 | + ) |
| 67 | + |
| 68 | + device = torch.device(f"cuda:{int(os.environ['LOCAL_RANK'])}") |
| 69 | + torch.cuda.set_device(device) |
| 70 | + |
| 71 | + # init fake pg |
| 72 | + store = FakeStore() |
| 73 | + torch.distributed.init_process_group( |
| 74 | + "fake", rank=int(os.environ["LOCAL_RANK"]), world_size=world_size, store=store |
| 75 | + ) |
| 76 | + |
| 77 | + # build meshes |
| 78 | + world_mesh = parallel_dims.build_mesh(device_type="cuda") |
| 79 | + |
| 80 | + if not parallel_dims.dp_enabled: |
| 81 | + logger.info("Data parallelism is not enabled. Skipping memory estimation.") |
| 82 | + return |
| 83 | + |
| 84 | + model_name = job_config.model.name |
| 85 | + |
| 86 | + # build tokenizer |
| 87 | + tokenizer_type = model_name_to_tokenizer[model_name] |
| 88 | + tokenizer = create_tokenizer(tokenizer_type, job_config.model.tokenizer_path) |
| 89 | + |
| 90 | + # loss_parallel enables dispatching to efficient loss operators |
| 91 | + loss_parallel_ctx = ( |
| 92 | + loss_parallel if parallel_dims.loss_parallel_enabled else contextlib.nullcontext |
| 93 | + ) |
| 94 | + |
| 95 | + # loss fn can be shared by pipeline-parallel or non-pp execution |
| 96 | + def loss_fn(pred, labels): |
| 97 | + return F.cross_entropy(pred.flatten(0, 1), labels.flatten(0, 1)) |
| 98 | + |
| 99 | + # build model (using meta init) |
| 100 | + model_cls = model_name_to_cls[model_name] |
| 101 | + model_config = models_config[model_name][job_config.model.flavor] |
| 102 | + # set the model configs from training inputs: |
| 103 | + # 1. norm type to decide which norm layer to use |
| 104 | + # 2. vocab size from tokenizer |
| 105 | + # 3. max_seq_len base on inputs |
| 106 | + model_config.norm_type = job_config.model.norm_type |
| 107 | + model_config.vocab_size = tokenizer.n_words |
| 108 | + model_config.max_seq_len = job_config.training.seq_len |
| 109 | + |
| 110 | + with FakeTensorMode() if job_config.estimate.mode == "fake" else contextlib.nullcontext(): |
| 111 | + |
| 112 | + logger.info( |
| 113 | + f"Building {model_name} {job_config.model.flavor} with {model_config}" |
| 114 | + ) |
| 115 | + with torch.device("meta"): |
| 116 | + whole_model = model_cls.from_model_args(model_config) |
| 117 | + |
| 118 | + # apply fp8 linear module swap |
| 119 | + if job_config.training.fp8_linear: |
| 120 | + build_fp8_linear(whole_model, job_config) |
| 121 | + |
| 122 | + # apply PT-D DP/TP parallelisms and activation checkpointing |
| 123 | + model_parts = [whole_model] |
| 124 | + model_parts = [ |
| 125 | + models_parallelize_fns[model_name](m, world_mesh, parallel_dims, job_config) |
| 126 | + for m in model_parts |
| 127 | + ] |
| 128 | + |
| 129 | + init_device = "cuda" |
| 130 | + for model in model_parts: |
| 131 | + model.to_empty(device=init_device) |
| 132 | + |
| 133 | + if not active_fake_mode(): |
| 134 | + whole_model.init_weights() |
| 135 | + |
| 136 | + # build optimizer after applying parallelisms to the model |
| 137 | + optimizers = build_optimizers(model_parts, job_config) |
| 138 | + lr_schedulers = get_lr_schedulers(optimizers.optimizers, job_config) |
| 139 | + |
| 140 | + for model in model_parts: |
| 141 | + model.train() |
| 142 | + logger.info(f"Vocab size: {model_config.vocab_size}") |
| 143 | + # Create a dummy batch instead of loading from a dataset |
| 144 | + batch = ( |
| 145 | + torch.randint( |
| 146 | + 0, |
| 147 | + model_config.vocab_size, |
| 148 | + (job_config.training.batch_size, model_config.max_seq_len), |
| 149 | + device="cuda", |
| 150 | + ), |
| 151 | + torch.randint( |
| 152 | + 0, |
| 153 | + model_config.vocab_size, |
| 154 | + (job_config.training.batch_size, model_config.max_seq_len), |
| 155 | + device="cuda", |
| 156 | + ), |
| 157 | + ) |
| 158 | + fsdp_memtracker = FSDPMemTracker(mod=whole_model, optm=optimizers.optimizers[0]) |
| 159 | + fsdp_memtracker.track_inputs(batch) |
| 160 | + |
| 161 | + with fsdp_memtracker: |
| 162 | + for iter_idx in range(2): |
| 163 | + input_ids, labels = batch |
| 164 | + # train step |
| 165 | + with loss_parallel_ctx(): |
| 166 | + pred = whole_model(input_ids) |
| 167 | + loss = loss_fn(pred, labels) |
| 168 | + del pred |
| 169 | + loss.backward() |
| 170 | + |
| 171 | + # clip gradients |
| 172 | + for model in model_parts: |
| 173 | + torch.nn.utils.clip_grad_norm_( |
| 174 | + model.parameters(), job_config.training.max_norm, foreach=True |
| 175 | + ) |
| 176 | + # optimizer step |
| 177 | + optimizers.step() |
| 178 | + lr_schedulers.step() |
| 179 | + optimizers.zero_grad() |
| 180 | + print(f"Peak Memory at iter: {iter_idx}") |
| 181 | + fsdp_memtracker.display_snapshot("peak", units="MiB", tabulate=True) |
| 182 | + if iter_idx == 0: |
| 183 | + fsdp_memtracker.reset_mod_stats() # iter 0 does not have optimizer state |
| 184 | + gc.collect(1) |
| 185 | + |
| 186 | + fsdp_memtracker.display_modulewise_snapshots( |
| 187 | + depth=3, units="MiB", tabulate=True |
| 188 | + ) |
| 189 | + mem_stats = torch.cuda.memory_stats() |
| 190 | + peak_active = mem_stats["active_bytes.all.peak"] |
| 191 | + peak_reserved = mem_stats["reserved_bytes.all.peak"] |
| 192 | + num_retries = mem_stats["num_alloc_retries"] |
| 193 | + dev = torch.device(torch.cuda.current_device()) |
| 194 | + tracker_peak = fsdp_memtracker.get_tracker_snapshot("peak")[dev]["Total"] |
| 195 | + gib = 1024**3 |
| 196 | + print( |
| 197 | + f"peak active: {peak_active / gib} GiB | peak reserved:" |
| 198 | + f" {peak_reserved / gib} GiB | num_retries: {num_retries}" |
| 199 | + ) |
| 200 | + print(f"Tracker Max: {tracker_peak / gib} GiB") |
| 201 | + if job_config.estimate.mode == "real": |
| 202 | + print(f"Tracker Accuracy: {tracker_peak/peak_active}") |
| 203 | + gc.enable() |
| 204 | + |
| 205 | + |
| 206 | +if __name__ == "__main__": |
| 207 | + config = JobConfig() |
| 208 | + config.parse_args() |
| 209 | + try: |
| 210 | + estimate_memory(config) |
| 211 | + finally: |
| 212 | + destroy_process_group() |
0 commit comments