From 669e4b0e768216810c286b7b8e9ef12928761c8e Mon Sep 17 00:00:00 2001 From: Tianyu Liu Date: Tue, 30 Jul 2024 23:52:15 -0700 Subject: [PATCH] [BE][3/n] wrap fp8 logic using Float8Handler [ghstack-poisoned] --- estimation.py | 14 +- torchtitan/config_manager.py | 83 ++++----- torchtitan/float8_linear.py | 173 ++++++++++--------- torchtitan/parallelisms/parallelize_llama.py | 2 +- train.py | 17 +- train_configs/debug_model.toml | 4 +- train_configs/llama2_13b.toml | 4 +- train_configs/llama2_70b.toml | 4 +- train_configs/llama2_7b.toml | 4 +- train_configs/llama3_70b.toml | 4 +- train_configs/llama3_8b.toml | 4 +- 11 files changed, 163 insertions(+), 150 deletions(-) diff --git a/estimation.py b/estimation.py index 3adcf66374..acf867d54b 100644 --- a/estimation.py +++ b/estimation.py @@ -16,10 +16,7 @@ from torchtitan.config_manager import JobConfig from torchtitan.datasets import build_tokenizer -from torchtitan.float8_linear import ( - maybe_build_fp8_linear, - maybe_precompute_fp8_dynamic_scale_for_fsdp, -) +from torchtitan.float8_linear import Float8Handler from torchtitan.logging import init_logger, logger from torchtitan.models import model_name_to_cls, model_name_to_tokenizer, models_config from torchtitan.optimizer import build_lr_schedulers, build_optimizers @@ -127,8 +124,10 @@ def loss_fn(pred, labels): with torch.device("meta"): whole_model = model_cls.from_model_args(model_config) + # a no-op hander if fp8 is not enabled + float8_handler = Float8Handler(job_config, parallel_dims) # swap to Float8Linear base on fp8 config - maybe_build_fp8_linear(whole_model, job_config, parallel_dims.dp_enabled) + float8_handler.convert_to_float8_training(whole_model) # apply PT-D DP/TP parallelisms and activation checkpointing model_parts = [whole_model] @@ -184,13 +183,14 @@ def loss_fn(pred, labels): torch.nn.utils.clip_grad_norm_( model.parameters(), job_config.training.max_norm, foreach=True ) + # sync float8 amaxes and scales + float8_handler.sync_float8_amax_and_scale_history(model) # optimizer step optimizers.step() lr_schedulers.step() - # when fp8 config is on, # calculate float8 dynamic amax/scale for all-parameter for FSDP2 # it issues a single all-reduce for all parameters at once for better performance - maybe_precompute_fp8_dynamic_scale_for_fsdp(whole_model, job_config) + float8_handler.precompute_fp8_dynamic_scale_for_fsdp(model) optimizers.zero_grad() print(f"Peak Memory at iter: {iter_idx}") fsdp_memtracker.display_snapshot("peak", units="MiB", tabulate=True) diff --git a/torchtitan/config_manager.py b/torchtitan/config_manager.py index dd5ba7cde2..2bc37bfbf0 100644 --- a/torchtitan/config_manager.py +++ b/torchtitan/config_manager.py @@ -348,46 +348,6 @@ def __init__(self): action="store_true", help="Whether to compile the model", ) - self.parser.add_argument( - "--training.enable_float8_linear", - action="store_true", - help=""" - If true, swaps `torch.nn.Linear` with `Float8Linear`. - This feature requires you to install 'torchao' which can be found - here: https://github.com/pytorch/ao - """, - ) - self.parser.add_argument( - "--training.enable_fsdp_float8_all_gather", - action="store_true", - default=False, - help="Whether enable float8 all-gather in FSDP", - ) - self.parser.add_argument( - "--training.precompute_float8_dynamic_scale_for_fsdp", - action="store_true", - default=False, - help="Whether precompute float8 scales dynamically for FSDP", - ) - self.parser.add_argument( - "--training.float8_scaling_type_input", - type=str, - default="dynamic", - help="float8 scaling for input, dynamic (default) or delayed", - choices=["dynamic", "delayed"], - ) - self.parser.add_argument( - "--training.float8_scaling_type_weight", - type=str, - default="dynamic", - help="float8 scaling for input, dynamic (default) or delayed", - ) - self.parser.add_argument( - "--training.float8_scaling_type_grad_output", - type=str, - default="dynamic", - help="float8 scaling for input, dynamic (default) or delayed", - ) self.parser.add_argument( "--training.gc_freq", type=int, @@ -483,6 +443,7 @@ def __init__(self): 0 is the default value. """, ) + # activation checkpointing configs self.parser.add_argument( "--activation_checkpoint.mode", @@ -500,6 +461,48 @@ def __init__(self): """, ) + # float8 configs + self.parser.add_argument( + "--float8.enable_float8_linear", + action="store_true", + help=""" + If true, swaps `torch.nn.Linear` with `Float8Linear`. + This feature requires you to install 'torchao' which can be found + here: https://github.com/pytorch/ao + """, + ) + self.parser.add_argument( + "--float8.enable_fsdp_float8_all_gather", + action="store_true", + default=False, + help="Whether enable float8 all-gather in FSDP", + ) + self.parser.add_argument( + "--float8.precompute_float8_dynamic_scale_for_fsdp", + action="store_true", + default=False, + help="Whether precompute float8 scales dynamically for FSDP", + ) + self.parser.add_argument( + "--float8.scaling_type_input", + type=str, + default="dynamic", + help="float8 scaling for input, dynamic (default) or delayed", + choices=["dynamic", "delayed"], + ) + self.parser.add_argument( + "--float8.scaling_type_weight", + type=str, + default="dynamic", + help="float8 scaling for input, dynamic (default) or delayed", + ) + self.parser.add_argument( + "--float8.scaling_type_grad_output", + type=str, + default="dynamic", + help="float8 scaling for input, dynamic (default) or delayed", + ) + # communications library settings self.parser.add_argument( "--comm.init_timeout_seconds", diff --git a/torchtitan/float8_linear.py b/torchtitan/float8_linear.py index fa311061d9..494b6046b7 100644 --- a/torchtitan/float8_linear.py +++ b/torchtitan/float8_linear.py @@ -12,127 +12,128 @@ # Note: Performance # Float8 experimental is intended to be ran under `torch.compile`` for competitive performance -import functools -from typing import Optional import torch import torch.nn as nn -from torch._logging import warning_once from torchtitan.config_manager import JobConfig from torchtitan.logging import logger +from torchtitan.parallelisms import ParallelDims -@functools.lru_cache(None) def is_sm90_or_later(): # Float8 is only supported on H100+ GPUs return torch.cuda.is_available() and torch.cuda.get_device_capability() >= (9, 0) -def maybe_build_fp8_linear( - model: nn.Module, job_config: JobConfig, dp_enabled: Optional[bool] = False -): - """ - This function converts the linear layers to `Float8Linear`. Note that today, - only dynamic tensor scaling (the default) is supported. - - This will mutate the model inplace. - """ - enable_float8_linear = job_config.training.enable_float8_linear - if not enable_float8_linear: - return - if not is_sm90_or_later(): - warning_once( - logger, - "Failed to swap to Float8Linear because SM90 or later is not available", - ) - return - try: - from torchao.float8 import ( - CastConfig, - convert_to_float8_training, - Float8LinearConfig, - ScalingType, - ) +class Float8Handler: + def __init__(self, job_config: JobConfig, parallel_dims: ParallelDims): + self.enabled = False + + float8_config = job_config.float8 + if not float8_config.enable_float8_linear: + return + if not is_sm90_or_later(): + logger.warning( + "Failed to swap to Float8Linear because SM90 or later is not available", + ) + return + try: + from torchao.float8 import CastConfig, Float8LinearConfig, ScalingType + except ImportError as e: + raise ImportError( + "torchao is not installed. Please install it to use fp8 linear layers." + ) from e # Mutates the model inplace replacing instances of torch.nn.Linear with Float8Linear enable_fsdp_float8_all_gather = ( - job_config.training.enable_fsdp_float8_all_gather and dp_enabled - ) - scaling_type_input = ScalingType(job_config.training.float8_scaling_type_input) - scaling_type_weight = ScalingType( - job_config.training.float8_scaling_type_weight + parallel_dims.dp_enabled + and parallel_dims.dp_type == "fsdp" + and float8_config.enable_fsdp_float8_all_gather ) - scaling_type_grad_output = ScalingType( - job_config.training.float8_scaling_type_grad_output - ) - float8_config = Float8LinearConfig( + scaling_type_input = ScalingType(float8_config.scaling_type_input) + scaling_type_weight = ScalingType(float8_config.scaling_type_weight) + scaling_type_grad_output = ScalingType(float8_config.scaling_type_grad_output) + self.config = Float8LinearConfig( enable_fsdp_float8_all_gather=enable_fsdp_float8_all_gather, cast_config_input=CastConfig(scaling_type=scaling_type_input), cast_config_weight=CastConfig(scaling_type=scaling_type_weight), cast_config_grad_output=CastConfig(scaling_type=scaling_type_grad_output), enable_pre_and_post_forward=False, ) + + self.enabled = True + + # for precompute_fp8_dynamic_scale_for_fsdp + self.precompute_scale = ( + enable_fsdp_float8_all_gather + and float8_config.precompute_float8_dynamic_scale_for_fsdp + ) + + # for sync_float8_amax_and_scale_history + self.delayed_scaling = ( + scaling_type_input == "delayed" + or scaling_type_weight == "delayed" + or scaling_type_grad_output == "delayed" + ) + self._sync_float8_amax_and_scale_history = None + self.compile = job_config.training.compile + + logger.info("Float8 training active") + + def convert_to_float8_training(self, model: nn.Module): + """ + This function converts the linear layers of `model` to `Float8Linear`. + Note that today, only dynamic tensor scaling (the default) is supported. + This will mutate the model inplace. + """ + if not self.enabled: + return + + from torchao.float8 import convert_to_float8_training + + # Mutates the model inplace replacing instances of nn.Linear with Float8Linear convert_to_float8_training( model, - config=float8_config, + config=self.config, module_filter_fn=lambda mod, fqn: fqn != "output", ) logger.info( - f"Swapped to Float8Linear layers with {enable_fsdp_float8_all_gather=}" + "Swapped to Float8Linear layers with enable_fsdp_float8_all_gather=" + f"{self.config.enable_fsdp_float8_all_gather}" ) - except ImportError as exc: - raise ImportError( - "torchao is not installed. Please install it to use fp8 linear layers." - ) from exc - - -def maybe_precompute_fp8_dynamic_scale_for_fsdp( - model: nn.Module, job_config: JobConfig -): - if not ( - job_config.training.enable_float8_linear - and job_config.training.enable_fsdp_float8_all_gather - and job_config.training.precompute_float8_dynamic_scale_for_fsdp - ): - return - if not is_sm90_or_later(): - warning_once( - logger, - "Skipped precomputing fp8 scales because SM90 or later is not available", - ) - return - from torchao.float8 import precompute_float8_dynamic_scale_for_fsdp - precompute_float8_dynamic_scale_for_fsdp(model) + def precompute_fp8_dynamic_scale_for_fsdp(self, model: nn.Module): + if not self.enabled: + return + if not self.precompute_scale: + return -_sync_float8_amax_and_scale_history = None + from torchao.float8 import precompute_float8_dynamic_scale_for_fsdp + precompute_float8_dynamic_scale_for_fsdp(model) -def maybe_sync_float8_amax_and_scale_history(model: nn.Module, job_config: JobConfig): - if not ( - job_config.training.enable_float8_linear - and ( - job_config.training.float8_scaling_type_input == "delayed" - or job_config.training.float8_scaling_type_weight == "delayed" - or job_config.training.float8_scaling_type_grad_output == "delayed" - ) - ): - return + def sync_float8_amax_and_scale_history(self, model: nn.Module): + if not self.enabled: + return - from torchao.float8 import sync_float8_amax_and_scale_history + if not self.delayed_scaling: + return - # TODO(future): see if precalculating the modules to sync over is going to - # meaningfully help performance + from torchao.float8 import sync_float8_amax_and_scale_history - global _sync_float8_amax_and_scale_history - if _sync_float8_amax_and_scale_history is None: - if job_config.training.compile: - _sync_float8_amax_and_scale_history = torch.compile( - sync_float8_amax_and_scale_history - ) - else: - _sync_float8_amax_and_scale_history = sync_float8_amax_and_scale_history + # TODO(vkuzo): see if precalculating the modules to sync over is going to + # meaningfully help performance + + if self._sync_float8_amax_and_scale_history is None: + if self.compile: + self._sync_float8_amax_and_scale_history = torch.compile( + sync_float8_amax_and_scale_history + ) + else: + self._sync_float8_amax_and_scale_history = ( + sync_float8_amax_and_scale_history + ) - sync_float8_amax_and_scale_history(model) + self._sync_float8_amax_and_scale_history(model) diff --git a/torchtitan/parallelisms/parallelize_llama.py b/torchtitan/parallelisms/parallelize_llama.py index e86f93b98d..bdafc8e215 100644 --- a/torchtitan/parallelisms/parallelize_llama.py +++ b/torchtitan/parallelisms/parallelize_llama.py @@ -541,7 +541,7 @@ def parallelize_llama( model, world_mesh["tp"], loss_parallel=parallel_dims.loss_parallel_enabled, - enable_float8=job_config.training.enable_float8_linear, + enable_float8=job_config.float8.enable_float8_linear, enable_async_tp=job_config.experimental.enable_async_tensor_parallel, ) diff --git a/train.py b/train.py index 92e29058b7..615ed4e344 100644 --- a/train.py +++ b/train.py @@ -15,11 +15,7 @@ from torchtitan.checkpoint import CheckpointManager, TrainState from torchtitan.config_manager import JobConfig from torchtitan.datasets import build_hf_data_loader, build_tokenizer -from torchtitan.float8_linear import ( - maybe_build_fp8_linear, - maybe_precompute_fp8_dynamic_scale_for_fsdp, - maybe_sync_float8_amax_and_scale_history, -) +from torchtitan.float8_linear import Float8Handler from torchtitan.logging import init_logger, logger from torchtitan.metrics import build_gpu_memory_monitor, build_metric_logger from torchtitan.models import model_name_to_cls, model_name_to_tokenizer, models_config @@ -120,8 +116,10 @@ def main(job_config: JobConfig): with torch.device("meta"): whole_model = model_cls.from_model_args(model_config) + # a no-op hander if fp8 is not enabled + float8_handler = Float8Handler(job_config, parallel_dims) # swap to Float8Linear base on fp8 config - maybe_build_fp8_linear(whole_model, job_config, parallel_dims.dp_enabled) + float8_handler.convert_to_float8_training(whole_model) # log model size model_param_count = utils.get_num_params(whole_model) @@ -307,18 +305,17 @@ def loss_fn(pred, labels): model.parameters(), job_config.training.max_norm, foreach=True ) - # if float8 is enabled, sync float8 amaxes and scales - maybe_sync_float8_amax_and_scale_history(model, job_config) + # sync float8 amaxes and scales + float8_handler.sync_float8_amax_and_scale_history(model) # optimizer step checkpoint.maybe_wait_for_staging() optimizers.step() lr_schedulers.step() - # when float8 config is on, # calculate float8 dynamic amax/scale for all-parameter for FSDP2 # it issues a single all-reduce for all parameters at once for better performance - maybe_precompute_fp8_dynamic_scale_for_fsdp(model, job_config) + float8_handler.precompute_fp8_dynamic_scale_for_fsdp(model) losses_since_last_log.append(loss) diff --git a/train_configs/debug_model.toml b/train_configs/debug_model.toml index b36e9d0c7b..7d4187dc35 100644 --- a/train_configs/debug_model.toml +++ b/train_configs/debug_model.toml @@ -37,7 +37,6 @@ max_norm = 1.0 # grad norm clipping steps = 10 data_parallel_degree = -1 tensor_parallel_degree = 1 -enable_float8_linear = false compile = false dataset = "c4_mini" # supported datasets: c4_mini (45K), c4 (177M) @@ -57,3 +56,6 @@ async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"] [activation_checkpoint] mode = 'selective' # ['none', 'selective', 'full'] selective_ac_option = '2' # 'int' = ac every positive int layer or 'op', ac based on ops policy + +[float8] +enable_float8_linear = false diff --git a/train_configs/llama2_13b.toml b/train_configs/llama2_13b.toml index 2dc29f2e4d..4727f965fb 100644 --- a/train_configs/llama2_13b.toml +++ b/train_configs/llama2_13b.toml @@ -33,7 +33,6 @@ max_norm = 1.0 # grad norm clipping steps = 1000 data_parallel_degree = -1 tensor_parallel_degree = 1 -enable_float8_linear = false compile = false dataset = "c4" @@ -52,3 +51,6 @@ async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"] [activation_checkpoint] mode = 'selective' # ['none', 'selective', 'full'] selective_ac_option = 'op' # 'int' = ac every positive int layer or 'op', ac based on ops policy + +[float8] +enable_float8_linear = false diff --git a/train_configs/llama2_70b.toml b/train_configs/llama2_70b.toml index f17496c51b..83114876d1 100644 --- a/train_configs/llama2_70b.toml +++ b/train_configs/llama2_70b.toml @@ -33,7 +33,6 @@ max_norm = 1.0 # grad norm clipping steps = 1000 data_parallel_degree = -1 tensor_parallel_degree = 8 # 8-way TP -enable_float8_linear = false compile = false dataset = "c4" @@ -51,3 +50,6 @@ async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"] [activation_checkpoint] mode = 'full' # ['none', 'selective', 'full'] + +[float8] +enable_float8_linear = false diff --git a/train_configs/llama2_7b.toml b/train_configs/llama2_7b.toml index 69ae7285e2..22ab6c7601 100644 --- a/train_configs/llama2_7b.toml +++ b/train_configs/llama2_7b.toml @@ -32,7 +32,6 @@ max_norm = 1.0 # grad norm clipping steps = 1000 data_parallel_degree = -1 tensor_parallel_degree = 1 # dp-only would be sufficient for 7B -enable_float8_linear = false compile = false dataset = "c4" @@ -51,3 +50,6 @@ async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"] [activation_checkpoint] mode = 'selective' # ['none', 'selective', 'full'] selective_ac_option = 'op' # 'int' = ac every positive int layer or 'op', ac based on ops policy + +[float8] +enable_float8_linear = false diff --git a/train_configs/llama3_70b.toml b/train_configs/llama3_70b.toml index 660f2c0b17..62d75dfb62 100644 --- a/train_configs/llama3_70b.toml +++ b/train_configs/llama3_70b.toml @@ -33,7 +33,6 @@ max_norm = 1.0 # grad norm clipping steps = 1000 data_parallel_degree = -1 tensor_parallel_degree = 8 # 8-way TP -enable_float8_linear = false compile = false dataset = "c4" @@ -51,3 +50,6 @@ async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"] [activation_checkpoint] mode = 'full' + +[float8] +enable_float8_linear = false diff --git a/train_configs/llama3_8b.toml b/train_configs/llama3_8b.toml index 7e5ac63c2f..517dd81ee6 100644 --- a/train_configs/llama3_8b.toml +++ b/train_configs/llama3_8b.toml @@ -33,7 +33,6 @@ max_norm = 1.0 # grad norm clipping steps = 1000 data_parallel_degree = -1 tensor_parallel_degree = 1 -enable_float8_linear = false compile = false dataset = "c4" @@ -52,3 +51,6 @@ async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"] [activation_checkpoint] mode = 'selective' # ['none', 'selective', 'full'] selective_ac_option = 'op' # 'int' = ac every positive int layer or 'op', ac based on ops policy + +[float8] +enable_float8_linear = false