Skip to content

Commit 04a6f47

Browse files
awaelchliBorda
authored andcommitted
Remove the redundant precision attribute from LightningModule (#16203)
1 parent a59d497 commit 04a6f47

File tree

7 files changed

+10
-15
lines changed

7 files changed

+10
-15
lines changed

src/pytorch_lightning/CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
6464
* Removed the `FitLoop.split_idx` property
6565
* Removed the `LoggerConnector.on_train_split_start` method
6666

67+
- Removed the `LightningModule.precision` attribute ([#16203](https://github.com/Lightning-AI/lightning/pull/16203))
68+
6769

6870
### Fixed
6971

src/pytorch_lightning/core/module.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -108,9 +108,6 @@ def __init__(self, *args: Any, **kwargs: Any) -> None:
108108
# pointer to the trainer object
109109
self._trainer: Optional["pl.Trainer"] = None
110110

111-
# the precision used
112-
self.precision: Union[int, str] = 32
113-
114111
# optionally can be set by user
115112
self._example_input_array: Optional[Union[Tensor, Tuple, Dict]] = None
116113
self._current_fx_name: Optional[str] = None

src/pytorch_lightning/trainer/connectors/data_connector.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -144,13 +144,8 @@ def attach_data(
144144
elif self.trainer.state.fn == TrainerFn.PREDICTING:
145145
_check_dataloader_none(predict_dataloaders, self._predict_dataloader_source, self.trainer.state.fn)
146146

147-
# set local properties on the model
148-
self._copy_trainer_model_properties(model)
149-
150-
def _copy_trainer_model_properties(self, model: "pl.LightningModule") -> None:
147+
# Attach the trainer to the LightningModule
151148
model.trainer = proxy(self.trainer)
152-
# for backward compatibility
153-
model.precision = int(self.trainer.precision) if self.trainer.precision != "bf16" else "bf16"
154149

155150
def attach_dataloaders(
156151
self,

src/pytorch_lightning/utilities/model_summary/model_summary.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -189,7 +189,8 @@ def __init__(self, model: "pl.LightningModule", max_depth: int = 1) -> None:
189189
self._layer_summary = self.summarize()
190190
# 1 byte -> 8 bits
191191
# TODO: how do we compute precision_megabytes in case of mixed precision?
192-
precision = self._model.precision if isinstance(self._model.precision, int) else 32
192+
precision_to_bits = {"64": 64, "32": 32, "16": 16, "bf16": 16}
193+
precision = precision_to_bits.get(self._model.trainer.precision, 32) if self._model._trainer else 32
193194
self._precision_megabytes = (precision / 8.0) * 1e-6
194195

195196
@property

tests/tests_pytorch/accelerators/test_ipu.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -189,7 +189,7 @@ def test_optimization(tmpdir):
189189
def test_half_precision(tmpdir):
190190
class TestCallback(Callback):
191191
def setup(self, trainer: Trainer, pl_module: LightningModule, stage: str) -> None:
192-
assert trainer.strategy.model.precision == 16
192+
assert trainer.precision == "16"
193193
raise SystemExit
194194

195195
model = IPUModel()

tests/tests_pytorch/plugins/precision/hpu/test_hpu.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ def test_precision_plugin(hmp_params):
4242
def test_mixed_precision(tmpdir, hmp_params: dict):
4343
class TestCallback(Callback):
4444
def setup(self, trainer: Trainer, pl_module: LightningModule, stage: str) -> None:
45-
assert trainer.strategy.model.precision == "bf16"
45+
assert trainer.precision == "bf16"
4646
raise SystemExit
4747

4848
model = BoringModel()
@@ -65,7 +65,7 @@ def setup(self, trainer: Trainer, pl_module: LightningModule, stage: str) -> Non
6565
def test_pure_half_precision(tmpdir, hmp_params: dict):
6666
class TestCallback(Callback):
6767
def on_train_start(self, trainer: Trainer, pl_module: LightningModule) -> None:
68-
assert trainer.strategy.model.precision == "16"
68+
assert trainer.precision == "16"
6969
for param in trainer.strategy.model.parameters():
7070
assert param.dtype == torch.float16
7171
raise SystemExit

tests/tests_pytorch/strategies/test_ddp_fully_sharded_native.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ def on_predict_batch_end(self, outputs, batch, batch_idx, dataloader_idx) -> Non
6464
def _assert_layer_fsdp_instance(self) -> None:
6565
assert isinstance(self.layer, FullyShardedDataParallel)
6666
assert isinstance(self.trainer.strategy.precision_plugin, FullyShardedNativeNativeMixedPrecisionPlugin)
67-
precision = torch.float16 if self.precision == 16 else torch.bfloat16
67+
precision = torch.float16 if self.trainer.precision == "16" else torch.bfloat16
6868
assert self.layer.mixed_precision.param_dtype == precision
6969
assert self.layer.mixed_precision.reduce_dtype == precision
7070
assert self.layer.mixed_precision.buffer_dtype == precision
@@ -100,7 +100,7 @@ def _assert_layer_fsdp_instance(self) -> None:
100100
assert isinstance(self.layer, torch.nn.Sequential)
101101
assert isinstance(self.trainer.strategy.precision_plugin, FullyShardedNativeNativeMixedPrecisionPlugin)
102102

103-
precision = torch.float16 if self.precision == 16 else torch.bfloat16
103+
precision = torch.float16 if self.trainer.precision == "16" else torch.bfloat16
104104
for layer_num in [0, 2]:
105105
assert isinstance(self.layer[layer_num], FullyShardedDataParallel)
106106
assert self.layer[layer_num].mixed_precision.param_dtype == precision

0 commit comments

Comments
 (0)