Skip to content

Commit 74c55c4

Browse files
committed
add proper deprecation
1 parent 58d9b80 commit 74c55c4

File tree

3 files changed

+17
-4
lines changed

3 files changed

+17
-4
lines changed

pytorch_lightning/trainer/deprecated_api.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
from pytorch_lightning.core.lightning import LightningModule
1415
from pytorch_lightning.trainer.connectors.accelerator_connector import AcceleratorConnector
1516
from pytorch_lightning.trainer.states import RunningStage
1617
from pytorch_lightning.utilities import DeviceType, DistributedType, rank_zero_warn
@@ -23,6 +24,7 @@ class DeprecatedDistDeviceAttributes:
2324
_running_stage: RunningStage
2425
num_gpus: int
2526
accelerator_connector: AcceleratorConnector
27+
lightning_module = LightningModule
2628

2729
@property
2830
def on_cpu(self) -> bool:
@@ -130,3 +132,11 @@ def use_single_gpu(self, val: bool) -> None:
130132
)
131133
if val:
132134
self.accelerator_connector._device_type = DeviceType.GPU
135+
136+
def get_model(self) -> LightningModule:
137+
rank_zero_warn(
138+
"The use of `Trainer.get_model()` is deprecated in favor of `Trainer.lightning_module`"
139+
" and will be removed in v1.4.",
140+
DeprecationWarning,
141+
)
142+
return self.lightning_module

pytorch_lightning/trainer/properties.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -352,10 +352,6 @@ def model(self, model: torch.nn.Module) -> None:
352352
"""
353353
self.accelerator.model = model
354354

355-
def get_model(self) -> LightningModule:
356-
# backward compatible
357-
return self.lightning_module
358-
359355
@property
360356
def lightning_optimizers(self) -> List[LightningOptimizer]:
361357
if self._lightning_optimizers is None:

tests/deprecated_api/test_remove_1-4.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,13 @@
3030
from tests.helpers import BoringModel
3131

3232

33+
def test_v1_4_0_deprecated_trainer_methods():
34+
with pytest.deprecated_call(match='will be removed in v1.4'):
35+
trainer = Trainer()
36+
_ = trainer.get_model()
37+
assert trainer.get_model() == trainer.lightning_module
38+
39+
3340
def test_v1_4_0_deprecated_imports():
3441
_soft_unimport_module('pytorch_lightning.utilities.argparse_utils')
3542
with pytest.deprecated_call(match='will be removed in v1.4'):

0 commit comments

Comments
 (0)