|
10 | 10 | import torch.distributed as dist |
11 | 11 | from torch.distributed._functional_collectives import AsyncCollectiveTensor, all_reduce |
12 | 12 |
|
13 | | -from torchao.float8.config import ( |
14 | | - Float8LinearConfig, |
15 | | - ScalingGranularity, |
16 | | - ScalingType, |
17 | | -) |
| 13 | +from torchao.float8.config import Float8LinearConfig, ScalingGranularity, ScalingType |
18 | 14 |
|
19 | 15 | # Helpful visualizer for debugging (only supports fp32): |
20 | 16 | # https://www.h-schmidt.net/FloatConverter/IEEE754.html |
|
33 | 29 |
|
34 | 30 |
|
35 | 31 | @torch.no_grad() |
36 | | -def amax_to_scale(amax: torch.Tensor, float8_dtype: torch.dtype): |
| 32 | +def amax_to_scale( |
| 33 | + amax: torch.Tensor, |
| 34 | + float8_dtype: torch.dtype, |
| 35 | + round_scales_to_power_of_2: bool = False, |
| 36 | +): |
37 | 37 | """Converts the amax value of a tensor to the fp8 scale. |
38 | 38 | Args: |
39 | 39 | amax: The amax value of the tensor. |
40 | 40 | float8_dtype: The float8 dtype. |
| 41 | + round_scales_to_power_of_2: if true, round scaling factor down to the nearest power of 2. |
41 | 42 | """ |
42 | 43 | # torch.compile and eager show different numerics for 1.0 / float32, |
43 | 44 | # upcast to float64 to ensure same numeric between compile and eager |
44 | 45 | amax = amax.to(torch.float64) |
45 | 46 | if float8_dtype in FP8_TYPES: |
46 | 47 | res = torch.finfo(float8_dtype).max / torch.clamp(amax, min=EPS) |
| 48 | + res = res.to(torch.float32) |
47 | 49 | else: |
48 | 50 | raise ValueError(f"Unsupported float8_dtype: {float8_dtype}") |
49 | | - |
50 | | - return res.to(torch.float32) |
| 51 | + if round_scales_to_power_of_2: |
| 52 | + res = _round_scale_down_to_power_of_2(res) |
| 53 | + return res |
51 | 54 |
|
52 | 55 |
|
53 | 56 | @torch.no_grad() |
@@ -119,21 +122,35 @@ def tensor_to_amax( |
119 | 122 |
|
120 | 123 | @torch.no_grad() |
121 | 124 | def tensor_to_scale( |
122 | | - x: torch.Tensor, |
| 125 | + hp_tensor: torch.Tensor, |
123 | 126 | float8_dtype: torch.dtype, |
124 | 127 | reduce_amax: bool = False, |
125 | 128 | device_mesh=None, |
126 | 129 | scaling_granularity: ScalingGranularity = ScalingGranularity.TENSORWISE, |
127 | 130 | axiswise_dim: Optional[int] = None, |
| 131 | + round_scales_to_power_of_2: bool = False, |
128 | 132 | ) -> torch.Tensor: |
| 133 | + """ |
| 134 | + Compute scaling factor for the given high precision tensor. |
| 135 | +
|
| 136 | + Args: |
| 137 | + hp_tensor: high precision tensor |
| 138 | + float8_dtype: the float8 dtype to use |
| 139 | + reduce_amax: whether to reduce the max(abs(hp_tensor)) value across distributed ranks |
| 140 | + scaling_granularity: Defines the scaling granularity |
| 141 | + axiswise_dim: if axiswise granularity is used, defines the dim to scale across |
| 142 | + round_scales_to_power_of_2: if true, round scaling factor down to the nearest power of 2. |
| 143 | + """ |
129 | 144 | amax = tensor_to_amax( |
130 | | - x, |
| 145 | + hp_tensor, |
131 | 146 | reduce_amax, |
132 | 147 | device_mesh, |
133 | 148 | scaling_granularity, |
134 | 149 | axiswise_dim, |
135 | 150 | ) |
136 | | - return amax_to_scale(amax, float8_dtype) |
| 151 | + return amax_to_scale( |
| 152 | + amax, float8_dtype, round_scales_to_power_of_2=round_scales_to_power_of_2 |
| 153 | + ) |
137 | 154 |
|
138 | 155 |
|
139 | 156 | def to_fp8_saturated(x: torch.Tensor, float8_dtype: torch.dtype): |
@@ -266,3 +283,8 @@ def config_has_stateful_scaling(config: Float8LinearConfig) -> bool: |
266 | 283 | or config.cast_config_weight.scaling_type != ScalingType.DYNAMIC |
267 | 284 | or config.cast_config_grad_output.scaling_type != ScalingType.DYNAMIC |
268 | 285 | ) |
| 286 | + |
| 287 | + |
| 288 | +def _round_scale_down_to_power_of_2(scale: torch.Tensor): |
| 289 | + assert scale.dtype == torch.float32, "scale must be float32 tensor" |
| 290 | + return torch.exp2(torch.floor(torch.log2(scale))) |
0 commit comments