Skip to content

Commit 0dfc6a1

Browse files
authored
Call any trainer function from the LightningCLI (#7508)
1 parent 045c879 commit 0dfc6a1

File tree

14 files changed

+597
-250
lines changed

14 files changed

+597
-250
lines changed

.azure-pipelines/gpu-tests.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,8 @@ jobs:
108108
bash pl_examples/run_examples.sh --trainer.gpus=1
109109
bash pl_examples/run_examples.sh --trainer.gpus=2 --trainer.accelerator=ddp
110110
bash pl_examples/run_examples.sh --trainer.gpus=2 --trainer.accelerator=ddp --trainer.precision=16
111+
bash pl_examples/run_examples.sh --trainer.gpus=2 --trainer.accelerator=dp
112+
bash pl_examples/run_examples.sh --trainer.gpus=2 --trainer.accelerator=dp --trainer.precision=16
111113
env:
112114
PL_USE_MOCKED_MNIST: "1"
113115
displayName: 'Testing: examples'

CHANGELOG.md

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
4444
- Added `log_graph` argument for `watch` method of `WandbLogger` ([#8662](https://github.com/PyTorchLightning/pytorch-lightning/pull/8662))
4545

4646

47-
- Added `LightningCLI(run=False|True)` to choose whether to run a `Trainer` subcommand ([#8751](https://github.com/PyTorchLightning/pytorch-lightning/pull/8751))
47+
- `LightningCLI` additions:
48+
* Added `LightningCLI(run=False|True)` to choose whether to run a `Trainer` subcommand ([#8751](https://github.com/PyTorchLightning/pytorch-lightning/pull/8751))
49+
* Added support to call any trainer function from the `LightningCLI` via subcommands ([#7508](https://github.com/PyTorchLightning/pytorch-lightning/pull/7508))
4850

4951

5052
- Fault-tolerant training:

docs/source/common/lightning_cli.rst

Lines changed: 116 additions & 70 deletions
Large diffs are not rendered by default.

pl_examples/basic_examples/autoencoder.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -109,9 +109,12 @@ def predict_dataloader(self):
109109

110110

111111
def cli_main():
112-
cli = LightningCLI(LitAutoEncoder, MyDataModule, seed_everything_default=1234, save_config_overwrite=True)
113-
cli.trainer.test(cli.model, datamodule=cli.datamodule)
114-
predictions = cli.trainer.predict(cli.model, datamodule=cli.datamodule)
112+
cli = LightningCLI(
113+
LitAutoEncoder, MyDataModule, seed_everything_default=1234, save_config_overwrite=True, run=False
114+
)
115+
cli.trainer.fit(cli.model, datamodule=cli.datamodule)
116+
cli.trainer.test(ckpt_path="best")
117+
predictions = cli.trainer.predict(ckpt_path="best")
115118
print(predictions[0])
116119

117120

pl_examples/basic_examples/backbone_image_classifier.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -124,9 +124,10 @@ def predict_dataloader(self):
124124

125125

126126
def cli_main():
127-
cli = LightningCLI(LitClassifier, MyDataModule, seed_everything_default=1234, save_config_overwrite=True)
128-
cli.trainer.test(cli.model, datamodule=cli.datamodule)
129-
predictions = cli.trainer.predict(cli.model, datamodule=cli.datamodule)
127+
cli = LightningCLI(LitClassifier, MyDataModule, seed_everything_default=1234, save_config_overwrite=True, run=False)
128+
cli.trainer.fit(cli.model, datamodule=cli.datamodule)
129+
cli.trainer.test(ckpt_path="best")
130+
predictions = cli.trainer.predict(ckpt_path="best")
130131
print(predictions[0])
131132

132133

pl_examples/basic_examples/dali_image_classifier.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -198,8 +198,9 @@ def cli_main():
198198
if not _DALI_AVAILABLE:
199199
return
200200

201-
cli = LightningCLI(LitClassifier, MyDataModule, seed_everything_default=1234, save_config_overwrite=True)
202-
cli.trainer.test(cli.model, datamodule=cli.datamodule)
201+
cli = LightningCLI(LitClassifier, MyDataModule, seed_everything_default=1234, save_config_overwrite=True, run=False)
202+
cli.trainer.fit(cli.model, datamodule=cli.datamodule)
203+
cli.trainer.test(ckpt_path="best")
203204

204205

205206
if __name__ == "__main__":

pl_examples/basic_examples/simple_image_classifier.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -72,8 +72,11 @@ def configure_optimizers(self):
7272

7373

7474
def cli_main():
75-
cli = LightningCLI(LitClassifier, MNISTDataModule, seed_everything_default=1234, save_config_overwrite=True)
76-
cli.trainer.test(cli.model, datamodule=cli.datamodule)
75+
cli = LightningCLI(
76+
LitClassifier, MNISTDataModule, seed_everything_default=1234, save_config_overwrite=True, run=False
77+
)
78+
cli.trainer.fit(cli.model, datamodule=cli.datamodule)
79+
cli.trainer.test(ckpt_path="best")
7780

7881

7982
if __name__ == "__main__":

pl_examples/domain_templates/computer_vision_fine_tuning.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -277,10 +277,10 @@ def add_arguments_to_parser(self, parser):
277277
}
278278
)
279279

280-
def instantiate_trainer(self):
281-
finetuning_callback = MilestonesFinetuning(**self.config_init["finetuning"])
280+
def instantiate_trainer(self, *args):
281+
finetuning_callback = MilestonesFinetuning(**self._get(self.config_init, "finetuning"))
282282
self.trainer_defaults["callbacks"] = [finetuning_callback]
283-
super().instantiate_trainer()
283+
return super().instantiate_trainer(*args)
284284

285285

286286
def cli_main():

pl_examples/run_examples.sh

100644100755
Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,14 @@
22
set -ex
33

44
dir_path=$(dirname "${BASH_SOURCE[0]}")
5-
args="--trainer.max_epochs=1 --data.batch_size=32 --trainer.limit_train_batches=2 --trainer.limit_val_batches=2"
5+
args="
6+
--data.batch_size=32
7+
--trainer.max_epochs=1
8+
--trainer.limit_train_batches=2
9+
--trainer.limit_val_batches=2
10+
--trainer.limit_test_batches=2
11+
--trainer.limit_predict_batches=2
12+
"
613

714
python "${dir_path}/basic_examples/simple_image_classifier.py" ${args} "$@"
815
python "${dir_path}/basic_examples/backbone_image_classifier.py" ${args} "$@"

pl_examples/test_examples.py

Lines changed: 4 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -11,70 +11,27 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
15-
import importlib
16-
import platform
1714
from unittest import mock
1815

1916
import pytest
20-
import torch
2117

2218
from pl_examples import _DALI_AVAILABLE
19+
from tests.helpers.runif import RunIf
2320

2421
ARGS_DEFAULT = (
2522
"--trainer.default_root_dir %(tmpdir)s "
2623
"--trainer.max_epochs 1 "
2724
"--trainer.limit_train_batches 2 "
2825
"--trainer.limit_val_batches 2 "
26+
"--trainer.limit_test_batches 2 "
27+
"--trainer.limit_predict_batches 2 "
2928
"--data.batch_size 32 "
3029
)
3130
ARGS_GPU = ARGS_DEFAULT + "--trainer.gpus 1 "
32-
ARGS_DP = ARGS_DEFAULT + "--trainer.gpus 2 --trainer.accelerator dp "
33-
ARGS_AMP = "--trainer.precision 16 "
34-
35-
36-
@pytest.mark.parametrize(
37-
"import_cli",
38-
[
39-
"pl_examples.basic_examples.simple_image_classifier",
40-
"pl_examples.basic_examples.backbone_image_classifier",
41-
"pl_examples.basic_examples.autoencoder",
42-
],
43-
)
44-
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
45-
@pytest.mark.parametrize("cli_args", [ARGS_DP, ARGS_DP + ARGS_AMP])
46-
def test_examples_dp(tmpdir, import_cli, cli_args):
47-
48-
module = importlib.import_module(import_cli)
49-
# update the temp dir
50-
cli_args = cli_args % {"tmpdir": tmpdir}
51-
52-
with mock.patch("argparse._sys.argv", ["any.py"] + cli_args.strip().split()):
53-
module.cli_main()
54-
55-
56-
@pytest.mark.parametrize(
57-
"import_cli",
58-
[
59-
"pl_examples.basic_examples.simple_image_classifier",
60-
"pl_examples.basic_examples.backbone_image_classifier",
61-
"pl_examples.basic_examples.autoencoder",
62-
],
63-
)
64-
@pytest.mark.parametrize("cli_args", [ARGS_DEFAULT])
65-
def test_examples_cpu(tmpdir, import_cli, cli_args):
66-
67-
module = importlib.import_module(import_cli)
68-
# update the temp dir
69-
cli_args = cli_args % {"tmpdir": tmpdir}
70-
71-
with mock.patch("argparse._sys.argv", ["any.py"] + cli_args.strip().split()):
72-
module.cli_main()
7331

7432

7533
@pytest.mark.skipif(not _DALI_AVAILABLE, reason="Nvidia DALI required")
76-
@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires GPU machine")
77-
@pytest.mark.skipif(platform.system() != "Linux", reason="Only applies to Linux platform.")
34+
@RunIf(min_gpus=1, skip_windows=True)
7835
@pytest.mark.parametrize("cli_args", [ARGS_GPU])
7936
def test_examples_mnist_dali(tmpdir, cli_args):
8037
from pl_examples.basic_examples.dali_image_classifier import cli_main

0 commit comments

Comments
 (0)