Skip to content

Commit f852a4f

Browse files
mauvilsacarmocca
andauthored
Changed basic_examples to use LightningCLI (#6862)
Co-authored-by: Carlos Mocholi <[email protected]>
1 parent f645df5 commit f852a4f

File tree

15 files changed

+203
-233
lines changed

15 files changed

+203
-233
lines changed

azure-pipelines.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -116,8 +116,8 @@ jobs:
116116
set -e
117117
python -m pytest pl_examples -v --maxfail=2 --durations=0
118118
pip install . --user --quiet
119-
bash pl_examples/run_examples-args.sh --gpus 1 --max_epochs 1 --batch_size 64 --limit_train_batches 5 --limit_val_batches 3
120-
bash pl_examples/run_ddp-examples.sh --max_epochs 1 --batch_size 32 --limit_train_batches 2 --limit_val_batches 2
119+
bash pl_examples/run_examples-args.sh --trainer.gpus 1 --trainer.max_epochs 1 --data.batch_size 64 --trainer.limit_train_batches 5 --trainer.limit_val_batches 3
120+
bash pl_examples/run_ddp-examples.sh --trainer.max_epochs 1 --data.batch_size 32 --trainer.limit_train_batches 2 --trainer.limit_val_batches 2
121121
# cd pl_examples/basic_examples
122122
# bash submit_ddp_job.sh
123123
# bash submit_ddp2_job.sh

docs/source/common/lightning_cli.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -224,7 +224,7 @@ datamodule class. However, there are many cases in which the objective is to eas
224224
multiple models and datasets. For these cases the tool can be configured such that a model and/or a datamodule is
225225
specified by an import path and init arguments. For example, with a tool implemented as:
226226

227-
.. testcode::
227+
.. code-block:: python
228228
229229
from pytorch_lightning.utilities.cli import LightningCLI
230230

pl_examples/basic_examples/README.md

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,10 @@ Trains MNIST where the model is defined inside the `LightningModule`.
88
python simple_image_classifier.py
99

1010
# gpus (any number)
11-
python simple_image_classifier.py --gpus 2
11+
python simple_image_classifier.py --trainer.gpus 2
1212

1313
# dataparallel
14-
python simple_image_classifier.py --gpus 2 --distributed_backend 'dp'
14+
python simple_image_classifier.py --trainer.gpus 2 --trainer.accelerator 'dp'
1515
```
1616

1717
---
@@ -30,10 +30,10 @@ Generic image classifier with an arbitrary backbone (ie: a simple system)
3030
python backbone_image_classifier.py
3131

3232
# gpus (any number)
33-
python backbone_image_classifier.py --gpus 2
33+
python backbone_image_classifier.py --trainer.gpus 2
3434

3535
# dataparallel
36-
python backbone_image_classifier.py --gpus 2 --distributed_backend 'dp'
36+
python backbone_image_classifier.py --trainer.gpus 2 --trainer.accelerator 'dp'
3737
```
3838

3939
---
@@ -44,10 +44,10 @@ Showing the power of a system... arbitrarily complex training loops
4444
python autoencoder.py
4545

4646
# gpus (any number)
47-
python autoencoder.py --gpus 2
47+
python autoencoder.py --trainer.gpus 2
4848

4949
# dataparallel
50-
python autoencoder.py --gpus 2 --distributed_backend 'dp'
50+
python autoencoder.py --trainer.gpus 2 --trainer.accelerator 'dp'
5151
```
5252
---
5353
# Multi-node example

pl_examples/basic_examples/autoencoder.py

Lines changed: 30 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,12 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
"""
15+
MNIST autoencoder example.
1416
15-
from argparse import ArgumentParser
17+
To run:
18+
python autoencoder.py --trainer.max_epochs=50
19+
"""
1620

1721
import torch
1822
import torch.nn.functional as F
@@ -21,6 +25,7 @@
2125

2226
import pytorch_lightning as pl
2327
from pl_examples import _DATASETS_PATH, _TORCHVISION_MNIST_AVAILABLE, cli_lightning_logo
28+
from pytorch_lightning.utilities.cli import LightningCLI
2429
from pytorch_lightning.utilities.imports import _TORCHVISION_AVAILABLE
2530

2631
if _TORCHVISION_AVAILABLE:
@@ -87,44 +92,31 @@ def configure_optimizers(self):
8792
return optimizer
8893

8994

95+
class MyDataModule(pl.LightningDataModule):
96+
97+
def __init__(
98+
self,
99+
batch_size: int = 32,
100+
):
101+
super().__init__()
102+
dataset = MNIST(_DATASETS_PATH, train=True, download=True, transform=transforms.ToTensor())
103+
self.mnist_test = MNIST(_DATASETS_PATH, train=False, download=True, transform=transforms.ToTensor())
104+
self.mnist_train, self.mnist_val = random_split(dataset, [55000, 5000])
105+
self.batch_size = batch_size
106+
107+
def train_dataloader(self):
108+
return DataLoader(self.mnist_train, batch_size=self.batch_size)
109+
110+
def val_dataloader(self):
111+
return DataLoader(self.mnist_val, batch_size=self.batch_size)
112+
113+
def test_dataloader(self):
114+
return DataLoader(self.mnist_test, batch_size=self.batch_size)
115+
116+
90117
def cli_main():
91-
pl.seed_everything(1234)
92-
93-
# ------------
94-
# args
95-
# ------------
96-
parser = ArgumentParser()
97-
parser.add_argument('--batch_size', default=32, type=int)
98-
parser.add_argument('--hidden_dim', type=int, default=64)
99-
parser = pl.Trainer.add_argparse_args(parser)
100-
args = parser.parse_args()
101-
102-
# ------------
103-
# data
104-
# ------------
105-
dataset = MNIST(_DATASETS_PATH, train=True, download=True, transform=transforms.ToTensor())
106-
mnist_test = MNIST(_DATASETS_PATH, train=False, download=True, transform=transforms.ToTensor())
107-
mnist_train, mnist_val = random_split(dataset, [55000, 5000])
108-
109-
train_loader = DataLoader(mnist_train, batch_size=args.batch_size)
110-
val_loader = DataLoader(mnist_val, batch_size=args.batch_size)
111-
test_loader = DataLoader(mnist_test, batch_size=args.batch_size)
112-
113-
# ------------
114-
# model
115-
# ------------
116-
model = LitAutoEncoder(args.hidden_dim)
117-
118-
# ------------
119-
# training
120-
# ------------
121-
trainer = pl.Trainer.from_argparse_args(args)
122-
trainer.fit(model, train_loader, val_loader)
123-
124-
# ------------
125-
# testing
126-
# ------------
127-
result = trainer.test(test_dataloaders=test_loader)
118+
cli = LightningCLI(LitAutoEncoder, MyDataModule, seed_everything_default=1234)
119+
result = cli.trainer.test(cli.model, datamodule=cli.datamodule)
128120
print(result)
129121

130122

pl_examples/basic_examples/backbone_image_classifier.py

Lines changed: 44 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -11,15 +11,20 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
"""
15+
MNIST backbone image classifier example.
1416
15-
from argparse import ArgumentParser
17+
To run:
18+
python backbone_image_classifier.py --trainer.max_epochs=50
19+
"""
1620

1721
import torch
1822
from torch.nn import functional as F
1923
from torch.utils.data import DataLoader, random_split
2024

2125
import pytorch_lightning as pl
2226
from pl_examples import _DATASETS_PATH, _TORCHVISION_MNIST_AVAILABLE, cli_lightning_logo
27+
from pytorch_lightning.utilities.cli import LightningCLI
2328
from pytorch_lightning.utilities.imports import _TORCHVISION_AVAILABLE
2429

2530
if _TORCHVISION_AVAILABLE:
@@ -59,7 +64,11 @@ class LitClassifier(pl.LightningModule):
5964
)
6065
"""
6166

62-
def __init__(self, backbone, learning_rate=1e-3):
67+
def __init__(
68+
self,
69+
backbone,
70+
learning_rate: float = 0.0001,
71+
):
6372
super().__init__()
6473
self.save_hyperparameters()
6574
self.backbone = backbone
@@ -92,52 +101,42 @@ def configure_optimizers(self):
92101
# self.hparams available because we called self.save_hyperparameters()
93102
return torch.optim.Adam(self.parameters(), lr=self.hparams.learning_rate)
94103

95-
@staticmethod
96-
def add_model_specific_args(parent_parser):
97-
parser = parent_parser.add_argument_group("LitClassifier")
98-
parser.add_argument('--learning_rate', type=float, default=0.0001)
99-
return parent_parser
104+
105+
class MyDataModule(pl.LightningDataModule):
106+
107+
def __init__(
108+
self,
109+
batch_size: int = 32,
110+
):
111+
super().__init__()
112+
dataset = MNIST(_DATASETS_PATH, train=True, download=True, transform=transforms.ToTensor())
113+
self.mnist_test = MNIST(_DATASETS_PATH, train=False, download=True, transform=transforms.ToTensor())
114+
self.mnist_train, self.mnist_val = random_split(dataset, [55000, 5000])
115+
self.batch_size = batch_size
116+
117+
def train_dataloader(self):
118+
return DataLoader(self.mnist_train, batch_size=self.batch_size)
119+
120+
def val_dataloader(self):
121+
return DataLoader(self.mnist_val, batch_size=self.batch_size)
122+
123+
def test_dataloader(self):
124+
return DataLoader(self.mnist_test, batch_size=self.batch_size)
125+
126+
127+
class MyLightningCLI(LightningCLI):
128+
129+
def add_arguments_to_parser(self, parser):
130+
parser.add_class_arguments(Backbone, 'model.backbone')
131+
132+
def instantiate_model(self):
133+
self.config_init['model']['backbone'] = Backbone(**self.config['model']['backbone'])
134+
super().instantiate_model()
100135

101136

102137
def cli_main():
103-
pl.seed_everything(1234)
104-
105-
# ------------
106-
# args
107-
# ------------
108-
parser = ArgumentParser()
109-
parser.add_argument('--batch_size', default=32, type=int)
110-
parser.add_argument('--hidden_dim', type=int, default=128)
111-
parser = pl.Trainer.add_argparse_args(parser)
112-
parser = LitClassifier.add_model_specific_args(parser)
113-
args = parser.parse_args()
114-
115-
# ------------
116-
# data
117-
# ------------
118-
dataset = MNIST(_DATASETS_PATH, train=True, download=True, transform=transforms.ToTensor())
119-
mnist_test = MNIST(_DATASETS_PATH, train=False, download=True, transform=transforms.ToTensor())
120-
mnist_train, mnist_val = random_split(dataset, [55000, 5000])
121-
122-
train_loader = DataLoader(mnist_train, batch_size=args.batch_size)
123-
val_loader = DataLoader(mnist_val, batch_size=args.batch_size)
124-
test_loader = DataLoader(mnist_test, batch_size=args.batch_size)
125-
126-
# ------------
127-
# model
128-
# ------------
129-
model = LitClassifier(Backbone(hidden_dim=args.hidden_dim), args.learning_rate)
130-
131-
# ------------
132-
# training
133-
# ------------
134-
trainer = pl.Trainer.from_argparse_args(args)
135-
trainer.fit(model, train_loader, val_loader)
136-
137-
# ------------
138-
# testing
139-
# ------------
140-
result = trainer.test(test_dataloaders=test_loader)
138+
cli = MyLightningCLI(LitClassifier, MyDataModule, seed_everything_default=1234)
139+
result = cli.trainer.test(cli.model, datamodule=cli.datamodule)
141140
print(result)
142141

143142

0 commit comments

Comments
 (0)