Skip to content

Commit 41309da

Browse files
committed
Fixes
1 parent 2db8975 commit 41309da

File tree

2 files changed

+4
-1
lines changed

2 files changed

+4
-1
lines changed

tests/loops/test_loop_state_dict.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
import os
15+
from unittest import mock
1416
from unittest.mock import Mock
1517

1618
import pytest
@@ -37,6 +39,7 @@ def test_loops_state_dict():
3739
assert fit_loop.state_dict() == new_fit_loop.state_dict()
3840

3941

42+
@mock.patch.dict(os.environ, {"PL_FAULT_TOLERANT_TRAINING": "1"})
4043
def test_loops_state_dict_structure():
4144
trainer = Trainer()
4245
trainer.train_dataloader = Mock()

tests/models/test_restore.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -251,7 +251,7 @@ def on_pretrain_routine_end(self) -> None:
251251
trainer.fit(TestModel(), ckpt_path=ckpt)
252252
assert trainer.current_epoch == max_epochs
253253
# TODO(@carmocca): should not need `+1`
254-
# assert trainer.global_step == max_epochs * train_batches + 1
254+
assert trainer.global_step == max_epochs * train_batches + 1
255255

256256

257257
def test_fit_twice(tmpdir):

0 commit comments

Comments
 (0)