|
12 | 12 |
|
13 | 13 | # Note: Performance |
14 | 14 | # Float8 experimental is intended to be ran under `torch.compile`` for competitive performance |
15 | | -import functools |
16 | | -from typing import Optional |
17 | 15 |
|
18 | 16 | import torch |
19 | 17 | import torch.nn as nn |
20 | | -from torch._logging import warning_once |
21 | 18 |
|
22 | 19 | from torchtitan.config_manager import JobConfig |
23 | 20 | from torchtitan.logging import logger |
| 21 | +from torchtitan.parallelisms import ParallelDims |
24 | 22 |
|
25 | 23 |
|
26 | | -@functools.lru_cache(None) |
27 | 24 | def is_sm90_or_later(): |
28 | 25 | # Float8 is only supported on H100+ GPUs |
29 | 26 | return torch.cuda.is_available() and torch.cuda.get_device_capability() >= (9, 0) |
30 | 27 |
|
31 | 28 |
|
32 | | -def maybe_build_fp8_linear( |
33 | | - model: nn.Module, job_config: JobConfig, dp_enabled: Optional[bool] = False |
34 | | -): |
35 | | - """ |
36 | | - This function converts the linear layers to `Float8Linear`. Note that today, |
37 | | - only dynamic tensor scaling (the default) is supported. |
38 | | -
|
39 | | - This will mutate the model inplace. |
40 | | - """ |
41 | | - enable_float8_linear = job_config.training.enable_float8_linear |
42 | | - if not enable_float8_linear: |
43 | | - return |
44 | | - if not is_sm90_or_later(): |
45 | | - warning_once( |
46 | | - logger, |
47 | | - "Failed to swap to Float8Linear because SM90 or later is not available", |
48 | | - ) |
49 | | - return |
50 | | - try: |
51 | | - from torchao.float8 import ( |
52 | | - CastConfig, |
53 | | - convert_to_float8_training, |
54 | | - Float8LinearConfig, |
55 | | - ScalingType, |
56 | | - ) |
| 29 | +class Float8Handler: |
| 30 | + def __init__(self, job_config: JobConfig, parallel_dims: ParallelDims): |
| 31 | + self.enabled = False |
| 32 | + |
| 33 | + float8_config = job_config.float8 |
| 34 | + if not float8_config.enable_float8_linear: |
| 35 | + return |
| 36 | + if not is_sm90_or_later(): |
| 37 | + logger.warning( |
| 38 | + "Failed to swap to Float8Linear because SM90 or later is not available", |
| 39 | + ) |
| 40 | + return |
| 41 | + try: |
| 42 | + from torchao.float8 import CastConfig, Float8LinearConfig, ScalingType |
| 43 | + except ImportError as e: |
| 44 | + raise ImportError( |
| 45 | + "torchao is not installed. Please install it to use fp8 linear layers." |
| 46 | + ) from e |
57 | 47 |
|
58 | 48 | # Mutates the model inplace replacing instances of torch.nn.Linear with Float8Linear |
59 | 49 | enable_fsdp_float8_all_gather = ( |
60 | | - job_config.training.enable_fsdp_float8_all_gather and dp_enabled |
61 | | - ) |
62 | | - scaling_type_input = ScalingType(job_config.training.float8_scaling_type_input) |
63 | | - scaling_type_weight = ScalingType( |
64 | | - job_config.training.float8_scaling_type_weight |
| 50 | + parallel_dims.dp_enabled |
| 51 | + and parallel_dims.dp_type == "fsdp" |
| 52 | + and float8_config.enable_fsdp_float8_all_gather |
65 | 53 | ) |
66 | | - scaling_type_grad_output = ScalingType( |
67 | | - job_config.training.float8_scaling_type_grad_output |
68 | | - ) |
69 | | - float8_config = Float8LinearConfig( |
| 54 | + scaling_type_input = ScalingType(float8_config.scaling_type_input) |
| 55 | + scaling_type_weight = ScalingType(float8_config.scaling_type_weight) |
| 56 | + scaling_type_grad_output = ScalingType(float8_config.scaling_type_grad_output) |
| 57 | + self.config = Float8LinearConfig( |
70 | 58 | enable_fsdp_float8_all_gather=enable_fsdp_float8_all_gather, |
71 | 59 | cast_config_input=CastConfig(scaling_type=scaling_type_input), |
72 | 60 | cast_config_weight=CastConfig(scaling_type=scaling_type_weight), |
73 | 61 | cast_config_grad_output=CastConfig(scaling_type=scaling_type_grad_output), |
74 | 62 | enable_pre_and_post_forward=False, |
75 | 63 | ) |
| 64 | + |
| 65 | + self.enabled = True |
| 66 | + |
| 67 | + # for precompute_fp8_dynamic_scale_for_fsdp |
| 68 | + self.precompute_scale = ( |
| 69 | + enable_fsdp_float8_all_gather |
| 70 | + and float8_config.precompute_float8_dynamic_scale_for_fsdp |
| 71 | + ) |
| 72 | + |
| 73 | + # for sync_float8_amax_and_scale_history |
| 74 | + self.delayed_scaling = ( |
| 75 | + scaling_type_input == "delayed" |
| 76 | + or scaling_type_weight == "delayed" |
| 77 | + or scaling_type_grad_output == "delayed" |
| 78 | + ) |
| 79 | + self._sync_float8_amax_and_scale_history = None |
| 80 | + self.compile = job_config.training.compile |
| 81 | + |
| 82 | + logger.info("Float8 training active") |
| 83 | + |
| 84 | + def convert_to_float8_training(self, model: nn.Module): |
| 85 | + """ |
| 86 | + This function converts the linear layers of `model` to `Float8Linear`. |
| 87 | + Note that today, only dynamic tensor scaling (the default) is supported. |
| 88 | + This will mutate the model inplace. |
| 89 | + """ |
| 90 | + if not self.enabled: |
| 91 | + return |
| 92 | + |
| 93 | + from torchao.float8 import convert_to_float8_training |
| 94 | + |
| 95 | + # Mutates the model inplace replacing instances of nn.Linear with Float8Linear |
76 | 96 | convert_to_float8_training( |
77 | 97 | model, |
78 | | - config=float8_config, |
| 98 | + config=self.config, |
79 | 99 | module_filter_fn=lambda mod, fqn: fqn != "output", |
80 | 100 | ) |
81 | 101 | logger.info( |
82 | | - f"Swapped to Float8Linear layers with {enable_fsdp_float8_all_gather=}" |
| 102 | + "Swapped to Float8Linear layers with enable_fsdp_float8_all_gather=" |
| 103 | + f"{self.config.enable_fsdp_float8_all_gather}" |
83 | 104 | ) |
84 | | - except ImportError as exc: |
85 | | - raise ImportError( |
86 | | - "torchao is not installed. Please install it to use fp8 linear layers." |
87 | | - ) from exc |
88 | | - |
89 | | - |
90 | | -def maybe_precompute_fp8_dynamic_scale_for_fsdp( |
91 | | - model: nn.Module, job_config: JobConfig |
92 | | -): |
93 | | - if not ( |
94 | | - job_config.training.enable_float8_linear |
95 | | - and job_config.training.enable_fsdp_float8_all_gather |
96 | | - and job_config.training.precompute_float8_dynamic_scale_for_fsdp |
97 | | - ): |
98 | | - return |
99 | | - if not is_sm90_or_later(): |
100 | | - warning_once( |
101 | | - logger, |
102 | | - "Skipped precomputing fp8 scales because SM90 or later is not available", |
103 | | - ) |
104 | | - return |
105 | | - from torchao.float8 import precompute_float8_dynamic_scale_for_fsdp |
106 | 105 |
|
107 | | - precompute_float8_dynamic_scale_for_fsdp(model) |
| 106 | + def precompute_fp8_dynamic_scale_for_fsdp(self, model: nn.Module): |
| 107 | + if not self.enabled: |
| 108 | + return |
108 | 109 |
|
| 110 | + if not self.precompute_scale: |
| 111 | + return |
109 | 112 |
|
110 | | -_sync_float8_amax_and_scale_history = None |
| 113 | + from torchao.float8 import precompute_float8_dynamic_scale_for_fsdp |
111 | 114 |
|
| 115 | + precompute_float8_dynamic_scale_for_fsdp(model) |
112 | 116 |
|
113 | | -def maybe_sync_float8_amax_and_scale_history(model: nn.Module, job_config: JobConfig): |
114 | | - if not ( |
115 | | - job_config.training.enable_float8_linear |
116 | | - and ( |
117 | | - job_config.training.float8_scaling_type_input == "delayed" |
118 | | - or job_config.training.float8_scaling_type_weight == "delayed" |
119 | | - or job_config.training.float8_scaling_type_grad_output == "delayed" |
120 | | - ) |
121 | | - ): |
122 | | - return |
| 117 | + def sync_float8_amax_and_scale_history(self, model: nn.Module): |
| 118 | + if not self.enabled: |
| 119 | + return |
123 | 120 |
|
124 | | - from torchao.float8 import sync_float8_amax_and_scale_history |
| 121 | + if not self.delayed_scaling: |
| 122 | + return |
125 | 123 |
|
126 | | - # TODO(future): see if precalculating the modules to sync over is going to |
127 | | - # meaningfully help performance |
| 124 | + from torchao.float8 import sync_float8_amax_and_scale_history |
128 | 125 |
|
129 | | - global _sync_float8_amax_and_scale_history |
130 | | - if _sync_float8_amax_and_scale_history is None: |
131 | | - if job_config.training.compile: |
132 | | - _sync_float8_amax_and_scale_history = torch.compile( |
133 | | - sync_float8_amax_and_scale_history |
134 | | - ) |
135 | | - else: |
136 | | - _sync_float8_amax_and_scale_history = sync_float8_amax_and_scale_history |
| 126 | + # TODO(vkuzo): see if precalculating the modules to sync over is going to |
| 127 | + # meaningfully help performance |
| 128 | + |
| 129 | + if self._sync_float8_amax_and_scale_history is None: |
| 130 | + if self.compile: |
| 131 | + self._sync_float8_amax_and_scale_history = torch.compile( |
| 132 | + sync_float8_amax_and_scale_history |
| 133 | + ) |
| 134 | + else: |
| 135 | + self._sync_float8_amax_and_scale_history = ( |
| 136 | + sync_float8_amax_and_scale_history |
| 137 | + ) |
137 | 138 |
|
138 | | - sync_float8_amax_and_scale_history(model) |
| 139 | + self._sync_float8_amax_and_scale_history(model) |
0 commit comments