diff --git a/torchao/float8/float8_linear_utils.py b/torchao/float8/float8_linear_utils.py index db9889567f..8ea6e2e23a 100644 --- a/torchao/float8/float8_linear_utils.py +++ b/torchao/float8/float8_linear_utils.py @@ -85,7 +85,7 @@ def convert_to_float8_training( module: nn.Module, *, module_filter_fn: Optional[Callable[[nn.Module, str], bool]] = None, - config: Float8LinearConfig = None, + config: Optional[Float8LinearConfig] = None, ) -> nn.Module: """ Swaps `torch.nn.Linear` in `module` with `Float8Linear`.