diff --git a/CHANGELOG.md b/CHANGELOG.md index bd7bd5338de9d..4534a8eaa46b2 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -24,6 +24,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Progress tracking * Integrate `TrainingEpochLoop.total_batch_idx` ([#8598](https://github.com/PyTorchLightning/pytorch-lightning/pull/8598) + * Avoid optional `Tracker` attributes ([#9320](https://github.com/PyTorchLightning/pytorch-lightning/pull/9320) - Added `batch_size` and `rank_zero_only` arguments for `log_dict` to match `log` ([#8628](https://github.com/PyTorchLightning/pytorch-lightning/pull/8628)) diff --git a/pyproject.toml b/pyproject.toml index 61b620d759c5d..7c1e24b5b8846 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -66,6 +66,7 @@ module = [ "pytorch_lightning.trainer.evaluation_loop", "pytorch_lightning.trainer.connectors.logger_connector.fx_validator", "pytorch_lightning.trainer.connectors.logger_connector.logger_connector", + "pytorch_lightning.trainer.progress", "pytorch_lightning.tuner.auto_gpu_select", "pytorch_lightning.utilities.apply_func", "pytorch_lightning.utilities.argparse", diff --git a/pytorch_lightning/trainer/progress.py b/pytorch_lightning/trainer/progress.py index a7e7ef19044a1..fd7034f63e1da 100644 --- a/pytorch_lightning/trainer/progress.py +++ b/pytorch_lightning/trainer/progress.py @@ -12,13 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. from dataclasses import asdict, dataclass, field -from typing import Optional +from typing import Type @dataclass class BaseProgress: """ - Mixin that implements state-loading utiltiies for dataclasses. + Mixin that implements state-loading utilities for dataclasses. """ def state_dict(self) -> dict: @@ -35,63 +35,83 @@ def from_state_dict(cls, state_dict: dict) -> "BaseProgress": @dataclass -class Tracker(BaseProgress): +class ReadyCompletedTracker(BaseProgress): """ Track an event's progress. Args: ready: Intended to track the number of events ready to start. - started: Intended to be incremented after the event is started (e.g. after ``on_*_start`` runs). - processed: Intended to be incremented after the event is processed. completed: Intended to be incremented after the event completes (e.g. after ``on_*_end`` runs). These attributes should be increased in order, that is, :attr:`ready` first and :attr:`completed` last. - Attributes set to ``None`` are treated as unused and are restricted. """ - ready: Optional[int] = 0 - started: Optional[int] = 0 - processed: Optional[int] = 0 - completed: Optional[int] = 0 + ready: int = 0 + completed: int = 0 def reset(self) -> None: - if self.ready is not None: - self.ready = 0 - if self.started is not None: - self.started = 0 - if self.processed is not None: - self.processed = 0 - if self.completed is not None: - self.completed = 0 - - def __setattr__(self, key: str, value: int) -> None: - """Restrict writing to attributes set to ``None``.""" - if getattr(self, key, 0) is None: - raise AttributeError(f"The '{key}' attribute is meant to be unused") - return super().__setattr__(key, value) - - def __repr__(self) -> str: - """Custom implementation to hide ``None`` fields.""" - args = [f"{k}={v}" for k, v in self.__dict__.items() if v is not None] - return f"{self.__class__.__name__}({', '.join(args)})" + """Reset the state.""" + self.ready = 0 + self.completed = 0 def reset_on_restart(self) -> None: """ Reset the progress on restart. + If there is a failure before all attributes are increased, - we restore the attributes to the last fully completed value. + restore the attributes to the last fully completed value. """ - # choose in case `processed` is unused - value = self.completed if self.processed is None else self.processed + self.ready = self.completed + + +@dataclass +class StartedTracker(ReadyCompletedTracker): + """ + Track an event's progress. + + Args: + ready: Intended to track the number of events ready to start. + started: Intended to be incremented after the event is started (e.g. after ``on_*_start`` runs). + completed: Intended to be incremented after the event completes (e.g. after ``on_*_end`` runs). + + These attributes should be increased in order, that is, :attr:`ready` first and :attr:`completed` last. + """ + + started: int = 0 + + def reset(self) -> None: + super().reset() + self.started = 0 + + def reset_on_restart(self) -> None: + super().reset_on_restart() + self.started = self.completed + + +@dataclass +class ProcessedTracker(StartedTracker): + """ + Track an event's progress. + + Args: + ready: Intended to track the number of events ready to start. + started: Intended to be incremented after the event is started (e.g. after ``on_*_start`` runs). + processed: Intended to be incremented after the event is processed. + completed: Intended to be incremented after the event completes (e.g. after ``on_*_end`` runs). + + These attributes should be increased in order, that is, :attr:`ready` first and :attr:`completed` last. + """ - if self.ready is not None: - self.ready = value - if self.started is not None: - self.started = value - if self.processed is not None: - self.processed = value - if self.completed is not None: - self.completed = value + processed: int = 0 + + def reset(self) -> None: + super().reset() + self.processed = 0 + + def reset_on_restart(self) -> None: + # use `processed` in this case as the reset value + self.completed = self.processed + super().reset_on_restart() @dataclass @@ -104,18 +124,26 @@ class Progress(BaseProgress): current: Intended to track the current progress of an event. """ - total: Tracker = field(default_factory=Tracker) - current: Tracker = field(default_factory=Tracker) + total: ReadyCompletedTracker = field(default_factory=ProcessedTracker) + current: ReadyCompletedTracker = field(default_factory=ProcessedTracker) + + def __post_init__(self) -> None: + if type(self.total) is not type(self.current): # noqa: E721 + raise ValueError("The `total` and `current` instances should be of the same class") def increment_ready(self) -> None: self.total.ready += 1 self.current.ready += 1 def increment_started(self) -> None: + if not isinstance(self.total, StartedTracker): + raise TypeError(f"`{self.total.__class__.__name__}` doesn't have a `started` attribute") self.total.started += 1 self.current.started += 1 def increment_processed(self) -> None: + if not isinstance(self.total, ProcessedTracker): + raise TypeError(f"`{self.total.__class__.__name__}` doesn't have a `processed` attribute") self.total.processed += 1 self.current.processed += 1 @@ -124,9 +152,9 @@ def increment_completed(self) -> None: self.current.completed += 1 @classmethod - def from_defaults(cls, **kwargs: Optional[int]) -> "Progress": + def from_defaults(cls, tracker_cls: Type[ReadyCompletedTracker], **kwargs: int) -> "Progress": """Utility function to easily create an instance from keyword arguments to both ``Tracker``s.""" - return cls(total=Tracker(**kwargs), current=Tracker(**kwargs)) + return cls(total=tracker_cls(**kwargs), current=tracker_cls(**kwargs)) def load_state_dict(self, state_dict: dict) -> None: self.total.load_state_dict(state_dict["total"]) @@ -144,8 +172,8 @@ class DataLoaderProgress(Progress): current: Tracks the current dataloader progress. """ - total: Tracker = field(default_factory=lambda: Tracker(started=None, processed=None)) - current: Tracker = field(default_factory=lambda: Tracker(started=None, processed=None)) + total: ReadyCompletedTracker = field(default_factory=ReadyCompletedTracker) + current: ReadyCompletedTracker = field(default_factory=ReadyCompletedTracker) @dataclass @@ -159,8 +187,8 @@ class SchedulerProgress(Progress): current: Tracks the current scheduler progress. """ - total: Tracker = field(default_factory=lambda: Tracker(started=None, processed=None)) - current: Tracker = field(default_factory=lambda: Tracker(started=None, processed=None)) + total: ReadyCompletedTracker = field(default_factory=ReadyCompletedTracker) + current: ReadyCompletedTracker = field(default_factory=ReadyCompletedTracker) @dataclass @@ -173,8 +201,8 @@ class OptimizerProgress(BaseProgress): zero_grad: Tracks ``optimizer.zero_grad`` calls. """ - step: Progress = field(default_factory=lambda: Progress.from_defaults(started=None, processed=None)) - zero_grad: Progress = field(default_factory=lambda: Progress.from_defaults(processed=None)) + step: Progress = field(default_factory=lambda: Progress.from_defaults(ReadyCompletedTracker)) + zero_grad: Progress = field(default_factory=lambda: Progress.from_defaults(StartedTracker)) def reset_on_epoch(self) -> None: self.step.current.reset() diff --git a/tests/loops/test_loop_state_dict.py b/tests/loops/test_loop_state_dict.py index 4cff47f2f57b1..0010af32b4f99 100644 --- a/tests/loops/test_loop_state_dict.py +++ b/tests/loops/test_loop_state_dict.py @@ -53,28 +53,25 @@ def test_loops_state_dict_structure(): "current": {"ready": 0, "started": 0, "processed": 0, "completed": 0}, }, "epoch_loop.scheduler_progress": { - "total": {"ready": 0, "started": None, "processed": None, "completed": 0}, - "current": {"ready": 0, "started": None, "processed": None, "completed": 0}, + "total": {"ready": 0, "completed": 0}, + "current": {"ready": 0, "completed": 0}, }, "epoch_loop.batch_loop.optimizer_loop.state_dict": {}, "epoch_loop.batch_loop.state_dict": {}, "epoch_loop.batch_loop.optimizer_loop.optim_progress": { "optimizer": { - "step": { - "total": {"ready": 0, "started": None, "processed": None, "completed": 0}, - "current": {"ready": 0, "started": None, "processed": None, "completed": 0}, - }, + "step": {"total": {"ready": 0, "completed": 0}, "current": {"ready": 0, "completed": 0}}, "zero_grad": { - "total": {"ready": 0, "started": 0, "processed": None, "completed": 0}, - "current": {"ready": 0, "started": 0, "processed": None, "completed": 0}, + "total": {"ready": 0, "started": 0, "completed": 0}, + "current": {"ready": 0, "started": 0, "completed": 0}, }, }, "optimizer_idx": 0, }, "epoch_loop.val_loop.state_dict": {}, "epoch_loop.val_loop.dataloader_progress": { - "total": {"ready": 0, "started": None, "processed": None, "completed": 0}, - "current": {"ready": 0, "started": None, "processed": None, "completed": 0}, + "total": {"ready": 0, "completed": 0}, + "current": {"ready": 0, "completed": 0}, }, "epoch_loop.val_loop.epoch_loop.state_dict": {}, "epoch_loop.val_loop.epoch_loop.batch_progress": { @@ -102,10 +99,7 @@ def test_loops_state_dict_structure(): }, "validate_loop": { "state_dict": {}, - "dataloader_progress": { - "total": {"ready": 0, "started": None, "processed": None, "completed": 0}, - "current": {"ready": 0, "started": None, "processed": None, "completed": 0}, - }, + "dataloader_progress": {"total": {"ready": 0, "completed": 0}, "current": {"ready": 0, "completed": 0}}, "epoch_loop.state_dict": {}, "epoch_loop.batch_progress": { "total": {"ready": 0, "started": 0, "processed": 0, "completed": 0}, @@ -121,10 +115,7 @@ def test_loops_state_dict_structure(): }, "test_loop": { "state_dict": {}, - "dataloader_progress": { - "total": {"ready": 0, "started": None, "processed": None, "completed": 0}, - "current": {"ready": 0, "started": None, "processed": None, "completed": 0}, - }, + "dataloader_progress": {"total": {"ready": 0, "completed": 0}, "current": {"ready": 0, "completed": 0}}, "epoch_loop.state_dict": {}, "epoch_loop.batch_progress": { "total": {"ready": 0, "started": 0, "processed": 0, "completed": 0}, @@ -140,10 +131,7 @@ def test_loops_state_dict_structure(): }, "predict_loop": { "state_dict": {}, - "dataloader_progress": { - "total": {"ready": 0, "started": None, "processed": None, "completed": 0}, - "current": {"ready": 0, "started": None, "processed": None, "completed": 0}, - }, + "dataloader_progress": {"total": {"ready": 0, "completed": 0}, "current": {"ready": 0, "completed": 0}}, "epoch_loop.state_dict": {}, "epoch_loop.batch_progress": { "total": {"ready": 0, "started": 0, "processed": 0, "completed": 0}, diff --git a/tests/loops/test_loops.py b/tests/loops/test_loops.py index c20a0b6261dae..cd7ecab8dd493 100644 --- a/tests/loops/test_loops.py +++ b/tests/loops/test_loops.py @@ -285,8 +285,8 @@ def val_dataloader(self): total_dataloader = stop_epoch * n_dataloaders + stop_dataloader expected = { - "total": {"ready": total_dataloader + 1, "started": None, "processed": None, "completed": total_dataloader}, - "current": {"ready": stop_dataloader + 1, "started": None, "processed": None, "completed": stop_dataloader}, + "total": {"ready": total_dataloader + 1, "completed": total_dataloader}, + "current": {"ready": stop_dataloader + 1, "completed": stop_dataloader}, } assert checkpoint["epoch_loop.val_loop.dataloader_progress"] == expected @@ -452,13 +452,8 @@ def configure_optimizers_multiple(self): }, }, "epoch_loop.scheduler_progress": { - "total": { - "ready": nbe_sch_steps + be_sch_steps, - "started": None, - "processed": None, - "completed": nbe_sch_steps + be_sch_steps, - }, - "current": {"ready": be_sch_steps, "started": None, "processed": None, "completed": be_sch_steps}, + "total": {"ready": nbe_sch_steps + be_sch_steps, "completed": nbe_sch_steps + be_sch_steps}, + "current": {"ready": be_sch_steps, "completed": be_sch_steps}, }, "epoch_loop.batch_loop.state_dict": ANY, "epoch_loop.batch_loop.optimizer_loop.state_dict": {}, @@ -468,28 +463,19 @@ def configure_optimizers_multiple(self): "step": { "total": { "ready": nbe_total_opt_steps + be_total_opt_steps + has_opt_stepped_in_be, - "started": None, - "processed": None, "completed": nbe_total_opt_steps + be_total_opt_steps, }, - "current": { - "ready": be_total_opt_steps + has_opt_stepped_in_be, - "started": None, - "processed": None, - "completed": be_total_opt_steps, - }, + "current": {"ready": be_total_opt_steps + has_opt_stepped_in_be, "completed": be_total_opt_steps}, }, "zero_grad": { "total": { "ready": nbe_total_zero_grad + be_total_zero_grad, "started": nbe_total_zero_grad + be_total_zero_grad, - "processed": None, "completed": nbe_total_zero_grad + be_total_zero_grad, }, "current": { "ready": be_total_zero_grad, "started": be_total_zero_grad, - "processed": None, "completed": be_total_zero_grad, }, }, diff --git a/tests/trainer/test_progress.py b/tests/trainer/test_progress.py index 63b8156f24050..ef2b8d2888573 100644 --- a/tests/trainer/test_progress.py +++ b/tests/trainer/test_progress.py @@ -15,72 +15,78 @@ import pytest -from pytorch_lightning.trainer.progress import BaseProgress, OptimizerProgress, Progress, Tracker - - -def test_progress_getattr_setattr(): - p = Tracker(ready=10, completed=None) - # can read - assert p.completed is None - # can't read non-existing attr - with pytest.raises(AttributeError, match="object has no attribute 'non_existing_attr'"): - p.non_existing_attr - # can set new attr - p.non_existing_attr = 10 - # can't write unused attr - with pytest.raises(AttributeError, match="'completed' attribute is meant to be unused"): - p.completed = 10 - with pytest.raises(TypeError, match="unsupported operand type"): - # default python error, would need to override `__getattribute__` - # but we want to allow reading the `None` value - p.completed += 10 - - -def test_progress_reset(): - p = Tracker(ready=1, started=2, completed=None) +from pytorch_lightning.trainer.progress import ( + BaseProgress, + OptimizerProgress, + ProcessedTracker, + Progress, + ReadyCompletedTracker, + StartedTracker, +) + + +def test_tracker_reset(): + p = StartedTracker(ready=1, started=2) p.reset() - assert p == Tracker(completed=None) + assert p == StartedTracker() -def test_progress_repr(): - assert repr(Tracker(ready=None, started=None)) == "Tracker(processed=0, completed=0)" +def test_tracker_reset_on_restart(): + t = StartedTracker(ready=3, started=3, completed=2) + t.reset_on_restart() + assert t == StartedTracker(ready=2, started=2, completed=2) + + t = ProcessedTracker(ready=4, started=4, processed=3, completed=2) + t.reset_on_restart() + assert t == ProcessedTracker(ready=3, started=3, processed=3, completed=3) @pytest.mark.parametrize("attr", ("ready", "started", "processed", "completed")) -def test_base_progress_increment(attr): +def test_progress_increment(attr): p = Progress() fn = getattr(p, f"increment_{attr}") fn() - expected = Tracker(**{attr: 1}) + expected = ProcessedTracker(**{attr: 1}) assert p.total == expected assert p.current == expected -def test_base_progress_from_defaults(): - actual = Progress.from_defaults(completed=5, started=None) - expected = Progress(total=Tracker(started=None, completed=5), current=Tracker(started=None, completed=5)) +def test_progress_from_defaults(): + actual = Progress.from_defaults(StartedTracker, completed=5) + expected = Progress(total=StartedTracker(completed=5), current=StartedTracker(completed=5)) assert actual == expected -def test_epoch_loop_progress_increment_sequence(): - """Test sequences for incrementing batches reads and epochs.""" +def test_progress_increment_sequence(): + """Test sequence for incrementing.""" batch = Progress() batch.increment_ready() - assert batch.total == Tracker(ready=1) - assert batch.current == Tracker(ready=1) + assert batch.total == ProcessedTracker(ready=1) + assert batch.current == ProcessedTracker(ready=1) batch.increment_started() - assert batch.total == Tracker(ready=1, started=1) - assert batch.current == Tracker(ready=1, started=1) + assert batch.total == ProcessedTracker(ready=1, started=1) + assert batch.current == ProcessedTracker(ready=1, started=1) batch.increment_processed() - assert batch.total == Tracker(ready=1, started=1, processed=1) - assert batch.current == Tracker(ready=1, started=1, processed=1) + assert batch.total == ProcessedTracker(ready=1, started=1, processed=1) + assert batch.current == ProcessedTracker(ready=1, started=1, processed=1) batch.increment_completed() - assert batch.total == Tracker(ready=1, started=1, processed=1, completed=1) - assert batch.current == Tracker(ready=1, started=1, processed=1, completed=1) + assert batch.total == ProcessedTracker(ready=1, started=1, processed=1, completed=1) + assert batch.current == ProcessedTracker(ready=1, started=1, processed=1, completed=1) + + +def test_progress_raises(): + with pytest.raises(ValueError, match="instances should be of the same class"): + Progress(ReadyCompletedTracker(), ProcessedTracker()) + + p = Progress(ReadyCompletedTracker(), ReadyCompletedTracker()) + with pytest.raises(TypeError, match="ReadyCompletedTracker` doesn't have a `started` attribute"): + p.increment_started() + with pytest.raises(TypeError, match="ReadyCompletedTracker` doesn't have a `processed` attribute"): + p.increment_processed() def test_optimizer_progress_default_factory(): @@ -99,4 +105,4 @@ def test_optimizer_progress_default_factory(): def test_deepcopy(): _ = deepcopy(BaseProgress()) _ = deepcopy(Progress()) - _ = deepcopy(Tracker()) + _ = deepcopy(ProcessedTracker())