@@ -75,7 +75,7 @@ Now in your main trainer file, add the Trainer args, the program args, and add t
7575 # ie: now --gpus --num_nodes ... --fast_dev_run all work in the cli
7676 parser = Trainer.add_argparse_args(parser)
7777
78- hparams = parser.parse_args()
78+ args = parser.parse_args()
7979
8080Now you can call run your program like so
8181
@@ -87,39 +87,50 @@ Finally, make sure to start the training like so:
8787
8888.. code-block :: python
8989
90- # YES
91- model = LitModel(hparams)
92- trainer = Trainer.from_argparse_args(hparams, early_stopping_callback = ... )
90+ # init the trainer like this
91+ trainer = Trainer.from_argparse_args(args, early_stopping_callback = ... )
92+
93+ # NOT like this
94+ trainer = Trainer(gpus = hparams.gpus, ... )
95+
96+ # init the model with Namespace directly
97+ model = LitModel(args)
98+
99+ # or init the model with all the key-value pairs
100+ dict_args = vars (args)
101+ model = LitModel(** dict_args)
93102
94- # NO
95- # model = LitModel(learning_rate=hparams.learning_rate, ...)
96- # trainer = Trainer(gpus=hparams.gpus, ...)
103+ LightningModule hyperparameters
104+ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
97105
98- LightningModule hparams
99- ^^^^^^^^^^^^^^^^^^^^^^^
106+ .. warning :: The use of `hparams` is no longer recommended (but still supported)
100107
101- Normally, we don't hard-code the values to a model. We usually use the command line to
102- modify the network and read those values in the LightningModule
108+ LightningModule is just an nn.Module, you can use it as you normally would. However, there are
109+ some best practices to improve readability and reproducibility.
110+
111+ 1. It's more readable to specify all the arguments that go into a module (with default values).
112+ This helps users of your module know everything that is required to run this.
103113
104114.. testcode ::
105115
106116 class LitMNIST(LightningModule):
107117
108- def __init__(self, hparams ):
118+ def __init__(self, layer_1_dim=128, layer_2_dim=256, learning_rate=1e-4, batch_size=32, ** kwargs ):
109119 super().__init__()
120+ self.layer_1_dim = layer_1_dim
121+ self.layer_2_dim = layer_2_dim
122+ self.learning_rate = learning_rate
123+ self.batch_size = batch_size
110124
111- # do this to save all arguments in any logger (tensorboard)
112- self.hparams = hparams
113-
114- self.layer_1 = torch.nn.Linear(28 * 28, hparams.layer_1_dim)
115- self.layer_2 = torch.nn.Linear(hparams.layer_1_dim, hparams.layer_2_dim)
116- self.layer_3 = torch.nn.Linear(hparams.layer_2_dim, 10)
125+ self.layer_1 = torch.nn.Linear(28 * 28, self.layer_1_dim)
126+ self.layer_2 = torch.nn.Linear(self.layer_1_dim, self.layer_2_dim)
127+ self.layer_3 = torch.nn.Linear(self.layer_2_dim, 10)
117128
118129 def train_dataloader(self):
119- return DataLoader(mnist_train, batch_size=self.hparams. batch_size)
130+ return DataLoader(mnist_train, batch_size=self.batch_size)
120131
121132 def configure_optimizers(self):
122- return Adam(self.parameters(), lr=self.hparams. learning_rate)
133+ return Adam(self.parameters(), lr=self.learning_rate)
123134
124135 @staticmethod
125136 def add_model_specific_args(parent_parser):
@@ -130,20 +141,59 @@ modify the network and read those values in the LightningModule
130141 parser.add_argument('--learning_rate', type=float, default=0.002)
131142 return parser
132143
133- Now pass in the params when you init your model
144+ 2. You can also pass in a dict or Namespace, but this obscures the parameters your module is looking
145+ for. The user would have to search the file to find what is parametrized.
146+
147+ .. code-block :: python
148+
149+ # using a argparse.Namespace
150+ class LitMNIST (LightningModule ):
151+
152+ def __init__ (self , hparams , * args , ** kwargs ):
153+ super ().__init__ ()
154+ self .hparams = hparams
155+
156+ self .layer_1 = torch.nn.Linear(28 * 28 , self .hparams.layer_1_dim)
157+ self .layer_2 = torch.nn.Linear(self .hparams.layer_1_dim, self .hparams.layer_2_dim)
158+ self .layer_3 = torch.nn.Linear(self .hparams.layer_2_dim, 10 )
159+
160+ def train_dataloader (self ):
161+ return DataLoader(mnist_train, batch_size = self .hparams.batch_size)
162+
163+ One way to get around this is to convert a Namespace or dict into key-value pairs using `** `
134164
135165.. code-block :: python
136166
137167 parser = ArgumentParser()
138168 parser = LitMNIST.add_model_specific_args(parser)
139- hparams = parser.parse_args()
140- model = LitMNIST(hparams)
169+ args = parser.parse_args()
170+ dict_args = vars (args)
171+ model = LitMNIST(** dict_args)
172+
173+ Within any LightningModule all the arguments you pass into your `__init__ ` will be stored in
174+ the checkpoint so that you know all the values that went into creating this model.
175+
176+ We will also add all of those values to the TensorBoard hparams tab (unless it's an object which
177+ we won't). We also will store those values into checkpoints for you which you can use to init your
178+ models.
179+
180+ .. code-block :: python
141181
142- The line `self.hparams = hparams ` is very special. This line assigns your hparams to the LightningModule.
143- This does two things:
182+ class LitMNIST (LightningModule ):
183+
184+ def __init__ (self , layer_1_dim , some_other_param ):
185+ super ().__init__ ()
186+ self .layer_1_dim = layer_1_dim
187+ self .some_other_param = some_other_param
188+
189+ self .layer_1 = torch.nn.Linear(28 * 28 , self .layer_1_dim)
190+
191+ self .layer_2 = torch.nn.Linear(self .layer_1_dim, self .some_other_param)
192+ self .layer_3 = torch.nn.Linear(self .some_other_param, 10 )
193+
194+
195+ model = LitMNIST(10 , 20 )
144196
145- 1. It adds them automatically to TensorBoard logs under the hparams tab.
146- 2. Lightning will save those hparams to the checkpoint and use them to restore the module correctly.
147197
148198 Trainer args
149199^^^^^^^^^^^^
@@ -171,27 +221,27 @@ polluting the main.py file, the LightningModule lets you define arguments for ea
171221
172222 class LitMNIST(LightningModule):
173223
174- def __init__(self, hparams ):
224+ def __init__(self, layer_1_dim, ** kwargs ):
175225 super().__init__()
176- self.layer_1 = torch.nn.Linear(28 * 28, hparams. layer_1_dim)
226+ self.layer_1 = torch.nn.Linear(28 * 28, layer_1_dim)
177227
178228 @staticmethod
179229 def add_model_specific_args(parent_parser):
180- parser = ArgumentParser(parents=[parent_parser])
230+ parser = ArgumentParser(parents=[parent_parser], add_help=False )
181231 parser.add_argument('--layer_1_dim', type=int, default=128)
182232 return parser
183233
184234.. testcode ::
185235
186236 class GoodGAN(LightningModule):
187237
188- def __init__(self, hparams ):
238+ def __init__(self, encoder_layers, ** kwargs ):
189239 super().__init__()
190- self.encoder = Encoder(layers=hparams. encoder_layers)
240+ self.encoder = Encoder(layers=encoder_layers)
191241
192242 @staticmethod
193243 def add_model_specific_args(parent_parser):
194- parser = ArgumentParser(parents=[parent_parser])
244+ parser = ArgumentParser(parents=[parent_parser], add_help=False )
195245 parser.add_argument('--encoder_layers', type=int, default=12)
196246 return parser
197247
@@ -201,14 +251,14 @@ Now we can allow each model to inject the arguments it needs in the ``main.py``
201251.. code-block :: python
202252
203253 def main (args ):
254+ dict_args = vars (args)
204255
205256 # pick model
206257 if args.model_name == ' gan' :
207- model = GoodGAN(hparams = args )
258+ model = GoodGAN(** dict_args )
208259 elif args.model_name == ' mnist' :
209- model = LitMNIST(hparams = args )
260+ model = LitMNIST(** dict_args )
210261
211- model = LitMNIST(hparams = args)
212262 trainer = Trainer.from_argparse_args(args)
213263 trainer.fit(model)
214264
0 commit comments