1818@dataclass
1919class _DataclassStateDictMixin :
2020
21- def __getstate__ (self ) -> dict :
21+ def state_dict (self ) -> dict :
2222 return asdict (self )
2323
24- def __setstate__ (self , state : dict ) -> None :
25- self .__dict__ .update (state )
26-
27- def state_dict (self ) -> dict :
28- return self .__getstate__ ()
24+ def load_state_dict (self , state_dict : dict ) -> None :
25+ self .__dict__ .update (state_dict )
2926
3027 @classmethod
3128 def from_state_dict (cls , state_dict : dict ) -> "_DataclassStateDictMixin" :
3229 obj = cls ()
33- obj .__setstate__ (state_dict )
30+ obj .load_state_dict (state_dict )
3431 return obj
3532
3633
@@ -115,9 +112,9 @@ def increment_completed(self) -> None:
115112 def from_defaults (cls , ** kwargs : Optional [int ]) -> "Progress" :
116113 return cls (total = Tracker (** kwargs ), current = Tracker (** kwargs ))
117114
118- def __setstate__ (self , state : dict ) -> None :
119- self .total .__setstate__ ( state ["total" ])
120- self .current .__setstate__ ( state ["current" ])
115+ def load_state_dict (self , state_dict : dict ) -> None :
116+ self .total .load_state_dict ( state_dict ["total" ])
117+ self .current .load_state_dict ( state_dict ["current" ])
121118
122119
123120class BatchProgress (Progress ):
@@ -147,9 +144,9 @@ class EpochProgress(Progress):
147144 def reset_on_epoch (self ) -> None :
148145 self .batch .current .reset ()
149146
150- def __setstate__ (self , state : dict ) -> None :
151- super ().__setstate__ ( state )
152- self .batch .__setstate__ ( state ["batch" ])
147+ def load_state_dict (self , state_dict : dict ) -> None :
148+ super ().load_state_dict ( state_dict )
149+ self .batch .load_state_dict ( state_dict ["batch" ])
153150
154151
155152@dataclass
@@ -169,9 +166,9 @@ def reset_on_epoch(self) -> None:
169166 self .step .current .reset ()
170167 self .zero_grad .current .reset ()
171168
172- def __setstate__ (self , state : dict ) -> None :
173- self .step .__setstate__ ( state ["step" ])
174- self .zero_grad .__setstate__ ( state ["zero_grad" ])
169+ def load_state_dict (self , state_dict : dict ) -> None :
170+ self .step .load_state_dict ( state_dict ["step" ])
171+ self .zero_grad .load_state_dict ( state_dict ["zero_grad" ])
175172
176173
177174@dataclass
@@ -200,9 +197,9 @@ def reset_on_epoch(self) -> None:
200197 self .optimizer .reset_on_epoch ()
201198 self .scheduler .current .reset ()
202199
203- def __setstate__ (self , state : dict ) -> None :
204- self .optimizer .__setstate__ ( state ["optimizer" ])
205- self .scheduler .__setstate__ ( state ["scheduler" ])
200+ def load_state_dict (self , state_dict : dict ) -> None :
201+ self .optimizer .load_state_dict ( state_dict ["optimizer" ])
202+ self .scheduler .load_state_dict ( state_dict ["scheduler" ])
206203
207204
208205@dataclass
@@ -225,8 +222,8 @@ def reset_on_epoch(self) -> None:
225222 self .epoch .reset_on_epoch ()
226223 self .epoch .current .reset ()
227224
228- def __setstate__ (self , state : dict ) -> None :
229- self .epoch .__setstate__ ( state ["epoch" ])
225+ def load_state_dict (self , state_dict : dict ) -> None :
226+ self .epoch .load_state_dict ( state_dict ["epoch" ])
230227
231228
232229@dataclass
@@ -245,10 +242,10 @@ class TrainingEpochProgress(EpochProgress):
245242 optim : OptimizationProgress = field (default_factory = OptimizationProgress )
246243 val : EpochLoopProgress = field (default_factory = EpochLoopProgress )
247244
248- def __setstate__ (self , state : dict ) -> None :
249- super ().__setstate__ ( state )
250- self .optim .__setstate__ ( state ["optim" ])
251- self .val .__setstate__ ( state ["val" ])
245+ def load_state_dict (self , state_dict : dict ) -> None :
246+ super ().load_state_dict ( state_dict )
247+ self .optim .load_state_dict ( state_dict ["optim" ])
248+ self .val .load_state_dict ( state_dict ["val" ])
252249
253250
254251@dataclass
0 commit comments