Skip to content

Commit cf5ef32

Browse files
four4fishawaelchli
andauthored
Deprecate Trainer.training_type_plugin in favor of trainer.strategy (#11141)
Co-authored-by: Adrian Wälchli <[email protected]>
1 parent 17ad1a4 commit cf5ef32

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

55 files changed

+387
-381
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
158158
- DeepSpeed does not require lightning module zero 3 partitioning ([#10655](https://github.com/PyTorchLightning/pytorch-lightning/pull/10655))
159159

160160

161+
- Deprecated `training_type_plugin` property in favor of `strategy` in `Trainer` and updated the references ([#11141](https://github.com/PyTorchLightning/pytorch-lightning/pull/11141))
162+
163+
161164
### Deprecated
162165

163166
- Deprecated `ClusterEnvironment.master_{address,port}` in favor of `ClusterEnvironment.main_{address,port}` ([#10103](https://github.com/PyTorchLightning/pytorch-lightning/issues/10103))

pl_examples/loop_examples/kfold.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -205,8 +205,8 @@ def on_run_end(self) -> None:
205205
voting_model = EnsembleVotingModel(type(self.trainer.lightning_module), checkpoint_paths)
206206
voting_model.trainer = self.trainer
207207
# This requires to connect the new model and move it the right device.
208-
self.trainer.training_type_plugin.connect(voting_model)
209-
self.trainer.training_type_plugin.model_to_device()
208+
self.trainer.strategy.connect(voting_model)
209+
self.trainer.strategy.model_to_device()
210210
self.trainer.test_loop.run()
211211

212212
def on_save_checkpoint(self) -> Dict[str, int]:

pl_examples/loop_examples/yielding_training_step.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ def _get_generator(self, split_batch, batch_idx, opt_idx):
7777
# Here we are basically calling `lightning_module.training_step()`
7878
# and this returns a generator! The `training_step` is handled by the
7979
# accelerator to enable distributed training.
80-
return self.trainer.training_type_plugin.training_step(*step_kwargs.values())
80+
return self.trainer.strategy.training_step(*step_kwargs.values())
8181

8282
def _training_step(self, generator):
8383
# required for logging
@@ -86,7 +86,7 @@ def _training_step(self, generator):
8686
# Here, instead of calling `lightning_module.training_step()`
8787
# we call next() on the generator!
8888
training_step_output = next(generator)
89-
self.trainer.training_type_plugin.post_training_step()
89+
self.trainer.strategy.post_training_step()
9090

9191
model_output = self.trainer._call_lightning_module_hook("training_step_end", training_step_output)
9292
strategy_output = self.trainer._call_strategy_hook("training_step_end", training_step_output)

pytorch_lightning/callbacks/early_stopping.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -200,7 +200,7 @@ def _run_early_stopping_check(self, trainer: "pl.Trainer") -> None:
200200
should_stop, reason = self._evaluate_stopping_criteria(current)
201201

202202
# stop every ddp process if any world process decides to stop
203-
should_stop = trainer.training_type_plugin.reduce_boolean_decision(should_stop)
203+
should_stop = trainer.strategy.reduce_boolean_decision(should_stop)
204204
trainer.should_stop = trainer.should_stop or should_stop
205205
if should_stop:
206206
self.stopped_epoch = trainer.current_epoch

pytorch_lightning/callbacks/model_checkpoint.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -286,7 +286,7 @@ def on_train_batch_end(
286286
skip_time = prev_time_check is None or (now - prev_time_check) < train_time_interval.total_seconds()
287287
# in case we have time differences across ranks
288288
# broadcast the decision on whether to checkpoint from rank 0 to avoid possible hangs
289-
skip_time = trainer.training_type_plugin.broadcast(skip_time)
289+
skip_time = trainer.strategy.broadcast(skip_time)
290290

291291
if skip_batch and skip_time:
292292
return
@@ -492,7 +492,7 @@ def check_monitor_top_k(self, trainer: "pl.Trainer", current: Optional[torch.Ten
492492
should_update_best_and_save = monitor_op(current, self.best_k_models[self.kth_best_model_path])
493493

494494
# If using multiple devices, make sure all processes are unanimous on the decision.
495-
should_update_best_and_save = trainer.training_type_plugin.reduce_boolean_decision(should_update_best_and_save)
495+
should_update_best_and_save = trainer.strategy.reduce_boolean_decision(should_update_best_and_save)
496496

497497
return should_update_best_and_save
498498

@@ -598,7 +598,7 @@ def __resolve_ckpt_dir(self, trainer: "pl.Trainer") -> None:
598598
else:
599599
ckpt_path = os.path.join(trainer.weights_save_path, "checkpoints")
600600

601-
ckpt_path = trainer.training_type_plugin.broadcast(ckpt_path)
601+
ckpt_path = trainer.strategy.broadcast(ckpt_path)
602602

603603
self.dirpath = ckpt_path
604604

@@ -646,7 +646,7 @@ def _save_last_checkpoint(self, trainer: "pl.Trainer", monitor_candidates: Dict[
646646
trainer.save_checkpoint(filepath, self.save_weights_only)
647647

648648
if self.last_model_path and self.last_model_path != filepath:
649-
trainer.training_type_plugin.remove_checkpoint(self.last_model_path)
649+
trainer.strategy.remove_checkpoint(self.last_model_path)
650650

651651
self.last_model_path = filepath
652652

@@ -671,7 +671,7 @@ def _save_none_monitor_checkpoint(self, trainer: "pl.Trainer", monitor_candidate
671671
trainer.save_checkpoint(filepath, self.save_weights_only)
672672

673673
if self.save_top_k == 1 and self.best_model_path and self.best_model_path != filepath:
674-
trainer.training_type_plugin.remove_checkpoint(self.best_model_path)
674+
trainer.strategy.remove_checkpoint(self.best_model_path)
675675

676676
self.best_model_path = filepath
677677

@@ -718,7 +718,7 @@ def _update_best_and_save(
718718
trainer.save_checkpoint(filepath, self.save_weights_only)
719719

720720
if del_filepath is not None and filepath != del_filepath:
721-
trainer.training_type_plugin.remove_checkpoint(del_filepath)
721+
trainer.strategy.remove_checkpoint(del_filepath)
722722

723723
def to_yaml(self, filepath: Optional[_PATH] = None) -> None:
724724
"""Saves the `best_k_models` dict containing the checkpoint paths with the corresponding scores to a YAML
@@ -733,4 +733,4 @@ def file_exists(self, filepath: _PATH, trainer: "pl.Trainer") -> bool:
733733
"""Checks if a file exists on rank 0 and broadcasts the result to all other ranks, preventing the internal
734734
state to diverge between ranks."""
735735
exists = self._fs.exists(filepath)
736-
return trainer.training_type_plugin.broadcast(exists)
736+
return trainer.strategy.broadcast(exists)

pytorch_lightning/callbacks/timer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -173,7 +173,7 @@ def on_load_checkpoint(
173173
def _check_time_remaining(self, trainer: "pl.Trainer") -> None:
174174
assert self._duration is not None
175175
should_stop = self.time_elapsed() >= self._duration
176-
should_stop = trainer.training_type_plugin.broadcast(should_stop)
176+
should_stop = trainer.strategy.broadcast(should_stop)
177177
trainer.should_stop = trainer.should_stop or should_stop
178178
if should_stop and self._verbose:
179179
elapsed = timedelta(seconds=int(self.time_elapsed(RunningStage.TRAINING)))

pytorch_lightning/callbacks/xla_stats_monitor.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ def on_train_start(self, trainer, pl_module) -> None:
7777
)
7878

7979
memory_info = xm.get_memory_info(pl_module.device)
80-
total_memory = trainer.training_type_plugin.reduce(memory_info["kb_total"]) * 0.001
80+
total_memory = trainer.strategy.reduce(memory_info["kb_total"]) * 0.001
8181
rank_zero_info(f"Average Total memory: {total_memory:.2f} MB")
8282

8383
def on_train_epoch_start(self, trainer, pl_module) -> None:
@@ -91,9 +91,9 @@ def on_train_epoch_end(self, trainer, pl_module) -> None:
9191
free_memory = memory_info["kb_free"]
9292
peak_memory = memory_info["kb_total"] - free_memory
9393

94-
free_memory = trainer.training_type_plugin.reduce(free_memory) * 0.001
95-
peak_memory = trainer.training_type_plugin.reduce(peak_memory) * 0.001
96-
epoch_time = trainer.training_type_plugin.reduce(epoch_time)
94+
free_memory = trainer.strategy.reduce(free_memory) * 0.001
95+
peak_memory = trainer.strategy.reduce(peak_memory) * 0.001
96+
epoch_time = trainer.strategy.reduce(epoch_time)
9797

9898
logs["avg. free memory (MB)"] = free_memory
9999
logs["avg. peak memory (MB)"] = peak_memory

pytorch_lightning/core/lightning.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -421,7 +421,7 @@ def log(
421421
add_dataloader_idx=add_dataloader_idx,
422422
batch_size=batch_size,
423423
sync_dist=sync_dist and distributed_available(),
424-
sync_dist_fn=self.trainer.training_type_plugin.reduce or sync_ddp,
424+
sync_dist_fn=self.trainer.strategy.reduce or sync_ddp,
425425
sync_dist_group=sync_dist_group,
426426
metric_attribute=metric_attribute,
427427
rank_zero_only=rank_zero_only,
@@ -536,7 +536,7 @@ def all_gather(
536536
the output will also be a collection with tensors of this shape.
537537
"""
538538
group = group if group is not None else torch.distributed.group.WORLD
539-
all_gather = self.trainer.training_type_plugin.all_gather
539+
all_gather = self.trainer.strategy.all_gather
540540
data = convert_to_tensors(data, device=self.device)
541541
return apply_to_collection(data, torch.Tensor, all_gather, group=group, sync_grads=sync_grads)
542542

@@ -1337,7 +1337,7 @@ def training_step(...):
13371337
**kwargs: Additional keyword arguments to be forwarded to :meth:`~torch.Tensor.backward`
13381338
"""
13391339
self._verify_is_manual_optimization("manual_backward")
1340-
self.trainer.training_type_plugin.backward(loss, None, None, *args, **kwargs)
1340+
self.trainer.strategy.backward(loss, None, None, *args, **kwargs)
13411341

13421342
def backward(
13431343
self, loss: Tensor, optimizer: Optional[Optimizer], optimizer_idx: Optional[int], *args, **kwargs

pytorch_lightning/core/optimizer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -161,4 +161,4 @@ def closure_dis():
161161
trainer = self._trainer
162162
assert trainer is not None
163163
with trainer.profiler.profile(profiler_action):
164-
trainer.training_type_plugin.optimizer_step(self._optimizer, self._optimizer_idx, closure, **kwargs)
164+
trainer.strategy.optimizer_step(self._optimizer, self._optimizer_idx, closure, **kwargs)

pytorch_lightning/loops/base.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -329,9 +329,7 @@ def _load_from_state_dict(self, state_dict: Dict, prefix: str, metrics: Optional
329329
# Python primitives. However, their states are saved with the model's `state_dict`.
330330
# On reload, we need to re-attach the `Metric`s back to the `_ResultCollection`.
331331
# The references are provided through the `metric_attributes` dictionary.
332-
v.load_state_dict(
333-
state_dict[key], metrics=metric_attributes, sync_fn=self.trainer.training_type_plugin.reduce
334-
)
332+
v.load_state_dict(state_dict[key], metrics=metric_attributes, sync_fn=self.trainer.strategy.reduce)
335333

336334
if not self.trainer.is_global_zero:
337335
v.reset(metrics=False)

0 commit comments

Comments
 (0)