Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Show a better error message when a custom `DataLoader` implementation is not well implemented and we need to reconstruct it ([#10719](https://github.com/PyTorchLightning/pytorch-lightning/issues/10719))


- Save the `Loop`'s state by default in the checkpoint ([#10784](https://github.com/PyTorchLightning/pytorch-lightning/issues/10784))


- Added `Loop.replace` to easily switch one loop for another ([#10324](https://github.com/PyTorchLightning/pytorch-lightning/issues/10324))


Expand Down
7 changes: 7 additions & 0 deletions pytorch_lightning/callbacks/timer.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,13 @@ def on_test_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule")
def on_test_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
self._end_time[RunningStage.TESTING] = time.monotonic()

def on_fit_start(self, trainer: "pl.Trainer", *args: Any, **kwargs: Any) -> None:
# this checks the time after the state is reloaded, regardless of the interval.
# this is necessary in case we load a state whose timer is already depleted
if self._duration is None:
return
self._check_time_remaining(trainer)

def on_train_batch_end(self, trainer: "pl.Trainer", *args: Any, **kwargs: Any) -> None:
if self._interval != Interval.step or self._duration is None:
return
Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/core/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,7 +303,7 @@ def log(
add_dataloader_idx: bool = True,
batch_size: Optional[int] = None,
metric_attribute: Optional[str] = None,
rank_zero_only: Optional[bool] = None,
rank_zero_only: bool = False,
) -> None:
"""Log a key, value pair.

Expand Down Expand Up @@ -441,7 +441,7 @@ def log_dict(
sync_dist_group: Optional[Any] = None,
add_dataloader_idx: bool = True,
batch_size: Optional[int] = None,
rank_zero_only: Optional[bool] = None,
rank_zero_only: bool = False,
) -> None:
"""Log a dictionary of values at once.

Expand Down
13 changes: 11 additions & 2 deletions pytorch_lightning/loops/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import pytorch_lightning as pl
from pytorch_lightning.trainer.connectors.logger_connector.result import ResultCollection
from pytorch_lightning.trainer.progress import BaseProgress
from pytorch_lightning.utilities.enums import _FaultTolerantMode
from pytorch_lightning.utilities.exceptions import MisconfigurationException

T = TypeVar("T") # the output type of `run`
Expand Down Expand Up @@ -273,9 +274,11 @@ def state_dict(self, destination: Optional[Dict] = None, prefix: str = "") -> Di

destination[prefix + "state_dict"] = self.on_save_checkpoint()

# do not get the mode from `self.trainer` because it might not have been attached yet
ft_enabled = _FaultTolerantMode.detect_current_mode().is_enabled
for k, v in self.__dict__.items():
key = prefix + k
if isinstance(v, BaseProgress):
if ft_enabled and isinstance(v, BaseProgress):
destination[key] = v.state_dict()
elif isinstance(v, Loop):
v.state_dict(destination, key + ".")
Expand All @@ -302,6 +305,10 @@ def load_state_dict(
def _load_from_state_dict(self, state_dict: Dict, prefix: str, metrics: Optional[Dict[str, Metric]] = None) -> None:
for k, v in self.__dict__.items():
key = prefix + k
if key not in state_dict:
# no state for this object, maybe we are loading an old checkpoint
continue

if isinstance(v, BaseProgress):
v.load_state_dict(state_dict[key])
elif (
Expand Down Expand Up @@ -330,4 +337,6 @@ def _load_from_state_dict(self, state_dict: Dict, prefix: str, metrics: Optional
v.reset(metrics=False)

self.on_load_checkpoint(state_dict[prefix + "state_dict"])
self.restarting = True

if _FaultTolerantMode.detect_current_mode().is_enabled:
self.restarting = True
Original file line number Diff line number Diff line change
Expand Up @@ -373,9 +373,8 @@ def dump_checkpoint(self, weights_only: bool = False) -> dict:
"global_step": global_step,
"pytorch-lightning_version": pl.__version__,
"state_dict": self._get_lightning_module_state_dict(),
"loops": self._get_loops_state_dict(),
}
if _fault_tolerant_training():
checkpoint["loops"] = self._get_loops_state_dict()

if not weights_only:
# dump callbacks
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -616,8 +616,8 @@ def cpu(self) -> "ResultCollection":

def sync(self) -> None:
for result_metric in self.result_metrics:
if result_metric.is_tensor:
result_metric.sync()
if result_metric.is_tensor and not result_metric._is_synced:
result_metric.sync(should_sync=not result_metric.meta.sync.rank_zero_only)

def unsync(self) -> None:
for result_metric in self.result_metrics:
Expand Down
9 changes: 3 additions & 6 deletions tests/callbacks/test_timer.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,15 +168,12 @@ def test_timer_resume_training(tmpdir):
assert trainer.current_epoch < 99
saved_global_step = trainer.global_step

# resume training (with depleted timer
# resume training (with depleted timer)
timer = Timer(duration=timedelta(milliseconds=200))
trainer = Trainer(
default_root_dir=tmpdir,
callbacks=[timer, checkpoint_callback],
)
trainer = Trainer(default_root_dir=tmpdir, callbacks=timer)
trainer.fit(model, ckpt_path=checkpoint_callback.best_model_path)
assert timer._offset > 0
assert trainer.global_step == saved_global_step + 1
assert trainer.global_step == saved_global_step


@RunIf(skip_windows=True)
Expand Down
3 changes: 3 additions & 0 deletions tests/loops/test_loop_state_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from unittest import mock
from unittest.mock import Mock

import pytest
Expand All @@ -37,6 +39,7 @@ def test_loops_state_dict():
assert fit_loop.state_dict() == new_fit_loop.state_dict()


@mock.patch.dict(os.environ, {"PL_FAULT_TOLERANT_TRAINING": "1"})
def test_loops_state_dict_structure():
trainer = Trainer()
trainer.train_dataloader = Mock()
Expand Down
1 change: 1 addition & 0 deletions tests/loops/test_loops.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,7 @@ def load_state_dict(self, state_dict: Dict) -> None:
assert loop.outputs == list(range(10))


@mock.patch.dict(os.environ, {"PL_FAULT_TOLERANT_TRAINING": "1"})
def test_loop_hierarchy():
@dataclass
class SimpleProgress(BaseProgress):
Expand Down
2 changes: 2 additions & 0 deletions tests/models/test_hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -497,6 +497,7 @@ def training_step(self, batch, batch_idx):
"optimizer_states": ANY,
"pytorch-lightning_version": __version__,
"state_dict": ANY,
"loops": ANY,
}
if kwargs.get("amp_backend") == "native":
saved_ckpt["native_amp_scaling_state"] = ANY
Expand Down Expand Up @@ -624,6 +625,7 @@ def test_trainer_model_hook_system_fit_no_val_and_resume(tmpdir):
"optimizer_states": ANY,
"pytorch-lightning_version": __version__,
"state_dict": ANY,
"loops": ANY,
}
# TODO: wrong saved epoch, should be 0
saved_ckpt = {**loaded_ckpt, "global_step": steps_after_reload, "epoch": 2}
Expand Down
2 changes: 1 addition & 1 deletion tests/trainer/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -660,7 +660,7 @@ def test_benchmark_option(tmpdir):
@pytest.mark.parametrize("ckpt_path", (None, "best", "specific"))
@pytest.mark.parametrize("save_top_k", (-1, 0, 1, 2))
@pytest.mark.parametrize("fn", ("validate", "test", "predict"))
def test_tested_checkpoint_path(tmpdir, ckpt_path, save_top_k, fn):
def test_checkpoint_path_input(tmpdir, ckpt_path, save_top_k, fn):
class TestModel(BoringModel):
def validation_step(self, batch, batch_idx):
self.log("foo", -batch_idx)
Expand Down