diff --git a/CHANGELOG.md b/CHANGELOG.md index 3227bd90555f9..2f6d4692ec076 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -188,6 +188,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed support for `torch.nn.Module` type hints in `LightningCLI` ([#7807](https://github.com/PyTorchLightning/pytorch-lightning/pull/7807)) +- Fixed a bug where checking `trainer.precision` changed to `'mixed'` when specifying 16 in trainer ([#7825](https://github.com/PyTorchLightning/pytorch-lightning/pull/7825)) + + ## [1.3.2] - 2021-05-18 ### Changed diff --git a/pytorch_lightning/plugins/training_type/sharded.py b/pytorch_lightning/plugins/training_type/sharded.py index 02da937286dcc..fceafddd66ec0 100644 --- a/pytorch_lightning/plugins/training_type/sharded.py +++ b/pytorch_lightning/plugins/training_type/sharded.py @@ -54,7 +54,8 @@ def _reinit_optimizers_with_oss(self): optim_class = type(optimizer) zero_optimizer = OSS(params=optimizer.param_groups, optim=optim_class, **optimizer.defaults) if _FAIRSCALE_OSS_FP16_BROADCAST_AVAILABLE: - is_fp16 = self.lightning_module.trainer.precision == 16 + precision = self.lightning_module.trainer.precision + is_fp16 = precision in ("mixed", 16) # For multi-node training, compressing the model shards in fp16 before broadcasting # improves performance. When using PyTorch AMP, it will not degrade # the model performance.