Skip to content

Commit 49c0485

Browse files
authored
Avoid optional Tracker attributes and enable mypy (#9320)
1 parent ff1e691 commit 49c0485

File tree

6 files changed

+143
-133
lines changed

6 files changed

+143
-133
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
2424

2525
- Progress tracking
2626
* Integrate `TrainingEpochLoop.total_batch_idx` ([#8598](https://github.com/PyTorchLightning/pytorch-lightning/pull/8598)
27+
* Avoid optional `Tracker` attributes ([#9320](https://github.com/PyTorchLightning/pytorch-lightning/pull/9320)
2728

2829

2930
- Added `batch_size` and `rank_zero_only` arguments for `log_dict` to match `log` ([#8628](https://github.com/PyTorchLightning/pytorch-lightning/pull/8628))

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@ module = [
6666
"pytorch_lightning.trainer.evaluation_loop",
6767
"pytorch_lightning.trainer.connectors.logger_connector.fx_validator",
6868
"pytorch_lightning.trainer.connectors.logger_connector.logger_connector",
69+
"pytorch_lightning.trainer.progress",
6970
"pytorch_lightning.tuner.auto_gpu_select",
7071
"pytorch_lightning.utilities.apply_func",
7172
"pytorch_lightning.utilities.argparse",

pytorch_lightning/trainer/progress.py

Lines changed: 78 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -12,13 +12,13 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
from dataclasses import asdict, dataclass, field
15-
from typing import Optional
15+
from typing import Type
1616

1717

1818
@dataclass
1919
class BaseProgress:
2020
"""
21-
Mixin that implements state-loading utiltiies for dataclasses.
21+
Mixin that implements state-loading utilities for dataclasses.
2222
"""
2323

2424
def state_dict(self) -> dict:
@@ -35,63 +35,83 @@ def from_state_dict(cls, state_dict: dict) -> "BaseProgress":
3535

3636

3737
@dataclass
38-
class Tracker(BaseProgress):
38+
class ReadyCompletedTracker(BaseProgress):
3939
"""
4040
Track an event's progress.
4141
4242
Args:
4343
ready: Intended to track the number of events ready to start.
44-
started: Intended to be incremented after the event is started (e.g. after ``on_*_start`` runs).
45-
processed: Intended to be incremented after the event is processed.
4644
completed: Intended to be incremented after the event completes (e.g. after ``on_*_end`` runs).
4745
4846
These attributes should be increased in order, that is, :attr:`ready` first and :attr:`completed` last.
49-
Attributes set to ``None`` are treated as unused and are restricted.
5047
"""
5148

52-
ready: Optional[int] = 0
53-
started: Optional[int] = 0
54-
processed: Optional[int] = 0
55-
completed: Optional[int] = 0
49+
ready: int = 0
50+
completed: int = 0
5651

5752
def reset(self) -> None:
58-
if self.ready is not None:
59-
self.ready = 0
60-
if self.started is not None:
61-
self.started = 0
62-
if self.processed is not None:
63-
self.processed = 0
64-
if self.completed is not None:
65-
self.completed = 0
66-
67-
def __setattr__(self, key: str, value: int) -> None:
68-
"""Restrict writing to attributes set to ``None``."""
69-
if getattr(self, key, 0) is None:
70-
raise AttributeError(f"The '{key}' attribute is meant to be unused")
71-
return super().__setattr__(key, value)
72-
73-
def __repr__(self) -> str:
74-
"""Custom implementation to hide ``None`` fields."""
75-
args = [f"{k}={v}" for k, v in self.__dict__.items() if v is not None]
76-
return f"{self.__class__.__name__}({', '.join(args)})"
53+
"""Reset the state."""
54+
self.ready = 0
55+
self.completed = 0
7756

7857
def reset_on_restart(self) -> None:
7958
"""
8059
Reset the progress on restart.
60+
8161
If there is a failure before all attributes are increased,
82-
we restore the attributes to the last fully completed value.
62+
restore the attributes to the last fully completed value.
8363
"""
84-
# choose in case `processed` is unused
85-
value = self.completed if self.processed is None else self.processed
64+
self.ready = self.completed
65+
66+
67+
@dataclass
68+
class StartedTracker(ReadyCompletedTracker):
69+
"""
70+
Track an event's progress.
71+
72+
Args:
73+
ready: Intended to track the number of events ready to start.
74+
started: Intended to be incremented after the event is started (e.g. after ``on_*_start`` runs).
75+
completed: Intended to be incremented after the event completes (e.g. after ``on_*_end`` runs).
76+
77+
These attributes should be increased in order, that is, :attr:`ready` first and :attr:`completed` last.
78+
"""
79+
80+
started: int = 0
81+
82+
def reset(self) -> None:
83+
super().reset()
84+
self.started = 0
85+
86+
def reset_on_restart(self) -> None:
87+
super().reset_on_restart()
88+
self.started = self.completed
89+
90+
91+
@dataclass
92+
class ProcessedTracker(StartedTracker):
93+
"""
94+
Track an event's progress.
95+
96+
Args:
97+
ready: Intended to track the number of events ready to start.
98+
started: Intended to be incremented after the event is started (e.g. after ``on_*_start`` runs).
99+
processed: Intended to be incremented after the event is processed.
100+
completed: Intended to be incremented after the event completes (e.g. after ``on_*_end`` runs).
101+
102+
These attributes should be increased in order, that is, :attr:`ready` first and :attr:`completed` last.
103+
"""
86104

87-
if self.ready is not None:
88-
self.ready = value
89-
if self.started is not None:
90-
self.started = value
91-
if self.processed is not None:
92-
self.processed = value
93-
if self.completed is not None:
94-
self.completed = value
105+
processed: int = 0
106+
107+
def reset(self) -> None:
108+
super().reset()
109+
self.processed = 0
110+
111+
def reset_on_restart(self) -> None:
112+
# use `processed` in this case as the reset value
113+
self.completed = self.processed
114+
super().reset_on_restart()
95115

96116

97117
@dataclass
@@ -104,18 +124,26 @@ class Progress(BaseProgress):
104124
current: Intended to track the current progress of an event.
105125
"""
106126

107-
total: Tracker = field(default_factory=Tracker)
108-
current: Tracker = field(default_factory=Tracker)
127+
total: ReadyCompletedTracker = field(default_factory=ProcessedTracker)
128+
current: ReadyCompletedTracker = field(default_factory=ProcessedTracker)
129+
130+
def __post_init__(self) -> None:
131+
if type(self.total) is not type(self.current): # noqa: E721
132+
raise ValueError("The `total` and `current` instances should be of the same class")
109133

110134
def increment_ready(self) -> None:
111135
self.total.ready += 1
112136
self.current.ready += 1
113137

114138
def increment_started(self) -> None:
139+
if not isinstance(self.total, StartedTracker):
140+
raise TypeError(f"`{self.total.__class__.__name__}` doesn't have a `started` attribute")
115141
self.total.started += 1
116142
self.current.started += 1
117143

118144
def increment_processed(self) -> None:
145+
if not isinstance(self.total, ProcessedTracker):
146+
raise TypeError(f"`{self.total.__class__.__name__}` doesn't have a `processed` attribute")
119147
self.total.processed += 1
120148
self.current.processed += 1
121149

@@ -124,9 +152,9 @@ def increment_completed(self) -> None:
124152
self.current.completed += 1
125153

126154
@classmethod
127-
def from_defaults(cls, **kwargs: Optional[int]) -> "Progress":
155+
def from_defaults(cls, tracker_cls: Type[ReadyCompletedTracker], **kwargs: int) -> "Progress":
128156
"""Utility function to easily create an instance from keyword arguments to both ``Tracker``s."""
129-
return cls(total=Tracker(**kwargs), current=Tracker(**kwargs))
157+
return cls(total=tracker_cls(**kwargs), current=tracker_cls(**kwargs))
130158

131159
def load_state_dict(self, state_dict: dict) -> None:
132160
self.total.load_state_dict(state_dict["total"])
@@ -144,8 +172,8 @@ class DataLoaderProgress(Progress):
144172
current: Tracks the current dataloader progress.
145173
"""
146174

147-
total: Tracker = field(default_factory=lambda: Tracker(started=None, processed=None))
148-
current: Tracker = field(default_factory=lambda: Tracker(started=None, processed=None))
175+
total: ReadyCompletedTracker = field(default_factory=ReadyCompletedTracker)
176+
current: ReadyCompletedTracker = field(default_factory=ReadyCompletedTracker)
149177

150178

151179
@dataclass
@@ -159,8 +187,8 @@ class SchedulerProgress(Progress):
159187
current: Tracks the current scheduler progress.
160188
"""
161189

162-
total: Tracker = field(default_factory=lambda: Tracker(started=None, processed=None))
163-
current: Tracker = field(default_factory=lambda: Tracker(started=None, processed=None))
190+
total: ReadyCompletedTracker = field(default_factory=ReadyCompletedTracker)
191+
current: ReadyCompletedTracker = field(default_factory=ReadyCompletedTracker)
164192

165193

166194
@dataclass
@@ -173,8 +201,8 @@ class OptimizerProgress(BaseProgress):
173201
zero_grad: Tracks ``optimizer.zero_grad`` calls.
174202
"""
175203

176-
step: Progress = field(default_factory=lambda: Progress.from_defaults(started=None, processed=None))
177-
zero_grad: Progress = field(default_factory=lambda: Progress.from_defaults(processed=None))
204+
step: Progress = field(default_factory=lambda: Progress.from_defaults(ReadyCompletedTracker))
205+
zero_grad: Progress = field(default_factory=lambda: Progress.from_defaults(StartedTracker))
178206

179207
def reset_on_epoch(self) -> None:
180208
self.step.current.reset()

tests/loops/test_loop_state_dict.py

Lines changed: 10 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -53,28 +53,25 @@ def test_loops_state_dict_structure():
5353
"current": {"ready": 0, "started": 0, "processed": 0, "completed": 0},
5454
},
5555
"epoch_loop.scheduler_progress": {
56-
"total": {"ready": 0, "started": None, "processed": None, "completed": 0},
57-
"current": {"ready": 0, "started": None, "processed": None, "completed": 0},
56+
"total": {"ready": 0, "completed": 0},
57+
"current": {"ready": 0, "completed": 0},
5858
},
5959
"epoch_loop.batch_loop.optimizer_loop.state_dict": {},
6060
"epoch_loop.batch_loop.state_dict": {},
6161
"epoch_loop.batch_loop.optimizer_loop.optim_progress": {
6262
"optimizer": {
63-
"step": {
64-
"total": {"ready": 0, "started": None, "processed": None, "completed": 0},
65-
"current": {"ready": 0, "started": None, "processed": None, "completed": 0},
66-
},
63+
"step": {"total": {"ready": 0, "completed": 0}, "current": {"ready": 0, "completed": 0}},
6764
"zero_grad": {
68-
"total": {"ready": 0, "started": 0, "processed": None, "completed": 0},
69-
"current": {"ready": 0, "started": 0, "processed": None, "completed": 0},
65+
"total": {"ready": 0, "started": 0, "completed": 0},
66+
"current": {"ready": 0, "started": 0, "completed": 0},
7067
},
7168
},
7269
"optimizer_idx": 0,
7370
},
7471
"epoch_loop.val_loop.state_dict": {},
7572
"epoch_loop.val_loop.dataloader_progress": {
76-
"total": {"ready": 0, "started": None, "processed": None, "completed": 0},
77-
"current": {"ready": 0, "started": None, "processed": None, "completed": 0},
73+
"total": {"ready": 0, "completed": 0},
74+
"current": {"ready": 0, "completed": 0},
7875
},
7976
"epoch_loop.val_loop.epoch_loop.state_dict": {},
8077
"epoch_loop.val_loop.epoch_loop.batch_progress": {
@@ -102,10 +99,7 @@ def test_loops_state_dict_structure():
10299
},
103100
"validate_loop": {
104101
"state_dict": {},
105-
"dataloader_progress": {
106-
"total": {"ready": 0, "started": None, "processed": None, "completed": 0},
107-
"current": {"ready": 0, "started": None, "processed": None, "completed": 0},
108-
},
102+
"dataloader_progress": {"total": {"ready": 0, "completed": 0}, "current": {"ready": 0, "completed": 0}},
109103
"epoch_loop.state_dict": {},
110104
"epoch_loop.batch_progress": {
111105
"total": {"ready": 0, "started": 0, "processed": 0, "completed": 0},
@@ -121,10 +115,7 @@ def test_loops_state_dict_structure():
121115
},
122116
"test_loop": {
123117
"state_dict": {},
124-
"dataloader_progress": {
125-
"total": {"ready": 0, "started": None, "processed": None, "completed": 0},
126-
"current": {"ready": 0, "started": None, "processed": None, "completed": 0},
127-
},
118+
"dataloader_progress": {"total": {"ready": 0, "completed": 0}, "current": {"ready": 0, "completed": 0}},
128119
"epoch_loop.state_dict": {},
129120
"epoch_loop.batch_progress": {
130121
"total": {"ready": 0, "started": 0, "processed": 0, "completed": 0},
@@ -140,10 +131,7 @@ def test_loops_state_dict_structure():
140131
},
141132
"predict_loop": {
142133
"state_dict": {},
143-
"dataloader_progress": {
144-
"total": {"ready": 0, "started": None, "processed": None, "completed": 0},
145-
"current": {"ready": 0, "started": None, "processed": None, "completed": 0},
146-
},
134+
"dataloader_progress": {"total": {"ready": 0, "completed": 0}, "current": {"ready": 0, "completed": 0}},
147135
"epoch_loop.state_dict": {},
148136
"epoch_loop.batch_progress": {
149137
"total": {"ready": 0, "started": 0, "processed": 0, "completed": 0},

tests/loops/test_loops.py

Lines changed: 5 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -285,8 +285,8 @@ def val_dataloader(self):
285285

286286
total_dataloader = stop_epoch * n_dataloaders + stop_dataloader
287287
expected = {
288-
"total": {"ready": total_dataloader + 1, "started": None, "processed": None, "completed": total_dataloader},
289-
"current": {"ready": stop_dataloader + 1, "started": None, "processed": None, "completed": stop_dataloader},
288+
"total": {"ready": total_dataloader + 1, "completed": total_dataloader},
289+
"current": {"ready": stop_dataloader + 1, "completed": stop_dataloader},
290290
}
291291
assert checkpoint["epoch_loop.val_loop.dataloader_progress"] == expected
292292

@@ -452,13 +452,8 @@ def configure_optimizers_multiple(self):
452452
},
453453
},
454454
"epoch_loop.scheduler_progress": {
455-
"total": {
456-
"ready": nbe_sch_steps + be_sch_steps,
457-
"started": None,
458-
"processed": None,
459-
"completed": nbe_sch_steps + be_sch_steps,
460-
},
461-
"current": {"ready": be_sch_steps, "started": None, "processed": None, "completed": be_sch_steps},
455+
"total": {"ready": nbe_sch_steps + be_sch_steps, "completed": nbe_sch_steps + be_sch_steps},
456+
"current": {"ready": be_sch_steps, "completed": be_sch_steps},
462457
},
463458
"epoch_loop.batch_loop.state_dict": ANY,
464459
"epoch_loop.batch_loop.optimizer_loop.state_dict": {},
@@ -468,28 +463,19 @@ def configure_optimizers_multiple(self):
468463
"step": {
469464
"total": {
470465
"ready": nbe_total_opt_steps + be_total_opt_steps + has_opt_stepped_in_be,
471-
"started": None,
472-
"processed": None,
473466
"completed": nbe_total_opt_steps + be_total_opt_steps,
474467
},
475-
"current": {
476-
"ready": be_total_opt_steps + has_opt_stepped_in_be,
477-
"started": None,
478-
"processed": None,
479-
"completed": be_total_opt_steps,
480-
},
468+
"current": {"ready": be_total_opt_steps + has_opt_stepped_in_be, "completed": be_total_opt_steps},
481469
},
482470
"zero_grad": {
483471
"total": {
484472
"ready": nbe_total_zero_grad + be_total_zero_grad,
485473
"started": nbe_total_zero_grad + be_total_zero_grad,
486-
"processed": None,
487474
"completed": nbe_total_zero_grad + be_total_zero_grad,
488475
},
489476
"current": {
490477
"ready": be_total_zero_grad,
491478
"started": be_total_zero_grad,
492-
"processed": None,
493479
"completed": be_total_zero_grad,
494480
},
495481
},

0 commit comments

Comments
 (0)