|
21 | 21 |
|
22 | 22 | def build_fp8_linear(model: nn.Module, job_config: JobConfig): |
23 | 23 | """ |
24 | | - This function converts the linear layers to one of the fp8 types: |
25 | | - - Float8DynamicLinear: Dynamic quantization of the weights and the activations |
26 | | - - [Not Yet Supported] Float8Linear: Uses a history of amaxs to quantize the weights and activations |
| 24 | + This function converts the linear layers to `Float8Linear`. Note that today, |
| 25 | + only dynamic tensor scaling (the default) is supported. |
27 | 26 |
|
28 | 27 | This will mutate the model inplace. |
29 | 28 | """ |
30 | | - linear_type = job_config.training.fp8_linear.lower() |
| 29 | + use_fp8_linear = job_config.training.fp8_linear |
31 | 30 | try: |
32 | | - from float8_experimental.float8_dynamic_linear import Float8DynamicLinear |
| 31 | + # from float8_experimental.float8_dynamic_linear import Float8DynamicLinear |
33 | 32 |
|
34 | | - # from float8_experimental.float8_linear import Float8Linear |
| 33 | + from float8_experimental.float8_linear import Float8Linear |
35 | 34 | from float8_experimental.float8_linear_utils import ( |
36 | 35 | swap_linear_with_float8_linear, |
37 | 36 | ) |
| 37 | + import float8_experimental.config as config |
| 38 | + config.enable_fsdp_fp8_all_gather = True |
38 | 39 | except ImportError as exc: |
39 | 40 | raise ImportError( |
40 | 41 | "float8_experimental is not installed. Please install it to use fp8 linear layers." |
41 | 42 | ) from exc |
42 | | - if linear_type: |
43 | | - linear_type_map = { |
44 | | - # "delayed": Float8Linear, # TODO: add "delayed" option back in when supported |
45 | | - "dynamic": Float8DynamicLinear, |
46 | | - } |
47 | | - assert ( |
48 | | - linear_type in linear_type_map |
49 | | - ), f"Invalid fp8 linear type: {linear_type}, supported types: {', '.join(linear_type_map.keys())}." |
50 | | - float8_linear_type = linear_type_map[linear_type.lower()] |
51 | | - |
52 | | - # Mutates the model inplace replacing instances of torch.nn.Linear with float8_linear_type |
53 | | - swap_linear_with_float8_linear(model, float8_linear_type) |
54 | | - logger.info(f"Swapped to {linear_type} float8 linear layers") |
| 43 | + if use_fp8_linear: |
| 44 | + # Mutates the model inplace replacing instances of torch.nn.Linear with Float8Linear |
| 45 | + swap_linear_with_float8_linear(model, Float8Linear) |
| 46 | + logger.info(f"Swapped to Float8Linear layers") |
0 commit comments