Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Progress tracking
* Integrate `TrainingEpochLoop.total_batch_idx` ([#8598](https://github.com/PyTorchLightning/pytorch-lightning/pull/8598)
* Avoid optional `Tracker` attributes ([#9320](https://github.com/PyTorchLightning/pytorch-lightning/pull/9320)


- Added `batch_size` and `rank_zero_only` arguments for `log_dict` to match `log` ([#8628](https://github.com/PyTorchLightning/pytorch-lightning/pull/8628))
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ module = [
"pytorch_lightning.trainer.evaluation_loop",
"pytorch_lightning.trainer.connectors.logger_connector.fx_validator",
"pytorch_lightning.trainer.connectors.logger_connector.logger_connector",
"pytorch_lightning.trainer.progress",
"pytorch_lightning.tuner.auto_gpu_select",
"pytorch_lightning.utilities.apply_func",
"pytorch_lightning.utilities.argparse",
Expand Down
128 changes: 78 additions & 50 deletions pytorch_lightning/trainer/progress.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import asdict, dataclass, field
from typing import Optional
from typing import Type


@dataclass
class BaseProgress:
"""
Mixin that implements state-loading utiltiies for dataclasses.
Mixin that implements state-loading utilities for dataclasses.
"""

def state_dict(self) -> dict:
Expand All @@ -35,63 +35,83 @@ def from_state_dict(cls, state_dict: dict) -> "BaseProgress":


@dataclass
class Tracker(BaseProgress):
class ReadyCompletedTracker(BaseProgress):
"""
Track an event's progress.

Args:
ready: Intended to track the number of events ready to start.
started: Intended to be incremented after the event is started (e.g. after ``on_*_start`` runs).
processed: Intended to be incremented after the event is processed.
completed: Intended to be incremented after the event completes (e.g. after ``on_*_end`` runs).

These attributes should be increased in order, that is, :attr:`ready` first and :attr:`completed` last.
Attributes set to ``None`` are treated as unused and are restricted.
"""

ready: Optional[int] = 0
started: Optional[int] = 0
processed: Optional[int] = 0
completed: Optional[int] = 0
ready: int = 0
completed: int = 0

def reset(self) -> None:
if self.ready is not None:
self.ready = 0
if self.started is not None:
self.started = 0
if self.processed is not None:
self.processed = 0
if self.completed is not None:
self.completed = 0

def __setattr__(self, key: str, value: int) -> None:
"""Restrict writing to attributes set to ``None``."""
if getattr(self, key, 0) is None:
raise AttributeError(f"The '{key}' attribute is meant to be unused")
return super().__setattr__(key, value)

def __repr__(self) -> str:
"""Custom implementation to hide ``None`` fields."""
args = [f"{k}={v}" for k, v in self.__dict__.items() if v is not None]
return f"{self.__class__.__name__}({', '.join(args)})"
"""Reset the state."""
self.ready = 0
self.completed = 0

def reset_on_restart(self) -> None:
"""
Reset the progress on restart.

If there is a failure before all attributes are increased,
we restore the attributes to the last fully completed value.
restore the attributes to the last fully completed value.
"""
# choose in case `processed` is unused
value = self.completed if self.processed is None else self.processed
self.ready = self.completed


@dataclass
class StartedTracker(ReadyCompletedTracker):
"""
Track an event's progress.

Args:
ready: Intended to track the number of events ready to start.
started: Intended to be incremented after the event is started (e.g. after ``on_*_start`` runs).
completed: Intended to be incremented after the event completes (e.g. after ``on_*_end`` runs).

These attributes should be increased in order, that is, :attr:`ready` first and :attr:`completed` last.
"""

started: int = 0

def reset(self) -> None:
super().reset()
self.started = 0

def reset_on_restart(self) -> None:
super().reset_on_restart()
self.started = self.completed


@dataclass
class ProcessedTracker(StartedTracker):
"""
Track an event's progress.

Args:
ready: Intended to track the number of events ready to start.
started: Intended to be incremented after the event is started (e.g. after ``on_*_start`` runs).
processed: Intended to be incremented after the event is processed.
completed: Intended to be incremented after the event completes (e.g. after ``on_*_end`` runs).

These attributes should be increased in order, that is, :attr:`ready` first and :attr:`completed` last.
"""

if self.ready is not None:
self.ready = value
if self.started is not None:
self.started = value
if self.processed is not None:
self.processed = value
if self.completed is not None:
self.completed = value
processed: int = 0

def reset(self) -> None:
super().reset()
self.processed = 0

def reset_on_restart(self) -> None:
# use `processed` in this case as the reset value
self.completed = self.processed
super().reset_on_restart()


@dataclass
Expand All @@ -104,18 +124,26 @@ class Progress(BaseProgress):
current: Intended to track the current progress of an event.
"""

total: Tracker = field(default_factory=Tracker)
current: Tracker = field(default_factory=Tracker)
total: ReadyCompletedTracker = field(default_factory=ProcessedTracker)
current: ReadyCompletedTracker = field(default_factory=ProcessedTracker)

def __post_init__(self) -> None:
if type(self.total) is not type(self.current): # noqa: E721
raise ValueError("The `total` and `current` instances should be of the same class")

def increment_ready(self) -> None:
self.total.ready += 1
self.current.ready += 1

def increment_started(self) -> None:
if not isinstance(self.total, StartedTracker):
raise TypeError(f"`{self.total.__class__.__name__}` doesn't have a `started` attribute")
self.total.started += 1
self.current.started += 1

def increment_processed(self) -> None:
if not isinstance(self.total, ProcessedTracker):
raise TypeError(f"`{self.total.__class__.__name__}` doesn't have a `processed` attribute")
self.total.processed += 1
self.current.processed += 1

Expand All @@ -124,9 +152,9 @@ def increment_completed(self) -> None:
self.current.completed += 1

@classmethod
def from_defaults(cls, **kwargs: Optional[int]) -> "Progress":
def from_defaults(cls, tracker_cls: Type[ReadyCompletedTracker], **kwargs: int) -> "Progress":
"""Utility function to easily create an instance from keyword arguments to both ``Tracker``s."""
return cls(total=Tracker(**kwargs), current=Tracker(**kwargs))
return cls(total=tracker_cls(**kwargs), current=tracker_cls(**kwargs))

def load_state_dict(self, state_dict: dict) -> None:
self.total.load_state_dict(state_dict["total"])
Expand All @@ -144,8 +172,8 @@ class DataLoaderProgress(Progress):
current: Tracks the current dataloader progress.
"""

total: Tracker = field(default_factory=lambda: Tracker(started=None, processed=None))
current: Tracker = field(default_factory=lambda: Tracker(started=None, processed=None))
total: ReadyCompletedTracker = field(default_factory=ReadyCompletedTracker)
current: ReadyCompletedTracker = field(default_factory=ReadyCompletedTracker)


@dataclass
Expand All @@ -159,8 +187,8 @@ class SchedulerProgress(Progress):
current: Tracks the current scheduler progress.
"""

total: Tracker = field(default_factory=lambda: Tracker(started=None, processed=None))
current: Tracker = field(default_factory=lambda: Tracker(started=None, processed=None))
total: ReadyCompletedTracker = field(default_factory=ReadyCompletedTracker)
current: ReadyCompletedTracker = field(default_factory=ReadyCompletedTracker)


@dataclass
Expand All @@ -173,8 +201,8 @@ class OptimizerProgress(BaseProgress):
zero_grad: Tracks ``optimizer.zero_grad`` calls.
"""

step: Progress = field(default_factory=lambda: Progress.from_defaults(started=None, processed=None))
zero_grad: Progress = field(default_factory=lambda: Progress.from_defaults(processed=None))
step: Progress = field(default_factory=lambda: Progress.from_defaults(ReadyCompletedTracker))
zero_grad: Progress = field(default_factory=lambda: Progress.from_defaults(StartedTracker))

def reset_on_epoch(self) -> None:
self.step.current.reset()
Expand Down
32 changes: 10 additions & 22 deletions tests/loops/test_loop_state_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,28 +53,25 @@ def test_loops_state_dict_structure():
"current": {"ready": 0, "started": 0, "processed": 0, "completed": 0},
},
"epoch_loop.scheduler_progress": {
"total": {"ready": 0, "started": None, "processed": None, "completed": 0},
"current": {"ready": 0, "started": None, "processed": None, "completed": 0},
"total": {"ready": 0, "completed": 0},
"current": {"ready": 0, "completed": 0},
},
"epoch_loop.batch_loop.optimizer_loop.state_dict": {},
"epoch_loop.batch_loop.state_dict": {},
"epoch_loop.batch_loop.optimizer_loop.optim_progress": {
"optimizer": {
"step": {
"total": {"ready": 0, "started": None, "processed": None, "completed": 0},
"current": {"ready": 0, "started": None, "processed": None, "completed": 0},
},
"step": {"total": {"ready": 0, "completed": 0}, "current": {"ready": 0, "completed": 0}},
"zero_grad": {
"total": {"ready": 0, "started": 0, "processed": None, "completed": 0},
"current": {"ready": 0, "started": 0, "processed": None, "completed": 0},
"total": {"ready": 0, "started": 0, "completed": 0},
"current": {"ready": 0, "started": 0, "completed": 0},
},
},
"optimizer_idx": 0,
},
"epoch_loop.val_loop.state_dict": {},
"epoch_loop.val_loop.dataloader_progress": {
"total": {"ready": 0, "started": None, "processed": None, "completed": 0},
"current": {"ready": 0, "started": None, "processed": None, "completed": 0},
"total": {"ready": 0, "completed": 0},
"current": {"ready": 0, "completed": 0},
},
"epoch_loop.val_loop.epoch_loop.state_dict": {},
"epoch_loop.val_loop.epoch_loop.batch_progress": {
Expand Down Expand Up @@ -102,10 +99,7 @@ def test_loops_state_dict_structure():
},
"validate_loop": {
"state_dict": {},
"dataloader_progress": {
"total": {"ready": 0, "started": None, "processed": None, "completed": 0},
"current": {"ready": 0, "started": None, "processed": None, "completed": 0},
},
"dataloader_progress": {"total": {"ready": 0, "completed": 0}, "current": {"ready": 0, "completed": 0}},
"epoch_loop.state_dict": {},
"epoch_loop.batch_progress": {
"total": {"ready": 0, "started": 0, "processed": 0, "completed": 0},
Expand All @@ -121,10 +115,7 @@ def test_loops_state_dict_structure():
},
"test_loop": {
"state_dict": {},
"dataloader_progress": {
"total": {"ready": 0, "started": None, "processed": None, "completed": 0},
"current": {"ready": 0, "started": None, "processed": None, "completed": 0},
},
"dataloader_progress": {"total": {"ready": 0, "completed": 0}, "current": {"ready": 0, "completed": 0}},
"epoch_loop.state_dict": {},
"epoch_loop.batch_progress": {
"total": {"ready": 0, "started": 0, "processed": 0, "completed": 0},
Expand All @@ -140,10 +131,7 @@ def test_loops_state_dict_structure():
},
"predict_loop": {
"state_dict": {},
"dataloader_progress": {
"total": {"ready": 0, "started": None, "processed": None, "completed": 0},
"current": {"ready": 0, "started": None, "processed": None, "completed": 0},
},
"dataloader_progress": {"total": {"ready": 0, "completed": 0}, "current": {"ready": 0, "completed": 0}},
"epoch_loop.state_dict": {},
"epoch_loop.batch_progress": {
"total": {"ready": 0, "started": 0, "processed": 0, "completed": 0},
Expand Down
24 changes: 5 additions & 19 deletions tests/loops/test_loops.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,8 +285,8 @@ def val_dataloader(self):

total_dataloader = stop_epoch * n_dataloaders + stop_dataloader
expected = {
"total": {"ready": total_dataloader + 1, "started": None, "processed": None, "completed": total_dataloader},
"current": {"ready": stop_dataloader + 1, "started": None, "processed": None, "completed": stop_dataloader},
"total": {"ready": total_dataloader + 1, "completed": total_dataloader},
"current": {"ready": stop_dataloader + 1, "completed": stop_dataloader},
}
assert checkpoint["epoch_loop.val_loop.dataloader_progress"] == expected

Expand Down Expand Up @@ -452,13 +452,8 @@ def configure_optimizers_multiple(self):
},
},
"epoch_loop.scheduler_progress": {
"total": {
"ready": nbe_sch_steps + be_sch_steps,
"started": None,
"processed": None,
"completed": nbe_sch_steps + be_sch_steps,
},
"current": {"ready": be_sch_steps, "started": None, "processed": None, "completed": be_sch_steps},
"total": {"ready": nbe_sch_steps + be_sch_steps, "completed": nbe_sch_steps + be_sch_steps},
"current": {"ready": be_sch_steps, "completed": be_sch_steps},
},
"epoch_loop.batch_loop.state_dict": ANY,
"epoch_loop.batch_loop.optimizer_loop.state_dict": {},
Expand All @@ -468,28 +463,19 @@ def configure_optimizers_multiple(self):
"step": {
"total": {
"ready": nbe_total_opt_steps + be_total_opt_steps + has_opt_stepped_in_be,
"started": None,
"processed": None,
"completed": nbe_total_opt_steps + be_total_opt_steps,
},
"current": {
"ready": be_total_opt_steps + has_opt_stepped_in_be,
"started": None,
"processed": None,
"completed": be_total_opt_steps,
},
"current": {"ready": be_total_opt_steps + has_opt_stepped_in_be, "completed": be_total_opt_steps},
},
"zero_grad": {
"total": {
"ready": nbe_total_zero_grad + be_total_zero_grad,
"started": nbe_total_zero_grad + be_total_zero_grad,
"processed": None,
"completed": nbe_total_zero_grad + be_total_zero_grad,
},
"current": {
"ready": be_total_zero_grad,
"started": be_total_zero_grad,
"processed": None,
"completed": be_total_zero_grad,
},
},
Expand Down
Loading