Skip to content

Commit 63940e4

Browse files
committed
Update pl_examples
1 parent d9fffb4 commit 63940e4

File tree

7 files changed

+23
-15
lines changed

7 files changed

+23
-15
lines changed

pl_examples/basic_examples/autoencoder.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -109,9 +109,12 @@ def predict_dataloader(self):
109109

110110

111111
def cli_main():
112-
cli = LightningCLI(LitAutoEncoder, MyDataModule, seed_everything_default=1234, save_config_overwrite=True)
113-
cli.trainer.test(cli.model, datamodule=cli.datamodule)
114-
predictions = cli.trainer.predict(cli.model, datamodule=cli.datamodule)
112+
cli = LightningCLI(
113+
LitAutoEncoder, MyDataModule, seed_everything_default=1234, save_config_overwrite=True, run=False
114+
)
115+
cli.trainer.fit(cli.model, datamodule=cli.datamodule)
116+
cli.trainer.test(ckpt_path="best")
117+
predictions = cli.trainer.predict(ckpt_path="best")
115118
print(predictions[0])
116119

117120

pl_examples/basic_examples/backbone_image_classifier.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -124,9 +124,10 @@ def predict_dataloader(self):
124124

125125

126126
def cli_main():
127-
cli = LightningCLI(LitClassifier, MyDataModule, seed_everything_default=1234, save_config_overwrite=True)
128-
cli.trainer.test(cli.model, datamodule=cli.datamodule)
129-
predictions = cli.trainer.predict(cli.model, datamodule=cli.datamodule)
127+
cli = LightningCLI(LitClassifier, MyDataModule, seed_everything_default=1234, save_config_overwrite=True, run=False)
128+
cli.trainer.fit(cli.model, datamodule=cli.datamodule)
129+
cli.trainer.test(ckpt_path="best")
130+
predictions = cli.trainer.predict(ckpt_path="best")
130131
print(predictions[0])
131132

132133

pl_examples/basic_examples/dali_image_classifier.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -198,8 +198,9 @@ def cli_main():
198198
if not _DALI_AVAILABLE:
199199
return
200200

201-
cli = LightningCLI(LitClassifier, MyDataModule, seed_everything_default=1234, save_config_overwrite=True)
202-
cli.trainer.test(cli.model, datamodule=cli.datamodule)
201+
cli = LightningCLI(LitClassifier, MyDataModule, seed_everything_default=1234, save_config_overwrite=True, run=False)
202+
cli.trainer.fit(cli.model, datamodule=cli.datamodule)
203+
cli.trainer.test(ckpt_path="best")
203204

204205

205206
if __name__ == "__main__":

pl_examples/basic_examples/simple_image_classifier.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -72,8 +72,11 @@ def configure_optimizers(self):
7272

7373

7474
def cli_main():
75-
cli = LightningCLI(LitClassifier, MNISTDataModule, seed_everything_default=1234, save_config_overwrite=True)
76-
cli.trainer.test(cli.model, datamodule=cli.datamodule)
75+
cli = LightningCLI(
76+
LitClassifier, MNISTDataModule, seed_everything_default=1234, save_config_overwrite=True, run=False
77+
)
78+
cli.trainer.fit(cli.model, datamodule=cli.datamodule)
79+
cli.trainer.test(ckpt_path="best")
7780

7881

7982
if __name__ == "__main__":

pl_examples/domain_templates/computer_vision_fine_tuning.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -277,10 +277,10 @@ def add_arguments_to_parser(self, parser):
277277
}
278278
)
279279

280-
def instantiate_trainer(self):
281-
finetuning_callback = MilestonesFinetuning(**self.config_init["finetuning"])
280+
def instantiate_trainer(self, *args):
281+
finetuning_callback = MilestonesFinetuning(**self._get(self.config_init, "finetuning"))
282282
self.trainer_defaults["callbacks"] = [finetuning_callback]
283-
super().instantiate_trainer()
283+
return super().instantiate_trainer(*args)
284284

285285

286286
def cli_main():

pl_examples/run_examples.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
set -ex
33

44
dir_path=$(dirname "${BASH_SOURCE[0]}")
5-
args="--trainer.max_epochs=1 --data.batch_size=32 --trainer.limit_train_batches=2 --trainer.limit_val_batches=2"
5+
args="--trainer.max_epochs=1 --data.batch_size=32 --trainer.limit_train_batches=2 --trainer.limit_val_batches=2 --limit_test_batches=2"
66

77
python "${dir_path}/basic_examples/simple_image_classifier.py" ${args} "$@"
88
python "${dir_path}/basic_examples/backbone_image_classifier.py" ${args} "$@"

tests/special_tests.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ fi
8787
# report+="Ran\ttests/plugins/environments/torch_elastic_deadlock.py\n"
8888

8989
# test that a user can manually launch individual processes
90-
args="--trainer.gpus 2 --trainer.accelerator ddp --trainer.fast_dev_run 1"
90+
args="--trainer.gpus 2 --trainer.accelerator ddp --trainer.max_epochs=1 --trainer.limit_train_batches=1 --trainer.limit_val_batches=1 --limit_test_batches=1"
9191
MASTER_ADDR="localhost" MASTER_PORT=1234 LOCAL_RANK=1 python pl_examples/basic_examples/simple_image_classifier.py ${args} &
9292
MASTER_ADDR="localhost" MASTER_PORT=1234 LOCAL_RANK=0 python pl_examples/basic_examples/simple_image_classifier.py ${args}
9393
report+="Ran\tmanual ddp launch test\n"

0 commit comments

Comments
 (0)