Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
02f5f85
resolve bug
tchaton Dec 16, 2020
069dcd1
Merge branch 'master' into bugfix/5091_resume_from_checkpoint_test
tchaton Dec 20, 2020
b60a0ad
update code
tchaton Dec 20, 2020
c42149e
add set -e
tchaton Dec 20, 2020
96e167e
Merge branch 'master' into bugfix/5091_resume_from_checkpoint_test
tchaton Dec 21, 2020
9e433aa
Update pytorch_lightning/callbacks/model_checkpoint.py
tchaton Dec 21, 2020
725d38d
Merge branch 'master' into bugfix/5091_resume_from_checkpoint_test
tchaton Dec 21, 2020
17cb6a1
update test
tchaton Dec 23, 2020
4106da2
Merge branch 'master' into bugfix/5091_resume_from_checkpoint_test
tchaton Dec 23, 2020
014f79c
Update tests/checkpointing/test_trainer_checkpoint.py
tchaton Dec 23, 2020
f266c6d
Merge branch 'master' into bugfix/5091_resume_from_checkpoint_test
tchaton Dec 23, 2020
f09af65
Update tests/checkpointing/test_trainer_checkpoint.py
tchaton Dec 28, 2020
a2a9fa0
update on comments
tchaton Dec 28, 2020
01aa10c
Merge branch 'master' into bugfix/5091_resume_from_checkpoint_test
tchaton Dec 28, 2020
4914735
Merge branch 'master' into bugfix/5091_resume_from_checkpoint_test
tchaton Dec 28, 2020
bd3d2da
resolve test
tchaton Dec 28, 2020
64e3b5b
convert to set
tchaton Dec 28, 2020
ed84588
update
tchaton Jan 4, 2021
41355d5
add error triggering
tchaton Jan 4, 2021
53455af
update
Jan 4, 2021
fa8d952
update on comments
tchaton Jan 4, 2021
0cf5e5c
Merge branch 'bugfix/5091_resume_from_checkpoint_test' of https://git…
tchaton Jan 4, 2021
ef75de5
update
tchaton Jan 4, 2021
d85662d
resolve import
tchaton Jan 4, 2021
a09ec3d
Merge branch 'master' into bugfix/5091_resume_from_checkpoint_test
tchaton Jan 4, 2021
6c4948c
update
tchaton Jan 4, 2021
a893c7f
Merge branch 'bugfix/5091_resume_from_checkpoint_test' of https://git…
tchaton Jan 4, 2021
3197762
update
tchaton Jan 4, 2021
e44a328
Merge branch 'master' into bugfix/5091_resume_from_checkpoint_test
tchaton Jan 4, 2021
b8f64bf
Update pytorch_lightning/plugins/rpc_plugin.py
tchaton Jan 4, 2021
34986bb
update
tchaton Jan 4, 2021
9252a06
add _module_available
tchaton Jan 4, 2021
16ccc66
Merge branch 'master' into bugfix/5091_resume_from_checkpoint_test
tchaton Jan 5, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .drone.yml
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ steps:
MKL_THREADING_LAYER: GNU

commands:
- set -e
- python --version
- pip --version
- nvidia-smi
Expand Down
5 changes: 3 additions & 2 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).


## [1.1.3rc] - 2020-12-29
## [1.1.3] - 2021-01-05

### Added

Expand All @@ -25,12 +25,13 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Fixed

- Skip restore from `resume_from_checkpoint` in while `testing` ([#5161](https://github.com/PyTorchLightning/pytorch-lightning/pull/5161))

- Allowed `log_momentum` for adaptive optimizers in `LearningRateMonitor` ([#5333](https://github.com/PyTorchLightning/pytorch-lightning/pull/5333))

- Disabled checkpointing, earlystopping and logger with `fast_dev_run` ([#5277](https://github.com/PyTorchLightning/pytorch-lightning/pull/5277))



## [1.1.2] - 2020-12-23

### Added
Expand Down
1 change: 1 addition & 0 deletions pytorch_lightning/callbacks/model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,7 @@ def on_save_checkpoint(self, trainer, pl_module) -> Dict[str, Any]:
"best_model_score": self.best_model_score,
"best_model_path": self.best_model_path,
"current_score": self.current_score,
"dirpath": self.dirpath
}

def on_load_checkpoint(self, checkpointed_state: Dict[str, Any]):
Expand Down
11 changes: 8 additions & 3 deletions pytorch_lightning/plugins/rpc_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,19 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from typing import Any, Optional
from typing import Optional

import torch

from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.plugins.ddp_plugin import DDPPlugin
from pytorch_lightning.utilities import RPC_AVAILABLE
from pytorch_lightning.utilities import _module_available, RPC_AVAILABLE

DEFAULT_RPC_TIMEOUT_SEC = 60.
if RPC_AVAILABLE:
from torch.distributed import rpc
if _module_available("torch.distributed.rpc.constants") and hasattr(torch.distributed.rpc.constants, "DEFAULT_RPC_TIMEOUT_SEC"):
from torch.distributed.rpc.constants import DEFAULT_RPC_TIMEOUT_SEC


class RPCPlugin(DDPPlugin):
Expand All @@ -33,7 +36,8 @@ class RPCPlugin(DDPPlugin):
that need to be addressed when using RPC communication when building custom RPC Plugins.
"""

def __init__(self, **kwargs):
def __init__(self, rpc_timeout_sec: float = DEFAULT_RPC_TIMEOUT_SEC, **kwargs):
self.rpc_timeout_sec = rpc_timeout_sec
self.rpc_initialized = False
super().__init__(**kwargs)

Expand All @@ -42,6 +46,7 @@ def init_rpc_connection(self,
world_size: int) -> None:
os.environ['MASTER_PORT'] = os.getenv('RPC_MASTER_PORT', '15000')
rpc.init_rpc(f"worker{global_rank}", rank=global_rank, world_size=world_size)
rpc._set_rpc_timeout(self.rpc_timeout_sec)
self.rpc_initialized = True

def rpc_save_model(self,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

import pytorch_lightning
from pytorch_lightning import _logger as log
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.utilities import AMPType, APEX_AVAILABLE, OMEGACONF_AVAILABLE, rank_zero_info, rank_zero_warn
from pytorch_lightning.utilities.cloud_io import atomic_save, get_filesystem
Expand Down Expand Up @@ -63,7 +64,7 @@ def restore_weights(self, model: LightningModule) -> None:
rank_zero_info(f'restored hpc model from: {checkpoint_path}')

# 2. Attempt to restore states from `resume_from_checkpoint` file
elif self.trainer.resume_from_checkpoint is not None:
elif self.trainer.resume_from_checkpoint is not None and not self.trainer.testing:
self.restore(self.trainer.resume_from_checkpoint, on_gpu=self.trainer.on_gpu)

# wait for all to catch up
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/trainer/training_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ def setup_training(self, model: LightningModule):
# if cluster resets state, the model will update with the saved weights
self.trainer.model = model

# restore training and model before hpc is called
# restore training state and model weights before hpc is called
self.trainer.checkpoint_connector.restore_weights(model)

# on pretrain routine end
Expand Down
13 changes: 7 additions & 6 deletions tests/checkpointing/test_model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,20 +11,20 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from argparse import Namespace
import os
from pathlib import Path
import pickle
import platform
import re
from argparse import Namespace
from pathlib import Path
from unittest import mock
from unittest.mock import Mock

import cloudpickle
from omegaconf import Container, OmegaConf
import pytest
import torch
import yaml
from omegaconf import Container, OmegaConf

import pytorch_lightning as pl
import tests.base.develop_utils as tutils
Expand All @@ -34,6 +34,7 @@
from pytorch_lightning.utilities.cloud_io import load as pl_load
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from tests.base import BoringModel
import tests.base.develop_utils as tutils


class LogInTwoMethods(BoringModel):
Expand Down Expand Up @@ -760,9 +761,9 @@ def assert_checkpoint_log_dir(idx):
model = ExtendedBoringModel()
trainer.test(model)
assert not trainer.checkpoint_connector.has_trained
assert trainer.global_step == epochs * limit_train_batches
assert trainer.current_epoch == epochs

# resume_from_checkpoint is resumed when calling `.fit`
assert trainer.global_step == 0
assert trainer.current_epoch == 0
trainer.fit(model)
assert not trainer.checkpoint_connector.has_trained
assert trainer.global_step == epochs * limit_train_batches
Expand Down
87 changes: 87 additions & 0 deletions tests/checkpointing/test_trainer_checkpoint.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from copy import deepcopy
import os

import torch

import pytorch_lightning as pl
from pytorch_lightning import seed_everything, Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.utilities.cloud_io import load as pl_load
from tests.base import BoringModel


def test_finetuning_with_resume_from_checkpoint(tmpdir):
"""
This test validates that generated ModelCheckpoint is pointing to the right best_model_path during test
"""

seed_everything(3)

checkpoint_callback = ModelCheckpoint(monitor='val_loss', dirpath=tmpdir, filename="{epoch:02d}", save_top_k=-1)

class ExtendedBoringModel(BoringModel):

def configure_optimizers(self):
optimizer = torch.optim.SGD(self.layer.parameters(), lr=0.001)
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1)
return [optimizer], [lr_scheduler]

def validation_step(self, batch, batch_idx):
output = self.layer(batch)
loss = self.loss(batch, output)
self.log("val_loss", loss, on_epoch=True, prog_bar=True)

model = ExtendedBoringModel()
model.validation_epoch_end = None
trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=1,
limit_train_batches=12,
limit_val_batches=6,
limit_test_batches=12,
callbacks=[checkpoint_callback],
logger=False,
)
trainer.fit(model)
assert os.listdir(tmpdir) == ['epoch=00.ckpt']

best_model_paths = [checkpoint_callback.best_model_path]
results = []

for idx in range(3, 6):
# load from checkpoint
trainer = pl.Trainer(
default_root_dir=tmpdir,
max_epochs=idx,
limit_train_batches=12,
limit_val_batches=12,
limit_test_batches=12,
resume_from_checkpoint=best_model_paths[-1],
progress_bar_refresh_rate=0,
)
trainer.fit(model)
trainer.test()
results.append(deepcopy(trainer.callback_metrics))
best_model_paths.append(trainer.checkpoint_callback.best_model_path)

for idx in range(len(results) - 1):
assert results[idx]["val_loss"] > results[idx + 1]["val_loss"]

for idx, best_model_path in enumerate(best_model_paths):
if idx == 0:
assert best_model_path.endswith(f"epoch=0{idx}.ckpt")
else:
assert f"epoch={idx + 1}" in best_model_path
3 changes: 2 additions & 1 deletion tests/plugins/test_ddp_sequential_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,8 @@ def test_ddp_sequential_plugin_ddp_rpc_manual(tmpdir, args=None):
limit_test_batches=2,
gpus=2,
distributed_backend="ddp",
plugins=[DDPSequentialPlugin(balance=[2, 1])],
plugins=[DDPSequentialPlugin(balance=[2, 1], rpc_timeout_sec=5 * 60)],
enable_pl_optimizer=True,
)

trainer.fit(model)
Expand Down
1 change: 1 addition & 0 deletions tests/special_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# Running special tests
set -e
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what does this do?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I noticed special_tests.sh doesn't return an error when a test fails. I found set -e might.

export PL_RUNNING_SPECIAL_TESTS=1
DEFAULTS="-m coverage run --source pytorch_lightning -a -m pytest --verbose --capture=no"
python ${DEFAULTS} tests/trainer/optimization/test_manual_optimization.py::test_step_with_optimizer_closure_with_different_frequencies_ddp
Expand Down