diff --git a/src/python/torchdistx/optimizers/anyprecision_optimizer.py b/src/python/torchdistx/optimizers/anyprecision_optimizer.py index ef9c090..e36eb60 100644 --- a/src/python/torchdistx/optimizers/anyprecision_optimizer.py +++ b/src/python/torchdistx/optimizers/anyprecision_optimizer.py @@ -8,11 +8,15 @@ # with optional Kahan summation for high precision weight updates. # Allows direct control over momentum, variance and auxiliary compensation # buffer dtypes. -# Optional Kahan summation is used to offset precision reduction for -# the weight updates. This allows full training in BFloat16 (equal or -# better than FP32 results in many cases) due to high precision weight upates. +# Optional Kahan summation is used to enable high precision for +# the weight updates. This allows sucessful training in pure BFloat16 +# (often equal or better than FP32 results) due to high precision weight +# updates, while training with reduced GPU memory and +# increased training speed. import torch +import torch.cuda.nccl as nccl +import torch.distributed as dist from torch.optim.optimizer import Optimizer @@ -31,33 +35,61 @@ def __init__( ): """ Args: - params (iterable): iterable of parameters to optimize or dicts defining - parameter groups - lr (float, optional): learning rate (default: 1e-3) - betas (Tuple[float, float], optional): coefficients used for computing - running averages of gradient and its square (default: (0.9, 0.999)) - eps (float, optional): term added to the denominator to improve - numerical stability (default: 1e-8) - weight_decay (float, optional): weight decay coefficient (default: 1e-2) - - # Any Precision specific - use_kahan_summation = creates auxiliary buffer to ensure high precision - model param updates (default: False) - momentum_dtype = dtype for momentum (default: BFloat32) - variance_dtype = dtype for uncentered variance (default: BFloat16) - compensation_buffer_dtype = dtype for Kahan summation - buffer (default: BFloat16). Only used if - ``use_kahan_summation=True``. - - # Usage - This optimizer implements optimizer states, and Kahan summation - for high precision updates, all in user controlled dtypes. - Defaults are variance in BF16, Momentum in FP32. - This can be run in FSDP mixed precision, amp, or full precision, - depending on what training pipeline you wish to work with. - - Setting to use_kahan_summation = False, and changing momentum and - variance dtypes to FP32, reverts this to a standard AdamW optimizer. + params (iterable): iterable of parameters to optimize or dicts defining + parameter groups + lr (float, optional): learning rate (default: 1e-3) + betas (Tuple[float, float], optional): coefficients used for computing + running averages of gradient and its square (default: (0.9, 0.999)) + eps (float, optional): term added to the denominator to improve + numerical stability (default: 1e-8) + weight_decay (float, optional): weight decay coefficient (default: 1e-2) + + # AnyPrecision specific + use_kahan_summation = use auxiliary buffer to ensure high precision + model param updates (default: False) + momentum_dtype = dtype for momentum (default: BFloat32) + variance_dtype = dtype for uncentered variance (default: BFloat16) + compensation_buffer_dtype = dtype for Kahan summation + buffer (default: BFloat16) + + # Usage + This optimizer implements optimizer states, and Kahan summation + for high precision updates, all in user controlled dtypes. + The high precision updates enable successful training in pure + BF16 with corresponding reductions in memory and increases in + training speed. + + Defaults are Variance in BF16, Momentum in FP32. + This can be run in FSDP mixed precision, amp, or full precision, + depending on what training pipeline you wish to work with. + + Setting to use_kahan_summation = False, and changing momentum and + variance dtypes to FP32, reverts this to a standard AdamW optimizer. + + AnyPrecision will automatically verify proper support is present + for BF16, for both GPU and network (NCCL). + + To train in pure BF16: + 1 - use model.to(torch.bfloat16) to move your model + to BF16. + 2 - Set momentum_dtype and variance_dtype to torch.bfloat16 + 3 - Set use_kahan_summation = True + + Example: + # init model + my_model = build_model(config_args) + + # ensure model is moved to all bf16 + my_model.to(torch.bfloat16) + + # setup AnyPrecision to run in pure BF16 with high precision updates + optimizer = AnyPrecisionAdamW(my_model.parameters(), lr=lr, ..., + momentum_dtype=torch.bfloat16, + variance_dtype=torch.bfloat16, + use_kahan_summation=True + ) + + """ defaults = dict( lr=lr, @@ -72,6 +104,28 @@ def __init__( super().__init__(params, defaults) + # confirm bfloat16 support if applicable + if ( + torch.bfloat16 + in [ + momentum_dtype, + variance_dtype, + ] + or torch.bfloat16 in [compensation_buffer_dtype] + and use_kahan_summation + ): + gpu_support, network_support = self._verify_bfloat_support() + + if not gpu_support or not network_support: + reason = "" + if not gpu_support: + reason += "Your GPU does not support native Bfloat16. " + + if not network_support: + reason += "Your NCCL version does not support BFloat16. " + + raise ValueError(f"Missing BFloat16 support. Details: {reason}") + @torch.no_grad() def step(self, closure=None): """Performs a single optimization step. @@ -180,3 +234,26 @@ def step(self, closure=None): else: # usual AdamW updates p.data.addcdiv_(exp_avg, centered_variance, value=-step_size) + + def _verify_bfloat_support( + self, + ): + """verify gpu and network support for BF16""" + # requires cuda >= 11.0 + required_cuda_major = 11 + + # requires nccl >= 2.10 + required_nccl_major = 2 + required_nccl_minor = 10 + + gpu_support = torch.version.cuda and torch.cuda.is_bf16_supported() + + cuda_version_major, _ = torch.version.cuda.split(".", maxsplit=1) + + network_support = ( + int(cuda_version_major) >= required_cuda_major + and dist.is_nccl_available() + and nccl.version() >= (required_nccl_major, required_nccl_minor) + ) + + return gpu_support, network_support