Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Deprecated `.get_model()` with explicit `.lightning_module` property ([#6035](https://github.com/PyTorchLightning/pytorch-lightning/pull/6035))


- Deprecated Trainer attribute `accelerator_backend` in favor of `accelerator` ([#6034](https://github.com/PyTorchLightning/pytorch-lightning/pull/6034))



### Removed

- Removed deprecated checkpoint argument `filepath` ([#5321](https://github.com/PyTorchLightning/pytorch-lightning/pull/5321))
Expand Down
2 changes: 1 addition & 1 deletion pl_examples/basic_examples/conv_sequential_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,6 @@ def instantiate_datamodule(args):
trainer.fit(model, cifar10_dm)
trainer.test(model, datamodule=cifar10_dm)

if trainer.accelerator_backend.rpc_enabled:
if trainer.accelerator.rpc_enabled:
# Called at the end of trainer to ensure all processes are killed
trainer.training_type_plugin.exit_rpc_process()
7 changes: 1 addition & 6 deletions pytorch_lightning/core/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@
from pytorch_lightning.utilities import rank_zero_warn
from pytorch_lightning.utilities.apply_func import apply_to_collection, convert_to_tensors
from pytorch_lightning.utilities.device_dtype_mixin import DeviceDtypeModuleMixin
from pytorch_lightning.utilities.distributed import all_gather_ddp_if_available
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.parsing import AttributeDict, collect_init_args, get_init_args

Expand Down Expand Up @@ -448,11 +447,7 @@ def all_gather(
the output will also be a collection with tensors of this shape.
"""
group = group if group is not None else torch.distributed.group.WORLD
if self.trainer.accelerator_backend is not None:
all_gather = self.trainer.accelerator_backend.all_gather
else:
all_gather = all_gather_ddp_if_available

all_gather = self.trainer.accelerator.all_gather
data = convert_to_tensors(data, device=self.device)
all_gather = partial(all_gather, group=group, sync_grads=sync_grads)
return apply_to_collection(data, torch.Tensor, all_gather)
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/core/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ def __optimizer_step(self, closure: Optional[Callable] = None, profiler_name: st
model = trainer.lightning_module

with trainer.profiler.profile(profiler_name):
trainer.accelerator_backend.optimizer_step(optimizer, self._optimizer_idx, lambda_closure=closure, **kwargs)
trainer.accelerator.optimizer_step(optimizer, self._optimizer_idx, lambda_closure=closure, **kwargs)

if self._trainer.train_loop.automatic_optimization:
trainer.train_loop.on_before_zero_grad(optimizer)
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/plugins/training_type/deepspeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ def init_deepspeed(self):
self._format_config()
self._config_initialized = True

precision = self.lightning_module.trainer.accelerator_backend.precision
precision = self.lightning_module.trainer.accelerator.precision
model = LightningDeepSpeedModule(pl_module=self.model, precision=precision)

if self.lightning_module.trainer.training:
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/plugins/training_type/tpu_spawn.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def new_process(self, process_idx: int, trainer, mp_queue) -> None:
trainer.progress_bar_callback.disable()

self.model_to_device()
trainer.accelerator_backend.setup_optimizers(trainer)
trainer.accelerator.setup_optimizers(trainer)
trainer.precision_plugin.connect(self._model, None, None)

# replace trainer save_checkpoint to use `xm.save`
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -219,8 +219,7 @@ def hpc_save(self, folderpath: str, logger):

model.on_hpc_save(checkpoint)

if self.trainer.accelerator_backend:
checkpoint = self.trainer.accelerator_backend.on_save(checkpoint)
checkpoint = self.trainer.accelerator.on_save(checkpoint)

# do the actual save
# TODO: fix for anything with multiprocess DP, DDP, DDP2
Expand Down Expand Up @@ -286,7 +285,7 @@ def dump_checkpoint(self, weights_only: bool = False) -> dict:
optimizer_states = []
for i, optimizer in enumerate(self.trainer.optimizers):
# Rely on accelerator to dump optimizer state
optimizer_state = self.trainer.accelerator_backend.optimizer_state(optimizer)
optimizer_state = self.trainer.accelerator.optimizer_state(optimizer)
optimizer_states.append(optimizer_state)

checkpoint['optimizer_states'] = optimizer_states
Expand Down
5 changes: 2 additions & 3 deletions pytorch_lightning/trainer/data_loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ class TrainerDataLoadingMixin(ABC):
limit_val_batches: Union[int, float]
limit_test_batches: Union[int, float]
replace_sampler_ddp: bool
accelerator_backend: Accelerator
accelerator: Accelerator
num_nodes: int
num_processes: int
distributed_backend: Optional[str]
Expand Down Expand Up @@ -398,8 +398,7 @@ def request_dataloader(self, dataloader_fx: Callable) -> DataLoader:
dataloader = dataloader_fx()
dataloader = self._flatten_dl_only(dataloader)

if self.accelerator_backend is not None:
self.accelerator_backend.barrier('get_dataloaders')
self.accelerator.barrier('get_dataloaders')
return dataloader

def _flatten_dl_only(self, dataloaders):
Expand Down
12 changes: 11 additions & 1 deletion pytorch_lightning/trainer/deprecated_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from pytorch_lightning.accelerators import Accelerator
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.trainer.connectors.accelerator_connector import AcceleratorConnector
from pytorch_lightning.trainer.states import RunningStage
Expand Down Expand Up @@ -133,10 +134,19 @@ def use_single_gpu(self, val: bool) -> None:
self.accelerator_connector._device_type = DeviceType.GPU


class DeprecatedModelAttributes:
class DeprecatedTrainerAttributes:

accelerator: Accelerator
lightning_module = LightningModule

@property
def accelerator_backend(self) -> Accelerator:
rank_zero_warn(
"The `Trainer.accelerator_backend` attribute is deprecated in favor of `Trainer.accelerator`"
" since 1.2 and will be removed in v1.4.", DeprecationWarning
)
return self.accelerator

def get_model(self) -> LightningModule:
rank_zero_warn(
"The use of `Trainer.get_model()` is deprecated in favor of `Trainer.lightning_module`"
Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/trainer/evaluation_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,11 +157,11 @@ def evaluation_step(self, batch, batch_idx, dataloader_idx):
if self.testing:
model_ref._current_fx_name = "test_step"
with self.trainer.profiler.profile("test_step"):
output = self.trainer.accelerator_backend.test_step(args)
output = self.trainer.accelerator.test_step(args)
else:
model_ref._current_fx_name = "validation_step"
with self.trainer.profiler.profile("validation_step"):
output = self.trainer.accelerator_backend.validation_step(args)
output = self.trainer.accelerator.validation_step(args)

# capture any logged information
self.trainer.logger_connector.cache_logged_metrics()
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/trainer/predict_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def predict(self, batch, batch_idx, dataloader_idx):
model_ref = self.trainer.lightning_module

model_ref._current_fx_name = "predict"
predictions = self.trainer.accelerator_backend.predict(args)
predictions = self.trainer.accelerator.predict(args)
self._predictions[dataloader_idx].append(predictions)
self.trainer._progress_bar_callback.on_predict_batch_end(
self.trainer, model_ref, predictions, batch, batch_idx, dataloader_idx
Expand Down
9 changes: 2 additions & 7 deletions pytorch_lightning/trainer/properties.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,11 +62,6 @@ class TrainerProperties(ABC):
def accelerator(self) -> Accelerator:
return self.accelerator_connector.accelerator

@property
def accelerator_backend(self) -> Accelerator:
# for backward compatibility
return self.accelerator

@property
def distributed_backend(self) -> Optional[str]:
# for backward compatibility
Expand Down Expand Up @@ -138,7 +133,7 @@ def log_dir(self) -> Optional[str]:
else:
dirpath = getattr(self.logger, 'log_dir' if isinstance(self.logger, TensorBoardLogger) else 'save_dir')

dirpath = self.accelerator_backend.broadcast(dirpath)
dirpath = self.accelerator.broadcast(dirpath)
return dirpath

@property
Expand Down Expand Up @@ -360,7 +355,7 @@ def lightning_optimizers(self) -> List[LightningOptimizer]:

@property
def lightning_module(self) -> LightningModule:
return self.accelerator_backend.lightning_module
return self.accelerator.lightning_module

@property
def optimizers(self) -> Optional[List[Optimizer]]:
Expand Down
26 changes: 13 additions & 13 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@
from pytorch_lightning.trainer.connectors.slurm_connector import SLURMConnector
from pytorch_lightning.trainer.connectors.training_trick_connector import TrainingTricksConnector
from pytorch_lightning.trainer.data_loading import TrainerDataLoadingMixin
from pytorch_lightning.trainer.deprecated_api import DeprecatedDistDeviceAttributes, DeprecatedModelAttributes
from pytorch_lightning.trainer.deprecated_api import DeprecatedDistDeviceAttributes, DeprecatedTrainerAttributes
from pytorch_lightning.trainer.evaluation_loop import EvaluationLoop
from pytorch_lightning.trainer.logging import TrainerLoggingMixin
from pytorch_lightning.trainer.model_hooks import TrainerModelHooksMixin
Expand Down Expand Up @@ -80,7 +80,7 @@ class Trainer(
TrainerTrainingTricksMixin,
TrainerDataLoadingMixin,
DeprecatedDistDeviceAttributes,
DeprecatedModelAttributes,
DeprecatedTrainerAttributes,
):

@overwrite_by_env_vars
Expand Down Expand Up @@ -470,7 +470,7 @@ def fit(
# ----------------------------
self.call_setup_hook(model)
self.call_hook("on_before_accelerator_backend_setup", model)
self.accelerator_backend.setup(self, model)
self.accelerator.setup(self, model)
self.setup_trainer(model)

# ----------------------------
Expand Down Expand Up @@ -533,24 +533,24 @@ def fit(

self._set_running_stage(None, model)

return self.accelerator_backend.results or 1
return self.accelerator.results or 1

def pre_dispatch(self):
self.accelerator_backend.pre_dispatch()
self.accelerator.pre_dispatch()

def post_dispatch(self):
self.accelerator_backend.post_dispatch()
self.accelerator_backend.teardown()
self.accelerator.post_dispatch()
self.accelerator.teardown()

def dispatch(self):
if self.testing:
self.accelerator_backend.start_testing(self)
self.accelerator.start_testing(self)

elif self.predicting:
self.accelerator_backend.start_predicting(self)
self.accelerator.start_predicting(self)

else:
self.accelerator_backend.start_training(self)
self.accelerator.start_training(self)

def train_or_test_or_predict(self):
if self.testing:
Expand Down Expand Up @@ -949,7 +949,7 @@ def __test_using_best_weights(self, ckpt_path, test_dataloaders):
)
return {}
if not self._device_type == DeviceType.TPU:
self.accelerator_backend.barrier()
self.accelerator.barrier()

ckpt = pl_load(ckpt_path, map_location=lambda storage, loc: storage)
model.load_state_dict(ckpt['state_dict'])
Expand Down Expand Up @@ -1109,8 +1109,8 @@ def call_hook(self, hook_name, *args, **kwargs):

# if the PL module doesn't have the hook then call the accelerator
# used to auto-reduce things for the user with Results obj
elif hasattr(self.accelerator_backend, hook_name):
accelerator_hook = getattr(self.accelerator_backend, hook_name)
elif hasattr(self.accelerator, hook_name):
accelerator_hook = getattr(self.accelerator, hook_name)
output = accelerator_hook(*args, **kwargs)

if not skip:
Expand Down
12 changes: 6 additions & 6 deletions pytorch_lightning/trainer/training_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,8 +290,8 @@ def training_step(self, split_batch, batch_idx, opt_idx, hiddens):
model_ref._current_fx_name = 'training_step'
model_ref._results = Result()
with self.trainer.profiler.profile("training_step"):
training_step_output = self.trainer.accelerator_backend.training_step(args)
self.trainer.accelerator_backend.post_training_step()
training_step_output = self.trainer.accelerator.training_step(args)
self.trainer.accelerator.post_training_step()

self.trainer.logger_connector.cache_logged_metrics()

Expand Down Expand Up @@ -438,14 +438,14 @@ def on_before_zero_grad(self, optimizer):
self.trainer.call_hook('on_before_zero_grad', optimizer)

def optimizer_zero_grad(self, batch_idx, optimizer, opt_idx):
self.trainer.accelerator_backend.optimizer_zero_grad(self.trainer.current_epoch, batch_idx, optimizer, opt_idx)
self.trainer.accelerator.optimizer_zero_grad(self.trainer.current_epoch, batch_idx, optimizer, opt_idx)

def track_and_norm_grad(self, optimizer):
# track gradient norms
grad_norm_dic = self._track_gradient_norm()

# clip gradients
self.trainer.accelerator_backend.clip_gradients(optimizer, self.trainer.gradient_clip_val)
self.trainer.accelerator.clip_gradients(optimizer, self.trainer.gradient_clip_val)
self._cur_grad_norm_dict = grad_norm_dic

def _track_gradient_norm(self):
Expand Down Expand Up @@ -769,9 +769,9 @@ def backward(self, result, optimizer, opt_idx, *args, **kwargs):

# backward can be called manually in the training loop
if isinstance(result, torch.Tensor):
self.trainer.accelerator_backend.backward(result, optimizer, opt_idx, should_accumulate, *args, **kwargs)
self.trainer.accelerator.backward(result, optimizer, opt_idx, should_accumulate, *args, **kwargs)
else:
result.closure_loss = self.trainer.accelerator_backend.backward(
result.closure_loss = self.trainer.accelerator.backward(
result.closure_loss, optimizer, opt_idx, should_accumulate, *args, **kwargs
)

Expand Down
Loading