Skip to content

Commit 33a1f52

Browse files
carmoccaananthsub
andauthored
[2/N] Define dataclasses for progress tracking (#7574)
Co-authored-by: ananthsub <[email protected]>
1 parent 110e49d commit 33a1f52

File tree

3 files changed

+192
-139
lines changed

3 files changed

+192
-139
lines changed

CHANGELOG.md

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
2121
- Added support for checkpointing based on a provided time interval during training ([#7515](https://github.com/PyTorchLightning/pytorch-lightning/pull/7515))
2222

2323

24-
- Added dataclasses for progress tracking ([#6603](https://github.com/PyTorchLightning/pytorch-lightning/pull/6603))
24+
- Added dataclasses for progress tracking (
25+
[#6603](https://github.com/PyTorchLightning/pytorch-lightning/pull/6603),
26+
[#7574](https://github.com/PyTorchLightning/pytorch-lightning/pull/7574))
2527

2628

2729
- Added argument `trainer.predict(ckpt_path)` ([#7430](https://github.com/PyTorchLightning/pytorch-lightning/pull/7430))

pytorch_lightning/trainer/progress.py

Lines changed: 108 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -11,68 +11,97 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
1514
from dataclasses import dataclass, field
15+
from typing import Optional
1616

1717

1818
@dataclass
19-
class ProgressState:
19+
class Tracker:
2020
"""
21-
Basic dataclass to track event progress.
21+
Track an event's progress.
2222
2323
Args:
2424
ready: Intended to track the number of events ready to start.
25-
started: Intended to be incremented after the event is started (e.g. after `on_*_start runs).
25+
started: Intended to be incremented after the event is started (e.g. after ``on_*_start`` runs).
2626
processed: Intended to be incremented after the event is processed.
27-
completed: Intended to be incremented after the event completes (e.g. after `on_*_end` runs).
27+
completed: Intended to be incremented after the event completes (e.g. after ``on_*_end`` runs).
28+
29+
Attributes set to ``None`` are treated as unused and are restricted.
2830
"""
29-
ready: int = 0
30-
started: int = 0
31-
processed: int = 0
32-
completed: int = 0
31+
ready: Optional[int] = 0
32+
started: Optional[int] = 0
33+
processed: Optional[int] = 0
34+
completed: Optional[int] = 0
3335

3436
def reset(self) -> None:
35-
self.ready = 0
36-
self.started = 0
37-
self.processed = 0
38-
self.completed = 0
37+
if self.ready is not None:
38+
self.ready = 0
39+
if self.started is not None:
40+
self.started = 0
41+
if self.processed is not None:
42+
self.processed = 0
43+
if self.completed is not None:
44+
self.completed = 0
45+
46+
def __setattr__(self, key: str, value: int) -> None:
47+
if getattr(self, key, 0) is None:
48+
raise AttributeError(f"The '{key}' attribute is meant to be unused")
49+
return super().__setattr__(key, value)
50+
51+
def __repr__(self):
52+
# hide `None` fields
53+
args = [f"{k}={v}" for k, v in self.__dict__.items() if v is not None]
54+
return f"{self.__class__.__name__}({', '.join(args)})"
3955

4056

4157
@dataclass
4258
class Progress:
4359
"""
44-
Basic dataclass to track aggregated and current progress states.
60+
Track aggregated and current progress.
4561
4662
Args:
4763
total: Intended to track the total progress of an event
4864
current: Intended to track the current progress of an event
4965
"""
50-
total: ProgressState = field(default_factory=ProgressState)
51-
current: ProgressState = field(default_factory=ProgressState)
66+
total: Tracker = field(default_factory=Tracker)
67+
current: Tracker = field(default_factory=Tracker)
5268

5369
def increment_ready(self) -> None:
70+
if self.total.ready is None or self.current.ready is None:
71+
return
5472
self.total.ready += 1
5573
self.current.ready += 1
5674

5775
def increment_started(self) -> None:
76+
if self.total.started is None or self.current.started is None:
77+
return
5878
self.total.started += 1
5979
self.current.started += 1
6080

6181
def increment_processed(self) -> None:
82+
if self.total.processed is None or self.current.processed is None:
83+
return
6284
self.total.processed += 1
6385
self.current.processed += 1
6486

6587
def increment_completed(self) -> None:
88+
if self.total.completed is None or self.current.completed is None:
89+
return
6690
self.total.completed += 1
6791
self.current.completed += 1
6892

93+
@classmethod
94+
def from_defaults(cls, **kwargs: Optional[int]) -> 'Progress':
95+
return cls(total=Tracker(**kwargs), current=Tracker(**kwargs))
96+
6997

7098
@dataclass
7199
class LoopProgress:
72100
"""
73-
Dataclass to track loop progress during execution.
101+
Track loop progress during execution.
74102
75103
These counters are local to a trainer rank. By default, they are not globally synced across all ranks.
104+
76105
Args:
77106
epoch: Tracks epochs progress.
78107
batch: Tracks batch progress.
@@ -87,3 +116,65 @@ def increment_epoch_completed(self) -> None:
87116
def reset_on_epoch(self) -> None:
88117
self.batch.current.reset()
89118
self.epoch.current.reset()
119+
120+
121+
@dataclass
122+
class OptimizationProgress:
123+
"""
124+
Track optimization progress.
125+
126+
Args:
127+
optimizer: Tracks optimizer progress.
128+
scheduler: Tracks scheduler progress.
129+
"""
130+
optimizer: Progress = Progress.from_defaults(processed=None)
131+
scheduler: Progress = Progress.from_defaults(started=None, processed=None)
132+
zero_grad: Progress = Progress.from_defaults(processed=None)
133+
134+
@property
135+
def optimizer_steps(self) -> int:
136+
return self.optimizer.total.completed
137+
138+
@property
139+
def scheduler_steps(self) -> int:
140+
return self.scheduler.total.completed
141+
142+
143+
@dataclass
144+
class TrainingProgress(Progress):
145+
"""
146+
Extends ``Progress`` with training specific attributes
147+
148+
Args:
149+
optimization: Tracks optimization progress
150+
"""
151+
optimization: OptimizationProgress = field(default_factory=OptimizationProgress)
152+
153+
154+
@dataclass
155+
class TrainingLoopProgress(LoopProgress):
156+
epoch: TrainingProgress = field(default_factory=TrainingProgress)
157+
158+
def reset_on_epoch(self) -> None:
159+
# override to avoid resetting `epoch.current`
160+
self.batch.current.reset()
161+
162+
163+
@dataclass
164+
class FitLoopProgress:
165+
train: TrainingLoopProgress = field(default_factory=TrainingLoopProgress)
166+
val: LoopProgress = field(default_factory=LoopProgress)
167+
168+
169+
@dataclass
170+
class LoopState:
171+
"""
172+
Basic dataclass to track loop progress across trainer functions during trainer execution.
173+
174+
This class will be removed and these attributes will live in each loop.
175+
"""
176+
177+
fit: FitLoopProgress = field(default_factory=FitLoopProgress)
178+
val: LoopProgress = field(default_factory=LoopProgress)
179+
test: LoopProgress = field(default_factory=LoopProgress)
180+
predict: LoopProgress = field(default_factory=LoopProgress)

0 commit comments

Comments
 (0)