1111# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212# See the License for the specific language governing permissions and
1313# limitations under the License.
14-
1514from dataclasses import dataclass , field
15+ from typing import Optional
1616
1717
1818@dataclass
19- class ProgressState :
19+ class Tracker :
2020 """
21- Basic dataclass to track event progress.
21+ Track an event's progress.
2222
2323 Args:
2424 ready: Intended to track the number of events ready to start.
25- started: Intended to be incremented after the event is started (e.g. after `on_*_start runs).
25+ started: Intended to be incremented after the event is started (e.g. after `` on_*_start`` runs).
2626 processed: Intended to be incremented after the event is processed.
27- completed: Intended to be incremented after the event completes (e.g. after `on_*_end` runs).
27+ completed: Intended to be incremented after the event completes (e.g. after ``on_*_end`` runs).
28+
29+ Attributes set to ``None`` are treated as unused and are restricted.
2830 """
29- ready : int = 0
30- started : int = 0
31- processed : int = 0
32- completed : int = 0
31+ ready : Optional [ int ] = 0
32+ started : Optional [ int ] = 0
33+ processed : Optional [ int ] = 0
34+ completed : Optional [ int ] = 0
3335
3436 def reset (self ) -> None :
35- self .ready = 0
36- self .started = 0
37- self .processed = 0
38- self .completed = 0
37+ if self .ready is not None :
38+ self .ready = 0
39+ if self .started is not None :
40+ self .started = 0
41+ if self .processed is not None :
42+ self .processed = 0
43+ if self .completed is not None :
44+ self .completed = 0
45+
46+ def __setattr__ (self , key : str , value : int ) -> None :
47+ if getattr (self , key , 0 ) is None :
48+ raise AttributeError (f"The '{ key } ' attribute is meant to be unused" )
49+ return super ().__setattr__ (key , value )
50+
51+ def __repr__ (self ):
52+ # hide `None` fields
53+ args = [f"{ k } ={ v } " for k , v in self .__dict__ .items () if v is not None ]
54+ return f"{ self .__class__ .__name__ } ({ ', ' .join (args )} )"
3955
4056
4157@dataclass
4258class Progress :
4359 """
44- Basic dataclass to track aggregated and current progress states .
60+ Track aggregated and current progress.
4561
4662 Args:
4763 total: Intended to track the total progress of an event
4864 current: Intended to track the current progress of an event
4965 """
50- total : ProgressState = field (default_factory = ProgressState )
51- current : ProgressState = field (default_factory = ProgressState )
66+ total : Tracker = field (default_factory = Tracker )
67+ current : Tracker = field (default_factory = Tracker )
5268
5369 def increment_ready (self ) -> None :
70+ if self .total .ready is None or self .current .ready is None :
71+ return
5472 self .total .ready += 1
5573 self .current .ready += 1
5674
5775 def increment_started (self ) -> None :
76+ if self .total .started is None or self .current .started is None :
77+ return
5878 self .total .started += 1
5979 self .current .started += 1
6080
6181 def increment_processed (self ) -> None :
82+ if self .total .processed is None or self .current .processed is None :
83+ return
6284 self .total .processed += 1
6385 self .current .processed += 1
6486
6587 def increment_completed (self ) -> None :
88+ if self .total .completed is None or self .current .completed is None :
89+ return
6690 self .total .completed += 1
6791 self .current .completed += 1
6892
93+ @classmethod
94+ def from_defaults (cls , ** kwargs : Optional [int ]) -> 'Progress' :
95+ return cls (total = Tracker (** kwargs ), current = Tracker (** kwargs ))
96+
6997
7098@dataclass
7199class LoopProgress :
72100 """
73- Dataclass to track loop progress during execution.
101+ Track loop progress during execution.
74102
75103 These counters are local to a trainer rank. By default, they are not globally synced across all ranks.
104+
76105 Args:
77106 epoch: Tracks epochs progress.
78107 batch: Tracks batch progress.
@@ -87,3 +116,65 @@ def increment_epoch_completed(self) -> None:
87116 def reset_on_epoch (self ) -> None :
88117 self .batch .current .reset ()
89118 self .epoch .current .reset ()
119+
120+
121+ @dataclass
122+ class OptimizationProgress :
123+ """
124+ Track optimization progress.
125+
126+ Args:
127+ optimizer: Tracks optimizer progress.
128+ scheduler: Tracks scheduler progress.
129+ """
130+ optimizer : Progress = Progress .from_defaults (processed = None )
131+ scheduler : Progress = Progress .from_defaults (started = None , processed = None )
132+ zero_grad : Progress = Progress .from_defaults (processed = None )
133+
134+ @property
135+ def optimizer_steps (self ) -> int :
136+ return self .optimizer .total .completed
137+
138+ @property
139+ def scheduler_steps (self ) -> int :
140+ return self .scheduler .total .completed
141+
142+
143+ @dataclass
144+ class TrainingProgress (Progress ):
145+ """
146+ Extends ``Progress`` with training specific attributes
147+
148+ Args:
149+ optimization: Tracks optimization progress
150+ """
151+ optimization : OptimizationProgress = field (default_factory = OptimizationProgress )
152+
153+
154+ @dataclass
155+ class TrainingLoopProgress (LoopProgress ):
156+ epoch : TrainingProgress = field (default_factory = TrainingProgress )
157+
158+ def reset_on_epoch (self ) -> None :
159+ # override to avoid resetting `epoch.current`
160+ self .batch .current .reset ()
161+
162+
163+ @dataclass
164+ class FitLoopProgress :
165+ train : TrainingLoopProgress = field (default_factory = TrainingLoopProgress )
166+ val : LoopProgress = field (default_factory = LoopProgress )
167+
168+
169+ @dataclass
170+ class LoopState :
171+ """
172+ Basic dataclass to track loop progress across trainer functions during trainer execution.
173+
174+ This class will be removed and these attributes will live in each loop.
175+ """
176+
177+ fit : FitLoopProgress = field (default_factory = FitLoopProgress )
178+ val : LoopProgress = field (default_factory = LoopProgress )
179+ test : LoopProgress = field (default_factory = LoopProgress )
180+ predict : LoopProgress = field (default_factory = LoopProgress )
0 commit comments