Skip to content

Commit caa9c67

Browse files
williamFalconBordaawaelchli
authored
replace Hparams by init args (#1896)
* remove the need for hparams * remove the need for hparams * remove the need for hparams * remove the need for hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * fixed * fixed * fixed * fixed * fixed * fixed * fixed * fixed * fixed * fixed * fixed * fixed * fixed * fixed * finished moco * basic * testing * todo * recurse * hparams * persist * hparams * chlog * tests * tests * tests * tests * tests * tests * review * saving * tests * tests * tests * docs * finished moco * hparams * review * Apply suggestions from code review Co-authored-by: Adrian Wälchli <[email protected]> * hparams * overwrite * transform * transform * transform * transform * cleaning * cleaning * tests * examples * examples * examples * Apply suggestions from code review Co-authored-by: Adrian Wälchli <[email protected]> * chp key * tests * Apply suggestions from code review * class * updated docs * updated docs * updated docs * updated docs * save * wip * fix * flake8 Co-authored-by: Jirka <[email protected]> Co-authored-by: Jirka Borovec <[email protected]> Co-authored-by: Adrian Wälchli <[email protected]>
1 parent a20db4e commit caa9c67

38 files changed

+687
-542
lines changed

.gitignore

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -133,4 +133,5 @@ mnist/
133133
# pl tests
134134
ml-runs/
135135
*.zip
136-
pytorch\ lightning
136+
pytorch\ lightning
137+
test-reports/

.run_local_tests.sh

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,3 +14,6 @@ rm -rf ./tests/tests/*
1414
rm -rf ./lightning_logs
1515
python -m coverage run --source pytorch_lightning -m py.test pytorch_lightning tests pl_examples -v --doctest-modules --flake8
1616
python -m coverage report -m
17+
18+
# specific file
19+
# python -m coverage run --source pytorch_lightning -m py.test -k test_trainer.py --flake8

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
2020

2121
- Removed non-finite values from loss in `LRFinder` ([#1862](https://github.com/PyTorchLightning/pytorch-lightning/pull/1862))
2222

23+
- Allow passing model hyperparameters as complete kwarg list ([#1896](https://github.com/PyTorchLightning/pytorch-lightning/pull/1896))
24+
2325
### Deprecated
2426

2527
- Dropped official support/testing for older PyTorch versions <1.3 ([#1917](https://github.com/PyTorchLightning/pytorch-lightning/pull/1917))

docs/source/hyperparameters.rst

Lines changed: 86 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ Now in your main trainer file, add the Trainer args, the program args, and add t
7575
# ie: now --gpus --num_nodes ... --fast_dev_run all work in the cli
7676
parser = Trainer.add_argparse_args(parser)
7777

78-
hparams = parser.parse_args()
78+
args = parser.parse_args()
7979

8080
Now you can call run your program like so
8181

@@ -87,39 +87,50 @@ Finally, make sure to start the training like so:
8787

8888
.. code-block:: python
8989
90-
# YES
91-
model = LitModel(hparams)
92-
trainer = Trainer.from_argparse_args(hparams, early_stopping_callback=...)
90+
# init the trainer like this
91+
trainer = Trainer.from_argparse_args(args, early_stopping_callback=...)
92+
93+
# NOT like this
94+
trainer = Trainer(gpus=hparams.gpus, ...)
95+
96+
# init the model with Namespace directly
97+
model = LitModel(args)
98+
99+
# or init the model with all the key-value pairs
100+
dict_args = vars(args)
101+
model = LitModel(**dict_args)
93102
94-
# NO
95-
# model = LitModel(learning_rate=hparams.learning_rate, ...)
96-
# trainer = Trainer(gpus=hparams.gpus, ...)
103+
LightningModule hyperparameters
104+
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
97105

98-
LightningModule hparams
99-
^^^^^^^^^^^^^^^^^^^^^^^
106+
.. warning:: The use of `hparams` is no longer recommended (but still supported)
100107

101-
Normally, we don't hard-code the values to a model. We usually use the command line to
102-
modify the network and read those values in the LightningModule
108+
LightningModule is just an nn.Module, you can use it as you normally would. However, there are
109+
some best practices to improve readability and reproducibility.
110+
111+
1. It's more readable to specify all the arguments that go into a module (with default values).
112+
This helps users of your module know everything that is required to run this.
103113

104114
.. testcode::
105115

106116
class LitMNIST(LightningModule):
107117

108-
def __init__(self, hparams):
118+
def __init__(self, layer_1_dim=128, layer_2_dim=256, learning_rate=1e-4, batch_size=32, **kwargs):
109119
super().__init__()
120+
self.layer_1_dim = layer_1_dim
121+
self.layer_2_dim = layer_2_dim
122+
self.learning_rate = learning_rate
123+
self.batch_size = batch_size
110124
111-
# do this to save all arguments in any logger (tensorboard)
112-
self.hparams = hparams
113-
114-
self.layer_1 = torch.nn.Linear(28 * 28, hparams.layer_1_dim)
115-
self.layer_2 = torch.nn.Linear(hparams.layer_1_dim, hparams.layer_2_dim)
116-
self.layer_3 = torch.nn.Linear(hparams.layer_2_dim, 10)
125+
self.layer_1 = torch.nn.Linear(28 * 28, self.layer_1_dim)
126+
self.layer_2 = torch.nn.Linear(self.layer_1_dim, self.layer_2_dim)
127+
self.layer_3 = torch.nn.Linear(self.layer_2_dim, 10)
117128

118129
def train_dataloader(self):
119-
return DataLoader(mnist_train, batch_size=self.hparams.batch_size)
130+
return DataLoader(mnist_train, batch_size=self.batch_size)
120131

121132
def configure_optimizers(self):
122-
return Adam(self.parameters(), lr=self.hparams.learning_rate)
133+
return Adam(self.parameters(), lr=self.learning_rate)
123134

124135
@staticmethod
125136
def add_model_specific_args(parent_parser):
@@ -130,20 +141,59 @@ modify the network and read those values in the LightningModule
130141
parser.add_argument('--learning_rate', type=float, default=0.002)
131142
return parser
132143

133-
Now pass in the params when you init your model
144+
2. You can also pass in a dict or Namespace, but this obscures the parameters your module is looking
145+
for. The user would have to search the file to find what is parametrized.
146+
147+
.. code-block:: python
148+
149+
# using a argparse.Namespace
150+
class LitMNIST(LightningModule):
151+
152+
def __init__(self, hparams, *args, **kwargs):
153+
super().__init__()
154+
self.hparams = hparams
155+
156+
self.layer_1 = torch.nn.Linear(28 * 28, self.hparams.layer_1_dim)
157+
self.layer_2 = torch.nn.Linear(self.hparams.layer_1_dim, self.hparams.layer_2_dim)
158+
self.layer_3 = torch.nn.Linear(self.hparams.layer_2_dim, 10)
159+
160+
def train_dataloader(self):
161+
return DataLoader(mnist_train, batch_size=self.hparams.batch_size)
162+
163+
One way to get around this is to convert a Namespace or dict into key-value pairs using `**`
134164

135165
.. code-block:: python
136166
137167
parser = ArgumentParser()
138168
parser = LitMNIST.add_model_specific_args(parser)
139-
hparams = parser.parse_args()
140-
model = LitMNIST(hparams)
169+
args = parser.parse_args()
170+
dict_args = vars(args)
171+
model = LitMNIST(**dict_args)
172+
173+
Within any LightningModule all the arguments you pass into your `__init__` will be stored in
174+
the checkpoint so that you know all the values that went into creating this model.
175+
176+
We will also add all of those values to the TensorBoard hparams tab (unless it's an object which
177+
we won't). We also will store those values into checkpoints for you which you can use to init your
178+
models.
179+
180+
.. code-block:: python
141181
142-
The line `self.hparams = hparams` is very special. This line assigns your hparams to the LightningModule.
143-
This does two things:
182+
class LitMNIST(LightningModule):
183+
184+
def __init__(self, layer_1_dim, some_other_param):
185+
super().__init__()
186+
self.layer_1_dim = layer_1_dim
187+
self.some_other_param = some_other_param
188+
189+
self.layer_1 = torch.nn.Linear(28 * 28, self.layer_1_dim)
190+
191+
self.layer_2 = torch.nn.Linear(self.layer_1_dim, self.some_other_param)
192+
self.layer_3 = torch.nn.Linear(self.some_other_param, 10)
193+
194+
195+
model = LitMNIST(10, 20)
144196
145-
1. It adds them automatically to TensorBoard logs under the hparams tab.
146-
2. Lightning will save those hparams to the checkpoint and use them to restore the module correctly.
147197
148198
Trainer args
149199
^^^^^^^^^^^^
@@ -171,27 +221,27 @@ polluting the main.py file, the LightningModule lets you define arguments for ea
171221

172222
class LitMNIST(LightningModule):
173223

174-
def __init__(self, hparams):
224+
def __init__(self, layer_1_dim, **kwargs):
175225
super().__init__()
176-
self.layer_1 = torch.nn.Linear(28 * 28, hparams.layer_1_dim)
226+
self.layer_1 = torch.nn.Linear(28 * 28, layer_1_dim)
177227
178228
@staticmethod
179229
def add_model_specific_args(parent_parser):
180-
parser = ArgumentParser(parents=[parent_parser])
230+
parser = ArgumentParser(parents=[parent_parser], add_help=False)
181231
parser.add_argument('--layer_1_dim', type=int, default=128)
182232
return parser
183233

184234
.. testcode::
185235

186236
class GoodGAN(LightningModule):
187237

188-
def __init__(self, hparams):
238+
def __init__(self, encoder_layers, **kwargs):
189239
super().__init__()
190-
self.encoder = Encoder(layers=hparams.encoder_layers)
240+
self.encoder = Encoder(layers=encoder_layers)
191241
192242
@staticmethod
193243
def add_model_specific_args(parent_parser):
194-
parser = ArgumentParser(parents=[parent_parser])
244+
parser = ArgumentParser(parents=[parent_parser], add_help=False)
195245
parser.add_argument('--encoder_layers', type=int, default=12)
196246
return parser
197247

@@ -201,14 +251,14 @@ Now we can allow each model to inject the arguments it needs in the ``main.py``
201251
.. code-block:: python
202252
203253
def main(args):
254+
dict_args = vars(args)
204255
205256
# pick model
206257
if args.model_name == 'gan':
207-
model = GoodGAN(hparams=args)
258+
model = GoodGAN(**dict_args)
208259
elif args.model_name == 'mnist':
209-
model = LitMNIST(hparams=args)
260+
model = LitMNIST(**dict_args)
210261
211-
model = LitMNIST(hparams=args)
212262
trainer = Trainer.from_argparse_args(args)
213263
trainer.fit(model)
214264

docs/source/lr_finder.rst

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -36,18 +36,17 @@ hyperparameters of the model.
3636
# default: no automatic learning rate finder
3737
trainer = Trainer(auto_lr_find=False)
3838

39-
When the ``lr`` or ``learning_rate`` key in hparams exists, this flag sets your learning_rate.
40-
In both cases, if the respective fields are not found, an error will be thrown.
41-
39+
This flag sets your learning rate which can be accessed via ``self.lr`` or ``self.learning_rate``.
40+
4241
.. testcode::
4342

4443
class LitModel(LightningModule):
4544

46-
def __init__(self, hparams):
47-
self.hparams = hparams
45+
def __init__(self, learning_rate):
46+
self.learning_rate = learning_rate
4847

4948
def configure_optimizers(self):
50-
return Adam(self.parameters(), lr=self.hparams.lr|self.hparams.learning_rate)
49+
return Adam(self.parameters(), lr=(self.lr or self.learning_rate))
5150

5251
# finds learning rate automatically
5352
# sets hparams.lr or hparams.learning_rate to that learning rate
@@ -97,7 +96,7 @@ of this would look like
9796
9897
# update hparams of the model
9998
model.hparams.lr = new_lr
100-
99+
101100
# Fit model
102101
trainer.fit(model)
103102

docs/source/training_tricks.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ a binary search.
6767
.. code-block:: python
6868
6969
def train_dataloader(self):
70-
return DataLoader(train_dataset, batch_size=self.hparams.batch_size)
70+
return DataLoader(train_dataset, batch_size=self.batch_size)
7171
7272
.. warning::
7373

docs/source/weights_loading.rst

Lines changed: 28 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -59,24 +59,20 @@ Or disable it by passing
5959
trainer = Trainer(checkpoint_callback=False)
6060

6161

62-
The Lightning checkpoint also saves the hparams (hyperparams) passed into the LightningModule init.
62+
The Lightning checkpoint also saves the arguments passed into the LightningModule init
63+
under the `module_arguments` key in the checkpoint.
6364

64-
.. note:: hparams is a `Namespace <https://docs.python.org/2/library/argparse.html#argparse.Namespace>`_.
65-
66-
.. testcode::
67-
68-
from argparse import Namespace
65+
.. code-block:: python
6966
70-
# usually these come from command line args
71-
args = Namespace(learning_rate=0.001)
67+
class MyLightningModule(LightningModule):
7268
73-
# define you module to have hparams as the first arg
74-
# this means your checkpoint will have everything that went into making
75-
# this model (in this case, learning rate)
76-
class MyLightningModule(LightningModule):
69+
def __init__(self, learning_rate, *args, **kwargs):
70+
super().__init__()
7771
78-
def __init__(self, hparams, *args, **kwargs):
79-
self.hparams = hparams
72+
# all init args were saved to the checkpoint
73+
checkpoint = torch.load(CKPT_PATH)
74+
print(checkpoint['module_arguments'])
75+
# {'learning_rate': the_value}
8076
8177
Manual saving
8278
^^^^^^^^^^^^^
@@ -92,37 +88,42 @@ You can manually save checkpoints and restore your model from the checkpointed s
9288
Checkpoint Loading
9389
------------------
9490

95-
To load a model along with its weights, biases and hyperparameters use following method.
91+
To load a model along with its weights, biases and `module_arguments` use following method.
9692

9793
.. code-block:: python
9894
9995
model = MyLightingModule.load_from_checkpoint(PATH)
100-
model.eval()
101-
y_hat = model(x)
102-
103-
The above only works if you used `hparams` in your model definition
10496
105-
.. testcode::
106-
107-
class LitModel(LightningModule):
97+
print(model.learning_rate)
98+
# prints the learning_rate you used in this checkpoint
10899
109-
def __init__(self, hparams):
110-
self.hparams = hparams
111-
self.l1 = nn.Linear(hparams.in_dim, hparams.out_dim)
100+
model.eval()
101+
y_hat = model(x)
112102
113-
But if you don't and instead pass individual parameters
103+
But if you don't want to use the values saved in the checkpoint, pass in your own here
114104

115105
.. testcode::
116106

117107
class LitModel(LightningModule):
118108

119109
def __init__(self, in_dim, out_dim):
120-
self.l1 = nn.Linear(in_dim, out_dim)
110+
super().__init__()
111+
self.in_dim = in_dim
112+
self.out_dim = out_dim
113+
self.l1 = nn.Linear(self.in_dim, self.out_dim)
121114

122115
you can restore the model like this
123116

124117
.. code-block:: python
125118
119+
# if you train and save the model like this it will use these values when loading
120+
# the weights. But you can overwrite this
121+
LitModel(in_dim=32, out_dim=10)
122+
123+
# uses in_dim=32, out_dim=10
124+
model = LitModel.load_from_checkpoint(PATH)
125+
126+
# uses in_dim=128, out_dim=10
126127
model = LitModel.load_from_checkpoint(PATH, in_dim=128, out_dim=10)
127128
128129

0 commit comments

Comments
 (0)