Skip to content

Commit 83039ba

Browse files
authored
Test deepcopy for progress tracking dataclasses (#8265)
1 parent ea88105 commit 83039ba

File tree

2 files changed

+33
-27
lines changed

2 files changed

+33
-27
lines changed

pytorch_lightning/trainer/progress.py

Lines changed: 22 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -18,19 +18,16 @@
1818
@dataclass
1919
class _DataclassStateDictMixin:
2020

21-
def __getstate__(self) -> dict:
21+
def state_dict(self) -> dict:
2222
return asdict(self)
2323

24-
def __setstate__(self, state: dict) -> None:
25-
self.__dict__.update(state)
26-
27-
def state_dict(self) -> dict:
28-
return self.__getstate__()
24+
def load_state_dict(self, state_dict: dict) -> None:
25+
self.__dict__.update(state_dict)
2926

3027
@classmethod
3128
def from_state_dict(cls, state_dict: dict) -> "_DataclassStateDictMixin":
3229
obj = cls()
33-
obj.__setstate__(state_dict)
30+
obj.load_state_dict(state_dict)
3431
return obj
3532

3633

@@ -115,9 +112,9 @@ def increment_completed(self) -> None:
115112
def from_defaults(cls, **kwargs: Optional[int]) -> "Progress":
116113
return cls(total=Tracker(**kwargs), current=Tracker(**kwargs))
117114

118-
def __setstate__(self, state: dict) -> None:
119-
self.total.__setstate__(state["total"])
120-
self.current.__setstate__(state["current"])
115+
def load_state_dict(self, state_dict: dict) -> None:
116+
self.total.load_state_dict(state_dict["total"])
117+
self.current.load_state_dict(state_dict["current"])
121118

122119

123120
class BatchProgress(Progress):
@@ -147,9 +144,9 @@ class EpochProgress(Progress):
147144
def reset_on_epoch(self) -> None:
148145
self.batch.current.reset()
149146

150-
def __setstate__(self, state: dict) -> None:
151-
super().__setstate__(state)
152-
self.batch.__setstate__(state["batch"])
147+
def load_state_dict(self, state_dict: dict) -> None:
148+
super().load_state_dict(state_dict)
149+
self.batch.load_state_dict(state_dict["batch"])
153150

154151

155152
@dataclass
@@ -169,9 +166,9 @@ def reset_on_epoch(self) -> None:
169166
self.step.current.reset()
170167
self.zero_grad.current.reset()
171168

172-
def __setstate__(self, state: dict) -> None:
173-
self.step.__setstate__(state["step"])
174-
self.zero_grad.__setstate__(state["zero_grad"])
169+
def load_state_dict(self, state_dict: dict) -> None:
170+
self.step.load_state_dict(state_dict["step"])
171+
self.zero_grad.load_state_dict(state_dict["zero_grad"])
175172

176173

177174
@dataclass
@@ -200,9 +197,9 @@ def reset_on_epoch(self) -> None:
200197
self.optimizer.reset_on_epoch()
201198
self.scheduler.current.reset()
202199

203-
def __setstate__(self, state: dict) -> None:
204-
self.optimizer.__setstate__(state["optimizer"])
205-
self.scheduler.__setstate__(state["scheduler"])
200+
def load_state_dict(self, state_dict: dict) -> None:
201+
self.optimizer.load_state_dict(state_dict["optimizer"])
202+
self.scheduler.load_state_dict(state_dict["scheduler"])
206203

207204

208205
@dataclass
@@ -225,8 +222,8 @@ def reset_on_epoch(self) -> None:
225222
self.epoch.reset_on_epoch()
226223
self.epoch.current.reset()
227224

228-
def __setstate__(self, state: dict) -> None:
229-
self.epoch.__setstate__(state["epoch"])
225+
def load_state_dict(self, state_dict: dict) -> None:
226+
self.epoch.load_state_dict(state_dict["epoch"])
230227

231228

232229
@dataclass
@@ -245,10 +242,10 @@ class TrainingEpochProgress(EpochProgress):
245242
optim: OptimizationProgress = field(default_factory=OptimizationProgress)
246243
val: EpochLoopProgress = field(default_factory=EpochLoopProgress)
247244

248-
def __setstate__(self, state: dict) -> None:
249-
super().__setstate__(state)
250-
self.optim.__setstate__(state["optim"])
251-
self.val.__setstate__(state["val"])
245+
def load_state_dict(self, state_dict: dict) -> None:
246+
super().load_state_dict(state_dict)
247+
self.optim.load_state_dict(state_dict["optim"])
248+
self.val.load_state_dict(state_dict["val"])
252249

253250

254251
@dataclass

tests/trainer/test_progress.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
from copy import deepcopy
15+
1416
import pytest
1517

1618
from pytorch_lightning.trainer.progress import (
@@ -135,14 +137,17 @@ def test_optimizer_progress_default_factory():
135137

136138
def test_fit_loop_progress_serialization():
137139
fit_loop = FitLoopProgress()
140+
_ = deepcopy(fit_loop)
141+
fit_loop.epoch.increment_completed() # check `TrainingEpochProgress.load_state_dict` calls `super`
142+
138143
state_dict = fit_loop.state_dict()
139144
# yapf: disable
140145
assert state_dict == {
141146
'epoch': {
142147
# number of epochs across `fit` calls
143-
'total': {'completed': 0, 'processed': 0, 'ready': 0, 'started': 0},
148+
'total': {'completed': 1, 'processed': 0, 'ready': 0, 'started': 0},
144149
# number of epochs this `fit` call
145-
'current': {'completed': 0, 'processed': 0, 'ready': 0, 'started': 0},
150+
'current': {'completed': 1, 'processed': 0, 'ready': 0, 'started': 0},
146151
'batch': {
147152
# number of batches across `fit` calls
148153
'total': {'completed': 0, 'processed': 0, 'ready': 0, 'started': 0},
@@ -191,13 +196,16 @@ def test_fit_loop_progress_serialization():
191196
}
192197
}
193198
# yapf: enable
199+
194200
new_loop = FitLoopProgress.from_state_dict(state_dict)
195201
assert fit_loop == new_loop
196202

197203

198204
def test_epoch_loop_progress_serialization():
199205
loop = EpochLoopProgress()
206+
_ = deepcopy(loop)
200207
state_dict = loop.state_dict()
208+
201209
# yapf: disable
202210
assert state_dict == {
203211
'epoch': {
@@ -214,5 +222,6 @@ def test_epoch_loop_progress_serialization():
214222
}
215223
}
216224
# yapf: enable
225+
217226
new_loop = EpochLoopProgress.from_state_dict(state_dict)
218227
assert loop == new_loop

0 commit comments

Comments
 (0)