Skip to content

Commit b0cd9da

Browse files
mauvilsacarmoccaawaelchliananthsub
authored
Simplify backbone_image_classifier example (#7246)
Co-authored-by: Carlos Mocholí <[email protected]> Co-authored-by: Adrian Wälchli <[email protected]> Co-authored-by: ananthsub <[email protected]>
1 parent 7a48db5 commit b0cd9da

File tree

2 files changed

+7
-14
lines changed

2 files changed

+7
-14
lines changed

pl_examples/basic_examples/backbone_image_classifier.py

Lines changed: 6 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
To run:
1818
python backbone_image_classifier.py --trainer.max_epochs=50
1919
"""
20+
from typing import Optional
2021

2122
import torch
2223
from torch.nn import functional as F
@@ -66,11 +67,13 @@ class LitClassifier(pl.LightningModule):
6667

6768
def __init__(
6869
self,
69-
backbone,
70+
backbone: Optional[Backbone] = None,
7071
learning_rate: float = 0.0001,
7172
):
7273
super().__init__()
73-
self.save_hyperparameters()
74+
self.save_hyperparameters(ignore=['backbone'])
75+
if backbone is None:
76+
backbone = Backbone()
7477
self.backbone = backbone
7578

7679
def forward(self, x):
@@ -124,18 +127,8 @@ def test_dataloader(self):
124127
return DataLoader(self.mnist_test, batch_size=self.batch_size)
125128

126129

127-
class MyLightningCLI(LightningCLI):
128-
129-
def add_arguments_to_parser(self, parser):
130-
parser.add_class_arguments(Backbone, 'model.backbone')
131-
132-
def instantiate_model(self):
133-
self.config_init['model']['backbone'] = Backbone(**self.config['model']['backbone'])
134-
super().instantiate_model()
135-
136-
137130
def cli_main():
138-
cli = MyLightningCLI(LitClassifier, MyDataModule, seed_everything_default=1234)
131+
cli = LightningCLI(LitClassifier, MyDataModule, seed_everything_default=1234)
139132
result = cli.trainer.test(cli.model, datamodule=cli.datamodule)
140133
print(result)
141134

requirements/extra.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,4 +7,4 @@ torchtext>=0.5
77
# onnx>=1.7.0
88
onnxruntime>=1.3.0
99
hydra-core>=1.0
10-
jsonargparse[signatures]>=3.10.1
10+
jsonargparse[signatures]>=3.11.0

0 commit comments

Comments
 (0)