Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion pytorch_lightning/trainer/connectors/data_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from dataclasses import dataclass
from functools import partial
from typing import Iterable, Optional, Union
from weakref import proxy

import pytorch_lightning as pl
from pytorch_lightning.utilities import rank_zero_deprecation
Expand Down Expand Up @@ -186,7 +187,15 @@ def attach_data(
)
self.attach_datamodule(model, datamodule=datamodule)
# set local properties on the model
self.trainer.model_connector.copy_trainer_model_properties(model)
self._copy_trainer_model_properties(model)

def _copy_trainer_model_properties(self, model):
ref_model = self.trainer.lightning_module or model

for m in [model, ref_model]:
m.trainer = proxy(self.trainer)
m.use_amp = self.trainer.amp_backend is not None
m.precision = self.trainer.precision

def attach_dataloaders(
self,
Expand Down
27 changes: 0 additions & 27 deletions pytorch_lightning/trainer/connectors/model_connector.py

This file was deleted.

2 changes: 0 additions & 2 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,6 @@
from pytorch_lightning.trainer.connectors.env_vars_connector import _defaults_from_env_vars
from pytorch_lightning.trainer.connectors.logger_connector import LoggerConnector
from pytorch_lightning.trainer.connectors.logger_connector.result import ResultCollection
from pytorch_lightning.trainer.connectors.model_connector import ModelConnector
from pytorch_lightning.trainer.connectors.signal_connector import SignalConnector
from pytorch_lightning.trainer.connectors.training_trick_connector import TrainingTricksConnector
from pytorch_lightning.trainer.data_loading import TrainerDataLoadingMixin
Expand Down Expand Up @@ -450,7 +449,6 @@ def __init__(
plugins,
)
self.logger_connector = LoggerConnector(self, log_gpu_memory)
self.model_connector = ModelConnector(self)
self.callback_connector = CallbackConnector(self)
self.debugging_connector = DebuggingConnector(self)
self.training_tricks_connector = TrainingTricksConnector(self)
Expand Down