File tree Expand file tree Collapse file tree 2 files changed +7
-14
lines changed
pl_examples/basic_examples Expand file tree Collapse file tree 2 files changed +7
-14
lines changed Original file line number Diff line number Diff line change 1717To run:
1818python backbone_image_classifier.py --trainer.max_epochs=50
1919"""
20+ from typing import Optional
2021
2122import torch
2223from torch .nn import functional as F
@@ -66,11 +67,13 @@ class LitClassifier(pl.LightningModule):
6667
6768 def __init__ (
6869 self ,
69- backbone ,
70+ backbone : Optional [ Backbone ] = None ,
7071 learning_rate : float = 0.0001 ,
7172 ):
7273 super ().__init__ ()
73- self .save_hyperparameters ()
74+ self .save_hyperparameters (ignore = ['backbone' ])
75+ if backbone is None :
76+ backbone = Backbone ()
7477 self .backbone = backbone
7578
7679 def forward (self , x ):
@@ -124,18 +127,8 @@ def test_dataloader(self):
124127 return DataLoader (self .mnist_test , batch_size = self .batch_size )
125128
126129
127- class MyLightningCLI (LightningCLI ):
128-
129- def add_arguments_to_parser (self , parser ):
130- parser .add_class_arguments (Backbone , 'model.backbone' )
131-
132- def instantiate_model (self ):
133- self .config_init ['model' ]['backbone' ] = Backbone (** self .config ['model' ]['backbone' ])
134- super ().instantiate_model ()
135-
136-
137130def cli_main ():
138- cli = MyLightningCLI (LitClassifier , MyDataModule , seed_everything_default = 1234 )
131+ cli = LightningCLI (LitClassifier , MyDataModule , seed_everything_default = 1234 )
139132 result = cli .trainer .test (cli .model , datamodule = cli .datamodule )
140133 print (result )
141134
Original file line number Diff line number Diff line change @@ -7,4 +7,4 @@ torchtext>=0.5
77# onnx>=1.7.0
88onnxruntime>=1.3.0
99hydra-core>=1.0
10- jsonargparse[signatures]>=3.10.1
10+ jsonargparse[signatures]>=3.11.0
You can’t perform that action at this time.
0 commit comments