1212# See the License for the specific language governing permissions and
1313# limitations under the License.
1414from dataclasses import asdict , dataclass , field
15- from typing import Optional
15+ from typing import Type
1616
1717
1818@dataclass
1919class BaseProgress :
2020 """
21- Mixin that implements state-loading utiltiies for dataclasses.
21+ Mixin that implements state-loading utilities for dataclasses.
2222 """
2323
2424 def state_dict (self ) -> dict :
@@ -35,63 +35,83 @@ def from_state_dict(cls, state_dict: dict) -> "BaseProgress":
3535
3636
3737@dataclass
38- class Tracker (BaseProgress ):
38+ class ReadyCompletedTracker (BaseProgress ):
3939 """
4040 Track an event's progress.
4141
4242 Args:
4343 ready: Intended to track the number of events ready to start.
44- started: Intended to be incremented after the event is started (e.g. after ``on_*_start`` runs).
45- processed: Intended to be incremented after the event is processed.
4644 completed: Intended to be incremented after the event completes (e.g. after ``on_*_end`` runs).
4745
4846 These attributes should be increased in order, that is, :attr:`ready` first and :attr:`completed` last.
49- Attributes set to ``None`` are treated as unused and are restricted.
5047 """
5148
52- ready : Optional [int ] = 0
53- started : Optional [int ] = 0
54- processed : Optional [int ] = 0
55- completed : Optional [int ] = 0
49+ ready : int = 0
50+ completed : int = 0
5651
5752 def reset (self ) -> None :
58- if self .ready is not None :
59- self .ready = 0
60- if self .started is not None :
61- self .started = 0
62- if self .processed is not None :
63- self .processed = 0
64- if self .completed is not None :
65- self .completed = 0
66-
67- def __setattr__ (self , key : str , value : int ) -> None :
68- """Restrict writing to attributes set to ``None``."""
69- if getattr (self , key , 0 ) is None :
70- raise AttributeError (f"The '{ key } ' attribute is meant to be unused" )
71- return super ().__setattr__ (key , value )
72-
73- def __repr__ (self ) -> str :
74- """Custom implementation to hide ``None`` fields."""
75- args = [f"{ k } ={ v } " for k , v in self .__dict__ .items () if v is not None ]
76- return f"{ self .__class__ .__name__ } ({ ', ' .join (args )} )"
53+ """Reset the state."""
54+ self .ready = 0
55+ self .completed = 0
7756
7857 def reset_on_restart (self ) -> None :
7958 """
8059 Reset the progress on restart.
60+
8161 If there is a failure before all attributes are increased,
82- we restore the attributes to the last fully completed value.
62+ restore the attributes to the last fully completed value.
8363 """
84- # choose in case `processed` is unused
85- value = self .completed if self .processed is None else self .processed
64+ self .ready = self .completed
65+
66+
67+ @dataclass
68+ class StartedTracker (ReadyCompletedTracker ):
69+ """
70+ Track an event's progress.
71+
72+ Args:
73+ ready: Intended to track the number of events ready to start.
74+ started: Intended to be incremented after the event is started (e.g. after ``on_*_start`` runs).
75+ completed: Intended to be incremented after the event completes (e.g. after ``on_*_end`` runs).
76+
77+ These attributes should be increased in order, that is, :attr:`ready` first and :attr:`completed` last.
78+ """
79+
80+ started : int = 0
81+
82+ def reset (self ) -> None :
83+ super ().reset ()
84+ self .started = 0
85+
86+ def reset_on_restart (self ) -> None :
87+ super ().reset_on_restart ()
88+ self .started = self .completed
89+
90+
91+ @dataclass
92+ class ProcessedTracker (StartedTracker ):
93+ """
94+ Track an event's progress.
95+
96+ Args:
97+ ready: Intended to track the number of events ready to start.
98+ started: Intended to be incremented after the event is started (e.g. after ``on_*_start`` runs).
99+ processed: Intended to be incremented after the event is processed.
100+ completed: Intended to be incremented after the event completes (e.g. after ``on_*_end`` runs).
101+
102+ These attributes should be increased in order, that is, :attr:`ready` first and :attr:`completed` last.
103+ """
86104
87- if self .ready is not None :
88- self .ready = value
89- if self .started is not None :
90- self .started = value
91- if self .processed is not None :
92- self .processed = value
93- if self .completed is not None :
94- self .completed = value
105+ processed : int = 0
106+
107+ def reset (self ) -> None :
108+ super ().reset ()
109+ self .processed = 0
110+
111+ def reset_on_restart (self ) -> None :
112+ # use `processed` in this case as the reset value
113+ self .completed = self .processed
114+ super ().reset_on_restart ()
95115
96116
97117@dataclass
@@ -104,18 +124,26 @@ class Progress(BaseProgress):
104124 current: Intended to track the current progress of an event.
105125 """
106126
107- total : Tracker = field (default_factory = Tracker )
108- current : Tracker = field (default_factory = Tracker )
127+ total : ReadyCompletedTracker = field (default_factory = ProcessedTracker )
128+ current : ReadyCompletedTracker = field (default_factory = ProcessedTracker )
129+
130+ def __post_init__ (self ) -> None :
131+ if type (self .total ) is not type (self .current ): # noqa: E721
132+ raise ValueError ("The `total` and `current` instances should be of the same class" )
109133
110134 def increment_ready (self ) -> None :
111135 self .total .ready += 1
112136 self .current .ready += 1
113137
114138 def increment_started (self ) -> None :
139+ if not isinstance (self .total , StartedTracker ):
140+ raise TypeError (f"`{ self .total .__class__ .__name__ } ` doesn't have a `started` attribute" )
115141 self .total .started += 1
116142 self .current .started += 1
117143
118144 def increment_processed (self ) -> None :
145+ if not isinstance (self .total , ProcessedTracker ):
146+ raise TypeError (f"`{ self .total .__class__ .__name__ } ` doesn't have a `processed` attribute" )
119147 self .total .processed += 1
120148 self .current .processed += 1
121149
@@ -124,9 +152,9 @@ def increment_completed(self) -> None:
124152 self .current .completed += 1
125153
126154 @classmethod
127- def from_defaults (cls , ** kwargs : Optional [ int ] ) -> "Progress" :
155+ def from_defaults (cls , tracker_cls : Type [ ReadyCompletedTracker ], ** kwargs : int ) -> "Progress" :
128156 """Utility function to easily create an instance from keyword arguments to both ``Tracker``s."""
129- return cls (total = Tracker (** kwargs ), current = Tracker (** kwargs ))
157+ return cls (total = tracker_cls (** kwargs ), current = tracker_cls (** kwargs ))
130158
131159 def load_state_dict (self , state_dict : dict ) -> None :
132160 self .total .load_state_dict (state_dict ["total" ])
@@ -144,8 +172,8 @@ class DataLoaderProgress(Progress):
144172 current: Tracks the current dataloader progress.
145173 """
146174
147- total : Tracker = field (default_factory = lambda : Tracker ( started = None , processed = None ) )
148- current : Tracker = field (default_factory = lambda : Tracker ( started = None , processed = None ) )
175+ total : ReadyCompletedTracker = field (default_factory = ReadyCompletedTracker )
176+ current : ReadyCompletedTracker = field (default_factory = ReadyCompletedTracker )
149177
150178
151179@dataclass
@@ -159,8 +187,8 @@ class SchedulerProgress(Progress):
159187 current: Tracks the current scheduler progress.
160188 """
161189
162- total : Tracker = field (default_factory = lambda : Tracker ( started = None , processed = None ) )
163- current : Tracker = field (default_factory = lambda : Tracker ( started = None , processed = None ) )
190+ total : ReadyCompletedTracker = field (default_factory = ReadyCompletedTracker )
191+ current : ReadyCompletedTracker = field (default_factory = ReadyCompletedTracker )
164192
165193
166194@dataclass
@@ -173,8 +201,8 @@ class OptimizerProgress(BaseProgress):
173201 zero_grad: Tracks ``optimizer.zero_grad`` calls.
174202 """
175203
176- step : Progress = field (default_factory = lambda : Progress .from_defaults (started = None , processed = None ))
177- zero_grad : Progress = field (default_factory = lambda : Progress .from_defaults (processed = None ))
204+ step : Progress = field (default_factory = lambda : Progress .from_defaults (ReadyCompletedTracker ))
205+ zero_grad : Progress = field (default_factory = lambda : Progress .from_defaults (StartedTracker ))
178206
179207 def reset_on_epoch (self ) -> None :
180208 self .step .current .reset ()
0 commit comments