diff --git a/.gitignore b/.gitignore index cb8fd278c5c4f..aa9778a6c6738 100644 --- a/.gitignore +++ b/.gitignore @@ -133,4 +133,5 @@ mnist/ # pl tests ml-runs/ *.zip -pytorch\ lightning \ No newline at end of file +pytorch\ lightning +test-reports/ \ No newline at end of file diff --git a/.run_local_tests.sh b/.run_local_tests.sh index 20fe84ff22fcf..83012a3932a79 100644 --- a/.run_local_tests.sh +++ b/.run_local_tests.sh @@ -14,3 +14,6 @@ rm -rf ./tests/tests/* rm -rf ./lightning_logs python -m coverage run --source pytorch_lightning -m py.test pytorch_lightning tests pl_examples -v --doctest-modules --flake8 python -m coverage report -m + +# specific file +# python -m coverage run --source pytorch_lightning -m py.test -k test_trainer.py --flake8 diff --git a/CHANGELOG.md b/CHANGELOG.md index 3f893a298b177..5d56f4a9287c9 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -20,6 +20,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Removed non-finite values from loss in `LRFinder` ([#1862](https://github.com/PyTorchLightning/pytorch-lightning/pull/1862)) +- Allow passing model hyperparameters as complete kwarg list ([#1896](https://github.com/PyTorchLightning/pytorch-lightning/pull/1896)) + ### Deprecated - Dropped official support/testing for older PyTorch versions <1.3 ([#1917](https://github.com/PyTorchLightning/pytorch-lightning/pull/1917)) diff --git a/docs/source/hyperparameters.rst b/docs/source/hyperparameters.rst index 5b2dd343fb622..d931b8ab52fbf 100644 --- a/docs/source/hyperparameters.rst +++ b/docs/source/hyperparameters.rst @@ -75,7 +75,7 @@ Now in your main trainer file, add the Trainer args, the program args, and add t # ie: now --gpus --num_nodes ... --fast_dev_run all work in the cli parser = Trainer.add_argparse_args(parser) - hparams = parser.parse_args() + args = parser.parse_args() Now you can call run your program like so @@ -87,39 +87,50 @@ Finally, make sure to start the training like so: .. code-block:: python - # YES - model = LitModel(hparams) - trainer = Trainer.from_argparse_args(hparams, early_stopping_callback=...) + # init the trainer like this + trainer = Trainer.from_argparse_args(args, early_stopping_callback=...) + + # NOT like this + trainer = Trainer(gpus=hparams.gpus, ...) + + # init the model with Namespace directly + model = LitModel(args) + + # or init the model with all the key-value pairs + dict_args = vars(args) + model = LitModel(**dict_args) - # NO - # model = LitModel(learning_rate=hparams.learning_rate, ...) - # trainer = Trainer(gpus=hparams.gpus, ...) +LightningModule hyperparameters +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -LightningModule hparams -^^^^^^^^^^^^^^^^^^^^^^^ +.. warning:: The use of `hparams` is no longer recommended (but still supported) -Normally, we don't hard-code the values to a model. We usually use the command line to -modify the network and read those values in the LightningModule +LightningModule is just an nn.Module, you can use it as you normally would. However, there are +some best practices to improve readability and reproducibility. + +1. It's more readable to specify all the arguments that go into a module (with default values). +This helps users of your module know everything that is required to run this. .. testcode:: class LitMNIST(LightningModule): - def __init__(self, hparams): + def __init__(self, layer_1_dim=128, layer_2_dim=256, learning_rate=1e-4, batch_size=32, **kwargs): super().__init__() + self.layer_1_dim = layer_1_dim + self.layer_2_dim = layer_2_dim + self.learning_rate = learning_rate + self.batch_size = batch_size - # do this to save all arguments in any logger (tensorboard) - self.hparams = hparams - - self.layer_1 = torch.nn.Linear(28 * 28, hparams.layer_1_dim) - self.layer_2 = torch.nn.Linear(hparams.layer_1_dim, hparams.layer_2_dim) - self.layer_3 = torch.nn.Linear(hparams.layer_2_dim, 10) + self.layer_1 = torch.nn.Linear(28 * 28, self.layer_1_dim) + self.layer_2 = torch.nn.Linear(self.layer_1_dim, self.layer_2_dim) + self.layer_3 = torch.nn.Linear(self.layer_2_dim, 10) def train_dataloader(self): - return DataLoader(mnist_train, batch_size=self.hparams.batch_size) + return DataLoader(mnist_train, batch_size=self.batch_size) def configure_optimizers(self): - return Adam(self.parameters(), lr=self.hparams.learning_rate) + return Adam(self.parameters(), lr=self.learning_rate) @staticmethod def add_model_specific_args(parent_parser): @@ -130,20 +141,59 @@ modify the network and read those values in the LightningModule parser.add_argument('--learning_rate', type=float, default=0.002) return parser -Now pass in the params when you init your model +2. You can also pass in a dict or Namespace, but this obscures the parameters your module is looking +for. The user would have to search the file to find what is parametrized. + +.. code-block:: python + + # using a argparse.Namespace + class LitMNIST(LightningModule): + + def __init__(self, hparams, *args, **kwargs): + super().__init__() + self.hparams = hparams + + self.layer_1 = torch.nn.Linear(28 * 28, self.hparams.layer_1_dim) + self.layer_2 = torch.nn.Linear(self.hparams.layer_1_dim, self.hparams.layer_2_dim) + self.layer_3 = torch.nn.Linear(self.hparams.layer_2_dim, 10) + + def train_dataloader(self): + return DataLoader(mnist_train, batch_size=self.hparams.batch_size) + +One way to get around this is to convert a Namespace or dict into key-value pairs using `**` .. code-block:: python parser = ArgumentParser() parser = LitMNIST.add_model_specific_args(parser) - hparams = parser.parse_args() - model = LitMNIST(hparams) + args = parser.parse_args() + dict_args = vars(args) + model = LitMNIST(**dict_args) + +Within any LightningModule all the arguments you pass into your `__init__` will be stored in +the checkpoint so that you know all the values that went into creating this model. + +We will also add all of those values to the TensorBoard hparams tab (unless it's an object which +we won't). We also will store those values into checkpoints for you which you can use to init your +models. + +.. code-block:: python -The line `self.hparams = hparams` is very special. This line assigns your hparams to the LightningModule. -This does two things: + class LitMNIST(LightningModule): + + def __init__(self, layer_1_dim, some_other_param): + super().__init__() + self.layer_1_dim = layer_1_dim + self.some_other_param = some_other_param + + self.layer_1 = torch.nn.Linear(28 * 28, self.layer_1_dim) + + self.layer_2 = torch.nn.Linear(self.layer_1_dim, self.some_other_param) + self.layer_3 = torch.nn.Linear(self.some_other_param, 10) + + + model = LitMNIST(10, 20) -1. It adds them automatically to TensorBoard logs under the hparams tab. -2. Lightning will save those hparams to the checkpoint and use them to restore the module correctly. Trainer args ^^^^^^^^^^^^ @@ -171,13 +221,13 @@ polluting the main.py file, the LightningModule lets you define arguments for ea class LitMNIST(LightningModule): - def __init__(self, hparams): + def __init__(self, layer_1_dim, **kwargs): super().__init__() - self.layer_1 = torch.nn.Linear(28 * 28, hparams.layer_1_dim) + self.layer_1 = torch.nn.Linear(28 * 28, layer_1_dim) @staticmethod def add_model_specific_args(parent_parser): - parser = ArgumentParser(parents=[parent_parser]) + parser = ArgumentParser(parents=[parent_parser], add_help=False) parser.add_argument('--layer_1_dim', type=int, default=128) return parser @@ -185,13 +235,13 @@ polluting the main.py file, the LightningModule lets you define arguments for ea class GoodGAN(LightningModule): - def __init__(self, hparams): + def __init__(self, encoder_layers, **kwargs): super().__init__() - self.encoder = Encoder(layers=hparams.encoder_layers) + self.encoder = Encoder(layers=encoder_layers) @staticmethod def add_model_specific_args(parent_parser): - parser = ArgumentParser(parents=[parent_parser]) + parser = ArgumentParser(parents=[parent_parser], add_help=False) parser.add_argument('--encoder_layers', type=int, default=12) return parser @@ -201,14 +251,14 @@ Now we can allow each model to inject the arguments it needs in the ``main.py`` .. code-block:: python def main(args): + dict_args = vars(args) # pick model if args.model_name == 'gan': - model = GoodGAN(hparams=args) + model = GoodGAN(**dict_args) elif args.model_name == 'mnist': - model = LitMNIST(hparams=args) + model = LitMNIST(**dict_args) - model = LitMNIST(hparams=args) trainer = Trainer.from_argparse_args(args) trainer.fit(model) diff --git a/docs/source/lr_finder.rst b/docs/source/lr_finder.rst index b92088fcae2f8..308cc216e1343 100755 --- a/docs/source/lr_finder.rst +++ b/docs/source/lr_finder.rst @@ -36,18 +36,17 @@ hyperparameters of the model. # default: no automatic learning rate finder trainer = Trainer(auto_lr_find=False) -When the ``lr`` or ``learning_rate`` key in hparams exists, this flag sets your learning_rate. -In both cases, if the respective fields are not found, an error will be thrown. - +This flag sets your learning rate which can be accessed via ``self.lr`` or ``self.learning_rate``. + .. testcode:: class LitModel(LightningModule): - def __init__(self, hparams): - self.hparams = hparams + def __init__(self, learning_rate): + self.learning_rate = learning_rate def configure_optimizers(self): - return Adam(self.parameters(), lr=self.hparams.lr|self.hparams.learning_rate) + return Adam(self.parameters(), lr=(self.lr or self.learning_rate)) # finds learning rate automatically # sets hparams.lr or hparams.learning_rate to that learning rate @@ -97,7 +96,7 @@ of this would look like # update hparams of the model model.hparams.lr = new_lr - + # Fit model trainer.fit(model) diff --git a/docs/source/training_tricks.rst b/docs/source/training_tricks.rst index b748465eec014..53cb95bf9f029 100644 --- a/docs/source/training_tricks.rst +++ b/docs/source/training_tricks.rst @@ -67,7 +67,7 @@ a binary search. .. code-block:: python def train_dataloader(self): - return DataLoader(train_dataset, batch_size=self.hparams.batch_size) + return DataLoader(train_dataset, batch_size=self.batch_size) .. warning:: diff --git a/docs/source/weights_loading.rst b/docs/source/weights_loading.rst index 64a6950738ef1..11844678397a9 100644 --- a/docs/source/weights_loading.rst +++ b/docs/source/weights_loading.rst @@ -59,24 +59,20 @@ Or disable it by passing trainer = Trainer(checkpoint_callback=False) -The Lightning checkpoint also saves the hparams (hyperparams) passed into the LightningModule init. +The Lightning checkpoint also saves the arguments passed into the LightningModule init +under the `module_arguments` key in the checkpoint. -.. note:: hparams is a `Namespace `_. - -.. testcode:: - - from argparse import Namespace +.. code-block:: python - # usually these come from command line args - args = Namespace(learning_rate=0.001) + class MyLightningModule(LightningModule): - # define you module to have hparams as the first arg - # this means your checkpoint will have everything that went into making - # this model (in this case, learning rate) - class MyLightningModule(LightningModule): + def __init__(self, learning_rate, *args, **kwargs): + super().__init__() - def __init__(self, hparams, *args, **kwargs): - self.hparams = hparams + # all init args were saved to the checkpoint + checkpoint = torch.load(CKPT_PATH) + print(checkpoint['module_arguments']) + # {'learning_rate': the_value} Manual saving ^^^^^^^^^^^^^ @@ -92,37 +88,42 @@ You can manually save checkpoints and restore your model from the checkpointed s Checkpoint Loading ------------------ -To load a model along with its weights, biases and hyperparameters use following method. +To load a model along with its weights, biases and `module_arguments` use following method. .. code-block:: python model = MyLightingModule.load_from_checkpoint(PATH) - model.eval() - y_hat = model(x) - -The above only works if you used `hparams` in your model definition -.. testcode:: - - class LitModel(LightningModule): + print(model.learning_rate) + # prints the learning_rate you used in this checkpoint - def __init__(self, hparams): - self.hparams = hparams - self.l1 = nn.Linear(hparams.in_dim, hparams.out_dim) + model.eval() + y_hat = model(x) -But if you don't and instead pass individual parameters +But if you don't want to use the values saved in the checkpoint, pass in your own here .. testcode:: class LitModel(LightningModule): def __init__(self, in_dim, out_dim): - self.l1 = nn.Linear(in_dim, out_dim) + super().__init__() + self.in_dim = in_dim + self.out_dim = out_dim + self.l1 = nn.Linear(self.in_dim, self.out_dim) you can restore the model like this .. code-block:: python + # if you train and save the model like this it will use these values when loading + # the weights. But you can overwrite this + LitModel(in_dim=32, out_dim=10) + + # uses in_dim=32, out_dim=10 + model = LitModel.load_from_checkpoint(PATH) + + # uses in_dim=128, out_dim=10 model = LitModel.load_from_checkpoint(PATH, in_dim=128, out_dim=10) diff --git a/pl_examples/domain_templates/computer_vision_fine_tuning.py b/pl_examples/domain_templates/computer_vision_fine_tuning.py index 42a0a936d9e34..4371c869450a3 100644 --- a/pl_examples/domain_templates/computer_vision_fine_tuning.py +++ b/pl_examples/domain_templates/computer_vision_fine_tuning.py @@ -148,10 +148,24 @@ class TransferLearningModel(pl.LightningModule): dl_path: Path where the data will be downloaded """ def __init__(self, - hparams: argparse.Namespace, - dl_path: Union[str, Path]) -> None: + dl_path: Union[str, Path], + backbone: str = 'resnet50', + train_bn: bool = True, + milestones: tuple = (5, 10), + batch_size: int = 8, + lr: float = 1e-2, + lr_scheduler_gamma: float = 1e-1, + num_workers: int = 6, **kwargs) -> None: super().__init__() - self.hparams = hparams + self.dl_path = dl_path + self.backbone = backbone + self.train_bn = train_bn + self.milestones = milestones + self.batch_size = batch_size + self.lr = lr + self.lr_scheduler_gamma = lr_scheduler_gamma + self.num_workers = num_workers + self.dl_path = dl_path self.__build_model() @@ -159,12 +173,12 @@ def __build_model(self): """Define model layers & loss.""" # 1. Load pre-trained network: - model_func = getattr(models, self.hparams.backbone) + model_func = getattr(models, self.backbone) backbone = model_func(pretrained=True) _layers = list(backbone.children())[:-1] self.feature_extractor = torch.nn.Sequential(*_layers) - freeze(module=self.feature_extractor, train_bn=self.hparams.train_bn) + freeze(module=self.feature_extractor, train_bn=self.train_bn) # 2. Classifier: _fc_layers = [torch.nn.Linear(2048, 256), @@ -194,29 +208,29 @@ def train(self, mode=True): super().train(mode=mode) epoch = self.current_epoch - if epoch < self.hparams.milestones[0] and mode: + if epoch < self.milestones[0] and mode: # feature extractor is frozen (except for BatchNorm layers) freeze(module=self.feature_extractor, - train_bn=self.hparams.train_bn) + train_bn=self.train_bn) - elif self.hparams.milestones[0] <= epoch < self.hparams.milestones[1] and mode: + elif self.milestones[0] <= epoch < self.milestones[1] and mode: # Unfreeze last two layers of the feature extractor freeze(module=self.feature_extractor, n=-2, - train_bn=self.hparams.train_bn) + train_bn=self.train_bn) def on_epoch_start(self): """Use `on_epoch_start` to unfreeze layers progressively.""" optimizer = self.trainer.optimizers[0] - if self.current_epoch == self.hparams.milestones[0]: + if self.current_epoch == self.milestones[0]: _unfreeze_and_add_param_group(module=self.feature_extractor[-2:], optimizer=optimizer, - train_bn=self.hparams.train_bn) + train_bn=self.train_bn) - elif self.current_epoch == self.hparams.milestones[1]: + elif self.current_epoch == self.milestones[1]: _unfreeze_and_add_param_group(module=self.feature_extractor[:-2], optimizer=optimizer, - train_bn=self.hparams.train_bn) + train_bn=self.train_bn) def training_step(self, batch, batch_idx): @@ -246,7 +260,7 @@ def training_epoch_end(self, outputs): for output in outputs]).mean() train_acc_mean = torch.stack([output['num_correct'] for output in outputs]).sum().float() - train_acc_mean /= (len(outputs) * self.hparams.batch_size) + train_acc_mean /= (len(outputs) * self.batch_size) return {'log': {'train_loss': train_loss_mean, 'train_acc': train_acc_mean, 'step': self.current_epoch}} @@ -273,7 +287,7 @@ def validation_epoch_end(self, outputs): for output in outputs]).mean() val_acc_mean = torch.stack([output['num_correct'] for output in outputs]).sum().float() - val_acc_mean /= (len(outputs) * self.hparams.batch_size) + val_acc_mean /= (len(outputs) * self.batch_size) return {'log': {'val_loss': val_loss_mean, 'val_acc': val_acc_mean, 'step': self.current_epoch}} @@ -281,11 +295,11 @@ def validation_epoch_end(self, outputs): def configure_optimizers(self): optimizer = optim.Adam(filter(lambda p: p.requires_grad, self.parameters()), - lr=self.hparams.lr) + lr=self.lr) scheduler = MultiStepLR(optimizer, - milestones=self.hparams.milestones, - gamma=self.hparams.lr_scheduler_gamma) + milestones=self.milestones, + gamma=self.lr_scheduler_gamma) return [optimizer], [scheduler] @@ -326,8 +340,8 @@ def __dataloader(self, train): _dataset = self.train_dataset if train else self.valid_dataset loader = DataLoader(dataset=_dataset, - batch_size=self.hparams.batch_size, - num_workers=self.hparams.num_workers, + batch_size=self.batch_size, + num_workers=self.num_workers, shuffle=True if train else False) return loader @@ -397,28 +411,28 @@ def add_model_specific_args(parent_parser): return parser -def main(hparams: argparse.Namespace) -> None: +def main(args: argparse.Namespace) -> None: """Train the model. Args: - hparams: Model hyper-parameters + args: Model hyper-parameters Note: For the sake of the example, the images dataset will be downloaded to a temporary directory. """ - with TemporaryDirectory(dir=hparams.root_data_path) as tmp_dir: + with TemporaryDirectory(dir=args.root_data_path) as tmp_dir: - model = TransferLearningModel(hparams, dl_path=tmp_dir) + model = TransferLearningModel(dl_path=tmp_dir, **vars(args)) trainer = pl.Trainer( weights_summary=None, show_progress_bar=True, num_sanity_val_steps=0, - gpus=hparams.gpus, - min_epochs=hparams.nb_epochs, - max_epochs=hparams.nb_epochs) + gpus=args.gpus, + min_epochs=args.nb_epochs, + max_epochs=args.nb_epochs) trainer.fit(model) diff --git a/pl_examples/domain_templates/generative_adversarial_net.py b/pl_examples/domain_templates/generative_adversarial_net.py index 99a57f1a0b96a..23049358395b9 100644 --- a/pl_examples/domain_templates/generative_adversarial_net.py +++ b/pl_examples/domain_templates/generative_adversarial_net.py @@ -72,13 +72,22 @@ def forward(self, img): class GAN(LightningModule): - def __init__(self, hparams): + def __init__(self, + latent_dim: int = 100, + lr: float = 0.0002, + b1: float = 0.5, + b2: float = 0.999, + batch_size: int = 64, **kwargs): super().__init__() - self.hparams = hparams + self.latent_dim = latent_dim + self.lr = lr + self.b1 = b1 + self.b2 = b2 + self.batch_size = batch_size # networks mnist_shape = (1, 28, 28) - self.generator = Generator(latent_dim=hparams.latent_dim, img_shape=mnist_shape) + self.generator = Generator(latent_dim=self.latent_dim, img_shape=mnist_shape) self.discriminator = Discriminator(img_shape=mnist_shape) # cache for generated images @@ -98,7 +107,7 @@ def training_step(self, batch, batch_idx, optimizer_idx): # train generator if optimizer_idx == 0: # sample noise - z = torch.randn(imgs.shape[0], self.hparams.latent_dim) + z = torch.randn(imgs.shape[0], self.latent_dim) z = z.type_as(imgs) # generate images @@ -152,9 +161,9 @@ def training_step(self, batch, batch_idx, optimizer_idx): return output def configure_optimizers(self): - lr = self.hparams.lr - b1 = self.hparams.b1 - b2 = self.hparams.b2 + lr = self.lr + b1 = self.b1 + b2 = self.b2 opt_g = torch.optim.Adam(self.generator.parameters(), lr=lr, betas=(b1, b2)) opt_d = torch.optim.Adam(self.discriminator.parameters(), lr=lr, betas=(b1, b2)) @@ -164,10 +173,10 @@ def train_dataloader(self): transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize([0.5], [0.5])]) dataset = MNIST(os.getcwd(), train=True, download=True, transform=transform) - return DataLoader(dataset, batch_size=self.hparams.batch_size) + return DataLoader(dataset, batch_size=self.batch_size) def on_epoch_end(self): - z = torch.randn(8, self.hparams.latent_dim) + z = torch.randn(8, self.latent_dim) z = z.type_as(self.last_imgs) # log sampled images diff --git a/pl_examples/domain_templates/imagenet.py b/pl_examples/domain_templates/imagenet.py index c274cec90ddbb..e6584d76554fb 100644 --- a/pl_examples/domain_templates/imagenet.py +++ b/pl_examples/domain_templates/imagenet.py @@ -29,13 +29,26 @@ class ImageNetLightningModel(LightningModule): - def __init__(self, hparams): + def __init__(self, + arch, + pretrained, + lr: float, + momentum: float, + weight_decay: int, + data_path: str, + batch_size: int, **kwargs): """ TODO: add docstring here """ super().__init__() - self.hparams = hparams - self.model = models.__dict__[self.hparams.arch](pretrained=self.hparams.pretrained) + self.arch = arch + self.pretrained = pretrained + self.lr = lr + self.momentum = momentum + self.weight_decay = weight_decay + self.data_path = data_path + self.batch_size = batch_size + self.model = models.__dict__[self.arch](pretrained=self.pretrained) def forward(self, x): return self.model(x) @@ -112,9 +125,9 @@ def __accuracy(cls, output, target, topk=(1,)): def configure_optimizers(self): optimizer = optim.SGD( self.parameters(), - lr=self.hparams.lr, - momentum=self.hparams.momentum, - weight_decay=self.hparams.weight_decay + lr=self.lr, + momentum=self.momentum, + weight_decay=self.weight_decay ) scheduler = lr_scheduler.ExponentialLR(optimizer, gamma=0.1) return [optimizer], [scheduler] @@ -125,7 +138,7 @@ def train_dataloader(self): std=[0.229, 0.224, 0.225], ) - train_dir = os.path.join(self.hparams.data_path, 'train') + train_dir = os.path.join(self.data_path, 'train') train_dataset = datasets.ImageFolder( train_dir, transforms.Compose([ @@ -142,7 +155,7 @@ def train_dataloader(self): train_loader = torch.utils.data.DataLoader( dataset=train_dataset, - batch_size=self.hparams.batch_size, + batch_size=self.batch_size, shuffle=(train_sampler is None), num_workers=0, sampler=train_sampler @@ -154,7 +167,7 @@ def val_dataloader(self): mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225], ) - val_dir = os.path.join(self.hparams.data_path, 'val') + val_dir = os.path.join(self.data_path, 'val') val_loader = torch.utils.data.DataLoader( datasets.ImageFolder(val_dir, transforms.Compose([ transforms.Resize(256), @@ -162,7 +175,7 @@ def val_dataloader(self): transforms.ToTensor(), normalize, ])), - batch_size=self.hparams.batch_size, + batch_size=self.batch_size, shuffle=False, num_workers=0, ) diff --git a/pl_examples/domain_templates/reinforce_learn_Qnet.py b/pl_examples/domain_templates/reinforce_learn_Qnet.py index ff3f634da7817..95d6873a444db 100644 --- a/pl_examples/domain_templates/reinforce_learn_Qnet.py +++ b/pl_examples/domain_templates/reinforce_learn_Qnet.py @@ -190,22 +190,41 @@ def play_step(self, net: nn.Module, epsilon: float = 0.0, device: str = 'cpu') - class DQNLightning(pl.LightningModule): """ Basic DQN Model """ - def __init__(self, hparams: argparse.Namespace) -> None: + def __init__(self, + replay_size, + warm_start_steps: int, + gamma: float, + eps_start: int, + eps_end: int, + eps_last_frame: int, + sync_rate, + lr: float, + episode_length, + batch_size, **kwargs) -> None: super().__init__() - self.hparams = hparams - - self.env = gym.make(self.hparams.env) + self.replay_size = replay_size + self.warm_start_steps = warm_start_steps + self.gamma = gamma + self.eps_start = eps_start + self.eps_end = eps_end + self.eps_last_frame = eps_last_frame + self.sync_rate = sync_rate + self.lr = lr + self.episode_length = episode_length + self.batch_size = batch_size + + self.env = gym.make(self.env) obs_size = self.env.observation_space.shape[0] n_actions = self.env.action_space.n self.net = DQN(obs_size, n_actions) self.target_net = DQN(obs_size, n_actions) - self.buffer = ReplayBuffer(self.hparams.replay_size) + self.buffer = ReplayBuffer(self.replay_size) self.agent = Agent(self.env, self.buffer) self.total_reward = 0 self.episode_reward = 0 - self.populate(self.hparams.warm_start_steps) + self.populate(self.warm_start_steps) def populate(self, steps: int = 1000) -> None: """ @@ -250,7 +269,7 @@ def dqn_mse_loss(self, batch: Tuple[torch.Tensor, torch.Tensor]) -> torch.Tensor next_state_values[dones] = 0.0 next_state_values = next_state_values.detach() - expected_state_action_values = next_state_values * self.hparams.gamma + rewards + expected_state_action_values = next_state_values * self.gamma + rewards return nn.MSELoss()(state_action_values, expected_state_action_values) @@ -267,8 +286,8 @@ def training_step(self, batch: Tuple[torch.Tensor, torch.Tensor], nb_batch) -> O Training loss and log metrics """ device = self.get_device(batch) - epsilon = max(self.hparams.eps_end, self.hparams.eps_start - - self.global_step + 1 / self.hparams.eps_last_frame) + epsilon = max(self.eps_end, self.eps_start - + self.global_step + 1 / self.eps_last_frame) # step through environment with agent reward, done = self.agent.play_step(self.net, epsilon, device) @@ -282,7 +301,7 @@ def training_step(self, batch: Tuple[torch.Tensor, torch.Tensor], nb_batch) -> O self.episode_reward = 0 # Soft update of target network - if self.global_step % self.hparams.sync_rate == 0: + if self.global_step % self.sync_rate == 0: self.target_net.load_state_dict(self.net.state_dict()) log = {'total_reward': torch.tensor(self.total_reward).to(device), @@ -293,16 +312,17 @@ def training_step(self, batch: Tuple[torch.Tensor, torch.Tensor], nb_batch) -> O def configure_optimizers(self) -> List[Optimizer]: """Initialize Adam optimizer""" - optimizer = optim.Adam(self.net.parameters(), lr=self.hparams.lr) + optimizer = optim.Adam(self.net.parameters(), lr=self.lr) return [optimizer] def __dataloader(self) -> DataLoader: """Initialize the Replay Buffer dataset used for retrieving experiences""" - dataset = RLDataset(self.buffer, self.hparams.episode_length) - dataloader = DataLoader(dataset=dataset, - batch_size=self.hparams.batch_size, - sampler=None - ) + dataset = RLDataset(self.buffer, self.episode_length) + dataloader = DataLoader( + dataset=dataset, + batch_size=self.batch_size, + sampler=None, + ) return dataloader def train_dataloader(self) -> DataLoader: @@ -314,8 +334,8 @@ def get_device(self, batch) -> str: return batch[0].device.index if self.on_gpu else 'cpu' -def main(hparams) -> None: - model = DQNLightning(hparams) +def main(args) -> None: + model = DQNLightning(**vars(args)) trainer = pl.Trainer( gpus=1, diff --git a/pl_examples/domain_templates/semantic_segmentation.py b/pl_examples/domain_templates/semantic_segmentation.py index 9d98c799a7283..2f486c5b81827 100644 --- a/pl_examples/domain_templates/semantic_segmentation.py +++ b/pl_examples/domain_templates/semantic_segmentation.py @@ -1,5 +1,5 @@ import os -from argparse import ArgumentParser +from argparse import ArgumentParser, Namespace import numpy as np import torch @@ -128,14 +128,23 @@ class SegModel(pl.LightningModule): Adam optimizer is used along with Cosine Annealing learning rate scheduler. """ - def __init__(self, hparams): + def __init__(self, + data_path: str, + batch_size: int, + lr: float, + num_layers: int, + features_start: int, + bilinear: bool, **kwargs): super().__init__() - self.hparams = hparams - self.data_path = hparams.data_path - self.batch_size = hparams.batch_size - self.learning_rate = hparams.lr - self.net = UNet(num_classes=19, num_layers=hparams.num_layers, - features_start=hparams.features_start, bilinear=hparams.bilinear) + self.data_path = data_path + self.batch_size = batch_size + self.lr = lr + self.num_layers = num_layers + self.features_start = features_start + self.bilinear = bilinear + + self.net = UNet(num_classes=19, num_layers=self.num_layers, + features_start=self.features_start, bilinear=self.bilinear) self.transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize(mean=[0.35675976, 0.37380189, 0.3764753], @@ -181,11 +190,11 @@ def val_dataloader(self): return DataLoader(self.validset, batch_size=self.batch_size, shuffle=False) -def main(hparams): +def main(hparams: Namespace): # ------------------------ # 1 INIT LIGHTNING MODEL # ------------------------ - model = SegModel(hparams) + model = SegModel(**vars(hparams)) # ------------------------ # 2 SET LOGGER diff --git a/pl_examples/models/lightning_template.py b/pl_examples/models/lightning_template.py index 13b3bc67a912b..b309094254118 100644 --- a/pl_examples/models/lightning_template.py +++ b/pl_examples/models/lightning_template.py @@ -3,7 +3,6 @@ """ import os from argparse import ArgumentParser -from collections import OrderedDict import torch import torch.nn as nn @@ -34,25 +33,38 @@ class LightningTemplateModel(LightningModule): ... out_features=10, ... hidden_dim=1000, ... ) - >>> from argparse import Namespace - >>> hparams = Namespace(**params) - >>> model = LightningTemplateModel(hparams) + >>> model = LightningTemplateModel(**params) """ - def __init__(self, hparams): - """ - Pass in hyperparameters as a `argparse.Namespace` or a `dict` to the model. - """ + def __init__(self, + drop_prob: float = 0.2, + batch_size: int = 2, + in_features: int = 28 * 28, + learning_rate: float = 0.001 * 8, + optimizer_name: str = 'adam', + data_root: str = './datasets', + out_features: int = 10, + hidden_dim: int = 1000, + **kwargs + ) -> 'LightningTemplateModel': # init superclass super().__init__() - self.hparams = hparams - self.c_d1 = nn.Linear(in_features=self.hparams.in_features, - out_features=self.hparams.hidden_dim) - self.c_d1_bn = nn.BatchNorm1d(self.hparams.hidden_dim) - self.c_d1_drop = nn.Dropout(self.hparams.drop_prob) - - self.c_d2 = nn.Linear(in_features=self.hparams.hidden_dim, - out_features=self.hparams.out_features) + self.drop_prob = drop_prob + self.batch_size = batch_size + self.in_features = in_features + self.learning_rate = learning_rate + self.optimizer_name = optimizer_name + self.data_root = data_root + self.out_features = out_features + self.hidden_dim = hidden_dim + + self.c_d1 = nn.Linear(in_features=self.in_features, + out_features=self.hidden_dim) + self.c_d1_bn = nn.BatchNorm1d(self.hidden_dim) + self.c_d1_drop = nn.Dropout(self.drop_prob) + + self.c_d2 = nn.Linear(in_features=self.hidden_dim, + out_features=self.out_features) def forward(self, x): """ @@ -122,32 +134,32 @@ def configure_optimizers(self): Return whatever optimizers and learning rate schedulers you want here. At least one optimizer is required. """ - optimizer = optim.Adam(self.parameters(), lr=self.hparams.learning_rate) + optimizer = optim.Adam(self.parameters(), lr=self.learning_rate) scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=10) return [optimizer], [scheduler] def prepare_data(self): transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (1.0,))]) - self.mnist_train = MNIST(self.hparams.data_root, train=True, download=True, transform=transform) - self.mnist_test = MNIST(self.hparams.data_root, train=False, download=True, transform=transform) + self.mnist_train = MNIST(self.data_root, train=True, download=True, transform=transform) + self.mnist_test = MNIST(self.data_root, train=False, download=True, transform=transform) def train_dataloader(self): log.info('Training data loader called.') - return DataLoader(self.mnist_train, batch_size=self.hparams.batch_size, num_workers=4) + return DataLoader(self.mnist_train, batch_size=self.batch_size, num_workers=4) def val_dataloader(self): log.info('Validation data loader called.') - return DataLoader(self.mnist_test, batch_size=self.hparams.batch_size, num_workers=4) + return DataLoader(self.mnist_test, batch_size=self.batch_size, num_workers=4) def test_dataloader(self): log.info('Test data loader called.') - return DataLoader(self.mnist_test, batch_size=self.hparams.batch_size, num_workers=4) + return DataLoader(self.mnist_test, batch_size=self.batch_size, num_workers=4) @staticmethod def add_model_specific_args(parent_parser, root_dir): # pragma: no-cover """ - Parameters you define here will be available to your model through `self.hparams`. + Define parameters that only apply to this model """ parser = ArgumentParser(parents=[parent_parser]) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index 662c05a29338d..9eb3c5aa1288f 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -17,7 +17,7 @@ from pytorch_lightning.core.grads import GradInformation from pytorch_lightning.core.hooks import ModelHooks from pytorch_lightning.core.memory import ModelSummary -from pytorch_lightning.core.saving import ModelIO, load_hparams_from_tags_csv, load_hparams_from_yaml, update_hparams +from pytorch_lightning.core.saving import ModelIO, load_hparams_from_tags_csv, load_hparams_from_yaml from pytorch_lightning.utilities.device_dtype_mixin import DeviceDtypeModuleMixin from pytorch_lightning.overrides.data_parallel import LightningDistributedDataParallel from pytorch_lightning.utilities.exceptions import MisconfigurationException @@ -30,6 +30,8 @@ else: XLA_AVAILABLE = True +CHECKPOINT_KEY_MODULE_ARGS = 'module_arguments' + class LightningModule(ABC, DeviceDtypeModuleMixin, GradInformation, ModelIO, ModelHooks): @@ -62,17 +64,21 @@ def __init__(self, *args, **kwargs): #: True if using ddp2 self.use_ddp2 = False + # True if on tpu + self.use_tpu = False + #: True if using amp self.use_amp = False - self.hparams = None - #: Current dtype self._dtype = torch.float #: device reference self._device = torch.device('cpu') + # register all params passed into the child module in __init__ + self._auto_collect_arguments() + @property def on_gpu(self): """ @@ -1158,7 +1164,7 @@ def optimizer_step(self, current_epoch, batch_idx, optimizer, if self.trainer.global_step < 500: lr_scale = min(1., float(self.trainer.global_step + 1) / 500.) for pg in optimizer.param_groups: - pg['lr'] = lr_scale * self.hparams.learning_rate + pg['lr'] = lr_scale * self.learning_rate # update params optimizer.step() @@ -1312,7 +1318,7 @@ def train_dataloader(self): download=True) loader = torch.utils.data.DataLoader( dataset=dataset, - batch_size=self.hparams.batch_size, + batch_size=self.batch_size, shuffle=True ) return loader @@ -1363,7 +1369,7 @@ def test_dataloader(self): download=True) loader = torch.utils.data.DataLoader( dataset=dataset, - batch_size=self.hparams.batch_size, + batch_size=self.batch_size, shuffle=False ) @@ -1408,7 +1414,7 @@ def val_dataloader(self): transform=transform, download=True) loader = torch.utils.data.DataLoader( dataset=dataset, - batch_size=self.hparams.batch_size, + batch_size=self.batch_size, shuffle=False ) @@ -1448,46 +1454,13 @@ def load_from_checkpoint( map_location: Optional[Union[Dict[str, str], str, torch.device, int, Callable]] = None, hparams_file: Optional[str] = None, tags_csv: Optional[str] = None, # backward compatible, todo: remove in v0.9.0 - hparam_overrides: Optional[Dict] = None, **kwargs ) -> 'LightningModule': r""" Primary way of loading a model from a checkpoint. When Lightning saves a checkpoint - it stores the hyperparameters in the checkpoint if you initialized your :class:`LightningModule` - with an argument called ``hparams`` which is an object of :class:`~dict` or - :class:`~argparse.Namespace` (output of :meth:`~argparse.ArgumentParser.parse_args` - when parsing command line arguments). - If you want `hparams` to have a hierarchical structure, you have to define it as :class:`~dict`. - Any other arguments specified through \*args and \*\*kwargs will be passed to the model. - - Example: - .. code-block:: python - - # define hparams as Namespace - from argparse import Namespace - hparams = Namespace(**{'learning_rate': 0.1}) - - model = MyModel(hparams) - - class MyModel(LightningModule): - def __init__(self, hparams: Namespace): - self.learning_rate = hparams.learning_rate - - # ---------- - - # define hparams as dict - hparams = { - drop_prob: 0.2, - dataloader: { - batch_size: 32 - } - } - - model = MyModel(hparams) + it stores the arguments passed to `__init__` in the checkpoint under `module_arguments` - class MyModel(LightningModule): - def __init__(self, hparams: dict): - self.learning_rate = hparams['learning_rate'] + Any arguments specified through \*args and \*\*kwargs will override args stored in `module_arguments`. Args: checkpoint_path: Path to checkpoint. @@ -1556,15 +1529,8 @@ def __init__(self, hparams: dict): # override some of the params with new values MyLightningModule.load_from_checkpoint( PATH, - hparam_overrides={'num_layers': 128, 'pretrained_ckpt_path': NEW_PATH} - ) - - # or load passing whatever args the model takes to load - MyLightningModule.load_from_checkpoint( - 'path/to/checkpoint.ckpt', - learning_rate=0.1, # These arguments will be passed to the model using **kwargs - layers=2, - pretrained_model=some_model + num_layers=128, + pretrained_ckpt_path: NEW_PATH, ) # predict @@ -1594,46 +1560,23 @@ def __init__(self, hparams: dict): hparams['on_gpu'] = False # overwrite hparams by the given file - checkpoint['hparams'] = hparams + checkpoint[CHECKPOINT_KEY_MODULE_ARGS] = hparams - # override the hparam keys that were passed in - if hparam_overrides is not None: - update_hparams(hparams, hparam_overrides) + # override the module_arguments with values that were passed in + checkpoint[CHECKPOINT_KEY_MODULE_ARGS].update(kwargs) model = cls._load_model_state(checkpoint, *args, **kwargs) return model @classmethod def _load_model_state(cls, checkpoint: Dict[str, Any], *args, **kwargs) -> 'LightningModule': - cls_takes_hparams = 'hparams' in inspect.signature(cls.__init__).parameters - ckpt_hparams = checkpoint.get('hparams') - - if cls_takes_hparams: - if ckpt_hparams is not None: - hparams_type = checkpoint.get('hparams_type', 'Namespace') - if hparams_type.lower() == 'dict': - hparams = ckpt_hparams - elif hparams_type.lower() == 'namespace': - hparams = Namespace(**ckpt_hparams) - else: - rank_zero_warn( - f"Checkpoint does not contain hyperparameters but {cls.__name__}'s __init__" - " contains argument 'hparams'. Will pass in an empty Namespace instead." - " Did you forget to store your model hyperparameters in self.hparams?" - ) - hparams = {} - else: # The user's LightningModule does not define a hparams argument - if ckpt_hparams is None: - hparams = None - else: - raise MisconfigurationException( - f"Checkpoint contains hyperparameters but {cls.__name__}'s __init__ " - f"is missing the argument 'hparams'. Are you loading the correct checkpoint?" - ) + + # pass in the values we saved automatically + if CHECKPOINT_KEY_MODULE_ARGS in checkpoint: + model_args = checkpoint[CHECKPOINT_KEY_MODULE_ARGS] + kwargs.update(**model_args) # load the state_dict on the model automatically - if cls_takes_hparams: - kwargs.update(hparams=hparams) model = cls(*args, **kwargs) model.load_state_dict(checkpoint['state_dict']) @@ -1757,3 +1700,49 @@ def get_tqdm_dict(self) -> Dict[str, Union[int, str]]: rank_zero_warn("`get_tqdm_dict` was renamed to `get_progress_bar_dict` in v0.7.3" " and this method will be removed in v1.0.0", DeprecationWarning) return self.get_progress_bar_dict() + + def _auto_collect_arguments(self): + """Collect all arguments module arguments.""" + frame = inspect.currentframe() + + frame_args = _collect_init_args(frame.f_back, []) + child = _get_latest_child(frame) + + # set module_arguments in child + child._module_self_arguments = frame_args[-1] + child._module_parents_arguments = {} + for args in frame_args[:-1]: + child._module_parents_arguments.update(args) + + @property + def module_arguments(self) -> dict: + """Aggregate this module and all parents arguments.""" + args = dict(self._module_parents_arguments) + args.update(self._module_self_arguments) + return args + + +def _collect_init_args(frame, path_args: list) -> list: + """Recursive search for all children.""" + if '__class__' in frame.f_locals: + local_args = dict(frame.f_locals) + local_args.update(local_args.get('kwargs', {})) + local_args = {k: v for k, v in local_args.items() + if k not in ('args', 'kwargs', 'self', '__class__', 'frame', 'frame_args')} + # if 'hparams' in local_args: + # # back compatible hparams as single argument + # hparams = local_args.get('hparams') + # local_args.update(vars(hparams) if isinstance(hparams, Namespace) else hparams) + # recursive update + path_args.append(local_args) + return _collect_init_args(frame.f_back, path_args) + else: + return path_args + + +def _get_latest_child(frame, child: object = None) -> object: + """Recursive search for lowest child.""" + if 'self' in frame.f_locals: + return _get_latest_child(frame.f_back, frame.f_locals['self']) + else: + return child diff --git a/pytorch_lightning/profiler/__init__.py b/pytorch_lightning/profiler/__init__.py index 683baccafa858..fc684d143e4b8 100644 --- a/pytorch_lightning/profiler/__init__.py +++ b/pytorch_lightning/profiler/__init__.py @@ -98,8 +98,7 @@ from pytorch_lightning.profiler import Profiler, PassThroughProfiler class MyModel(LightningModule): - def __init__(self, hparams, profiler=None): - self.hparams = hparams + def __init__(self, profiler=None): self.profiler = profiler or PassThroughProfiler() def custom_processing_step(self, data): @@ -108,7 +107,7 @@ def custom_processing_step(self, data): return data profiler = Profiler() - model = MyModel(hparams, profiler) + model = MyModel(profiler) trainer = Trainer(profiler=profiler, max_epochs=1) """ diff --git a/pytorch_lightning/trainer/lr_finder.py b/pytorch_lightning/trainer/lr_finder.py index 6a76523f17fb4..abc71ece2ac3d 100755 --- a/pytorch_lightning/trainer/lr_finder.py +++ b/pytorch_lightning/trainer/lr_finder.py @@ -2,7 +2,7 @@ Trainer Learning Rate Finder """ from abc import ABC, abstractmethod -from typing import Optional +from typing import Optional, Sequence import numpy as np import torch @@ -20,6 +20,8 @@ class TrainerLRFinderMixin(ABC): + default_root_dir: str + @abstractmethod def save_checkpoint(self, *args): """Warning: this is just empty shell for code implemented in other class.""" @@ -35,17 +37,17 @@ def _run_lr_finder_internally(self, model: LightningModule): # TODO: log lr.results to self.logger if isinstance(self.auto_lr_find, str): # Try to find requested field, may be nested - if _nested_hasattr(model.hparams, self.auto_lr_find): - _nested_setattr(model.hparams, self.auto_lr_find, lr) + if _nested_hasattr(model, self.auto_lr_find): + _nested_setattr(model, self.auto_lr_find, lr) else: raise MisconfigurationException( f'`auto_lr_find` was set to {self.auto_lr_find}, however' ' could not find this as a field in `model.hparams`.') else: - if hasattr(model.hparams, 'lr'): - model.hparams.lr = lr - elif hasattr(model.hparams, 'learning_rate'): - model.hparams.learning_rate = lr + if hasattr(model, 'lr'): + model.lr = lr + elif hasattr(model, 'learning_rate'): + model.learning_rate = lr else: raise MisconfigurationException( 'When auto_lr_find is set to True, expects that hparams' @@ -350,7 +352,7 @@ class _LRCallback(Callback): """ def __init__(self, num_training: int, early_stop_threshold: float = 4.0, - progress_bar_refresh_rate: bool = False, + progress_bar_refresh_rate: int = 0, beta: float = 0.98): self.num_training = num_training self.early_stop_threshold = early_stop_threshold @@ -414,6 +416,8 @@ class _LinearLR(_LRScheduler): last_epoch: the index of last epoch. Default: -1. """ + last_epoch: int + base_lrs: Sequence def __init__(self, optimizer: torch.optim.Optimizer, @@ -454,6 +458,8 @@ class _ExponentialLR(_LRScheduler): last_epoch: the index of last epoch. Default: -1. """ + last_epoch: int + base_lrs: Sequence def __init__(self, optimizer: torch.optim.Optimizer, diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index e4eae0dedf143..3f51a168add75 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -35,6 +35,7 @@ from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities import rank_zero_warn, parsing + try: from apex import amp except ImportError: @@ -288,7 +289,7 @@ def __init__( auto_lr_find: If set to True, will `initially` run a learning rate finder, trying to optimize initial learning for faster convergence. Sets learning - rate in self.hparams.lr | self.hparams.learning_rate in the lightning module. + rate in self.lr or self.learning_rate in the LightningModule. To use a different key, set a string instead of True with the key name. replace_sampler_ddp: Explicitly enables or disables sampler replacement. @@ -303,7 +304,7 @@ def __init__( auto_scale_batch_size: If set to True, will `initially` run a batch size finder trying to find the largest batch size that fits into memory. - The result will be stored in self.hparams.batch_size in the LightningModule. + The result will be stored in self.batch_size in the LightningModule. Additionally, can be set to either `power` that estimates the batch size through a power search or `binsearch` that estimates the batch size through a binary search. """ @@ -951,8 +952,7 @@ def run_pretrain_routine(self, model: LightningModule): # log hyper-parameters if self.logger is not None: # save exp to get started - if hasattr(ref_model, "hparams"): - self.logger.log_hyperparams(ref_model.hparams) + self.logger.log_hyperparams(ref_model.module_arguments) self.logger.save() diff --git a/pytorch_lightning/trainer/training_io.py b/pytorch_lightning/trainer/training_io.py index 11771b21961ed..d036f5d9f670b 100644 --- a/pytorch_lightning/trainer/training_io.py +++ b/pytorch_lightning/trainer/training_io.py @@ -84,6 +84,7 @@ """ import os +import pickle import re import signal from abc import ABC @@ -95,7 +96,7 @@ import torch.distributed as torch_distrib from pytorch_lightning import _logger as log -from pytorch_lightning.core.lightning import LightningModule +from pytorch_lightning.core.lightning import LightningModule, CHECKPOINT_KEY_MODULE_ARGS from pytorch_lightning.loggers import LightningLoggerBase from pytorch_lightning.overrides.data_parallel import ( LightningDistributedDataParallel, @@ -119,6 +120,12 @@ else: HOROVOD_AVAILABLE = True +PRIMITIVE_TYPES = ( + bool, int, float, str, + list, tuple, set, dict, + Namespace, # for back compatibility +) + class TrainerIOMixin(ABC): @@ -141,6 +148,9 @@ class TrainerIOMixin(ABC): on_tpu: bool num_training_batches: int accumulate_grad_batches: int + use_amp: bool + use_native_amp: bool + scaler: ... def get_model(self): is_dp_module = isinstance(self.model, (LightningDistributedDataParallel, @@ -263,12 +273,11 @@ def save_checkpoint(self, filepath, weights_only: bool = False): # do the actual save try: self._atomic_save(checkpoint, filepath) - except AttributeError as e: - if 'hparams' in checkpoint: - del checkpoint['hparams'] - rank_zero_warn('warning, `hparams` dropped from checkpoint.' - f' An attribute is not picklable {e}') - + except AttributeError as err: + if CHECKPOINT_KEY_MODULE_ARGS in checkpoint: + del checkpoint[CHECKPOINT_KEY_MODULE_ARGS] + rank_zero_warn('Warning, `module_arguments` dropped from checkpoint.' + f' An attribute is not picklable {err}') self._atomic_save(checkpoint, filepath) def restore(self, checkpoint_path: str, on_gpu: bool): @@ -306,7 +315,15 @@ def restore(self, checkpoint_path: str, on_gpu: bool): # load training state (affects trainer only) self.restore_training_state(checkpoint) - def dump_checkpoint(self, weights_only: bool = False): + def dump_checkpoint(self, weights_only: bool = False) -> dict: + """Creating model checkpoint. + + Args: + weights_only: saving model weights only + + Return: + structured dictionary + """ checkpoint = { 'epoch': self.current_epoch + 1, 'global_step': self.global_step + 1, @@ -338,28 +355,15 @@ def dump_checkpoint(self, weights_only: bool = False): if self.use_amp and self.use_native_amp: checkpoint['native_amp_scaling_state'] = self.scaler.state_dict() - # add the hparams and state_dict from the model + # add the module_arguments and state_dict from the model model = self.get_model() checkpoint['state_dict'] = model.state_dict() - if hasattr(model, "hparams") and model.hparams is not None: - parsing.clean_namespace(model.hparams) - if isinstance(model.hparams, dict): - checkpoint['hparams_type'] = 'dict' - checkpoint['hparams'] = model.hparams - elif isinstance(model.hparams, Namespace): - checkpoint['hparams_type'] = 'Namespace' - checkpoint['hparams'] = vars(model.hparams) - else: - raise ValueError( - 'The acceptable hparams type is dict or argparse.Namespace,', - f' not {checkpoint["hparams_type"]}' - ) - else: - rank_zero_warn( - "Did not find hyperparameters at model hparams. Saving checkpoint without hyperparameters." - ) + if hasattr(model, CHECKPOINT_KEY_MODULE_ARGS) and model.module_arguments: + # add arguments to the checkpoint + checkpoint[CHECKPOINT_KEY_MODULE_ARGS] = {k: v for k, v in model.module_arguments.items() + if isinstance(v, PRIMITIVE_TYPES)} # give the model a chance to add a few things model.on_save_checkpoint(checkpoint) @@ -463,12 +467,11 @@ def hpc_save(self, folderpath: str, logger): # TODO: fix for anything with multiprocess DP, DDP, DDP2 try: self._atomic_save(checkpoint, filepath) - except AttributeError as e: - if 'hparams' in checkpoint: - del checkpoint['hparams'] - rank_zero_warn('warning, `hparams` dropped from checkpoint.' - f' An attribute is not picklable {e}') - + except AttributeError as err: + if CHECKPOINT_KEY_MODULE_ARGS in checkpoint: + del checkpoint[CHECKPOINT_KEY_MODULE_ARGS] + rank_zero_warn('warning, `module_arguments` dropped from checkpoint.' + f' An attribute is not picklable {err}') self._atomic_save(checkpoint, filepath) return filepath diff --git a/pytorch_lightning/trainer/training_tricks.py b/pytorch_lightning/trainer/training_tricks.py index 2a9adaf568f90..d4a3b3864eb08 100644 --- a/pytorch_lightning/trainer/training_tricks.py +++ b/pytorch_lightning/trainer/training_tricks.py @@ -25,7 +25,9 @@ class TrainerTrainingTricksMixin(ABC): # this is just a summary on variables used in this abstract class, # the proper values/initialisation should be done in child class gradient_clip_val: ... - precision: ... + precision: int + default_root_dir: str + progress_bar_callback: ... on_gpu: bool @abstractmethod @@ -133,7 +135,7 @@ def scale_batch_size(self, algorithm is terminated """ - if not hasattr(model.hparams, batch_arg_name): + if not hasattr(model, batch_arg_name): raise MisconfigurationException(f'Field {batch_arg_name} not found in `model.hparams`') if hasattr(model.train_dataloader, 'patch_loader_code'): @@ -243,9 +245,9 @@ def _adjust_batch_size(trainer, """ model = trainer.get_model() - batch_size = getattr(model.hparams, batch_arg_name) + batch_size = getattr(model, batch_arg_name) if value: - setattr(model.hparams, batch_arg_name, value) + setattr(model, batch_arg_name, value) new_size = value if desc: log.info(f'Batch size {batch_size} {desc}, trying batch size {new_size}') @@ -253,7 +255,7 @@ def _adjust_batch_size(trainer, new_size = int(batch_size * factor) if desc: log.info(f'Batch size {batch_size} {desc}, trying batch size {new_size}') - setattr(model.hparams, batch_arg_name, new_size) + setattr(model, batch_arg_name, new_size) return new_size diff --git a/tests/base/model_optimizers.py b/tests/base/model_optimizers.py index 394ee69daee81..6386d925bdb13 100644 --- a/tests/base/model_optimizers.py +++ b/tests/base/model_optimizers.py @@ -4,12 +4,13 @@ class ConfigureOptimizersPool(ABC): + def configure_optimizers(self): """ return whatever optimizers we want here. :return: list of optimizers """ - optimizer = optim.Adam(self.parameters(), lr=self.hparams.learning_rate) + optimizer = optim.Adam(self.parameters(), lr=self.learning_rate) return optimizer def configure_optimizers__empty(self): @@ -20,7 +21,7 @@ def configure_optimizers__lbfgs(self): return whatever optimizers we want here. :return: list of optimizers """ - optimizer = optim.LBFGS(self.parameters(), lr=self.hparams.learning_rate) + optimizer = optim.LBFGS(self.parameters(), lr=self.learning_rate) return optimizer def configure_optimizers__multiple_optimizers(self): @@ -29,26 +30,26 @@ def configure_optimizers__multiple_optimizers(self): :return: list of optimizers """ # try no scheduler for this model (testing purposes) - optimizer1 = optim.Adam(self.parameters(), lr=self.hparams.learning_rate) - optimizer2 = optim.Adam(self.parameters(), lr=self.hparams.learning_rate) + optimizer1 = optim.Adam(self.parameters(), lr=self.learning_rate) + optimizer2 = optim.Adam(self.parameters(), lr=self.learning_rate) return optimizer1, optimizer2 def configure_optimizers__single_scheduler(self): - optimizer = optim.Adam(self.parameters(), lr=self.hparams.learning_rate) + optimizer = optim.Adam(self.parameters(), lr=self.learning_rate) lr_scheduler = optim.lr_scheduler.StepLR(optimizer, 1, gamma=0.1) return [optimizer], [lr_scheduler] def configure_optimizers__multiple_schedulers(self): - optimizer1 = optim.Adam(self.parameters(), lr=self.hparams.learning_rate) - optimizer2 = optim.Adam(self.parameters(), lr=self.hparams.learning_rate) + optimizer1 = optim.Adam(self.parameters(), lr=self.learning_rate) + optimizer2 = optim.Adam(self.parameters(), lr=self.learning_rate) lr_scheduler1 = optim.lr_scheduler.StepLR(optimizer1, 1, gamma=0.1) lr_scheduler2 = optim.lr_scheduler.StepLR(optimizer2, 1, gamma=0.1) return [optimizer1, optimizer2], [lr_scheduler1, lr_scheduler2] def configure_optimizers__mixed_scheduling(self): - optimizer1 = optim.Adam(self.parameters(), lr=self.hparams.learning_rate) - optimizer2 = optim.Adam(self.parameters(), lr=self.hparams.learning_rate) + optimizer1 = optim.Adam(self.parameters(), lr=self.learning_rate) + optimizer2 = optim.Adam(self.parameters(), lr=self.learning_rate) lr_scheduler1 = optim.lr_scheduler.StepLR(optimizer1, 4, gamma=0.1) lr_scheduler2 = optim.lr_scheduler.StepLR(optimizer2, 1, gamma=0.1) @@ -56,14 +57,14 @@ def configure_optimizers__mixed_scheduling(self): [{'scheduler': lr_scheduler1, 'interval': 'step'}, lr_scheduler2] def configure_optimizers__reduce_lr_on_plateau(self): - optimizer = optim.Adam(self.parameters(), lr=self.hparams.learning_rate) + optimizer = optim.Adam(self.parameters(), lr=self.learning_rate) lr_scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer) return [optimizer], [lr_scheduler] def configure_optimizers__param_groups(self): param_groups = [ - {'params': list(self.parameters())[:2], 'lr': self.hparams.learning_rate * 0.1}, - {'params': list(self.parameters())[2:], 'lr': self.hparams.learning_rate} + {'params': list(self.parameters())[:2], 'lr': self.learning_rate * 0.1}, + {'params': list(self.parameters())[2:], 'lr': self.learning_rate} ] optimizer = optim.Adam(param_groups) diff --git a/tests/base/model_template.py b/tests/base/model_template.py index d530fa4a97b12..782ace16193d3 100644 --- a/tests/base/model_template.py +++ b/tests/base/model_template.py @@ -37,16 +37,36 @@ class EvalModelTemplate( >>> model = EvalModelTemplate() """ - def __init__(self, hparams: object = None) -> object: - """Pass in parsed HyperOptArgumentParser to the model.""" - if hparams is None: - hparams = EvalModelTemplate.get_default_hparams() + + def __init__(self, + *args, + drop_prob: float = 0.2, + batch_size: int = 32, + in_features: int = 28 * 28, + learning_rate: float = 0.001 * 8, + optimizer_name: str = 'adam', + data_root: str = PATH_DATASETS, + out_features: int = 10, + hidden_dim: int = 1000, + b1: float = 0.5, + b2: float = 0.999, + **kwargs) -> object: # init superclass super().__init__() - self.hparams = Namespace(**hparams) if isinstance(hparams, dict) else hparams + self.drop_prob = drop_prob + self.batch_size = batch_size + self.in_features = in_features + self.learning_rate = learning_rate + self.optimizer_name = optimizer_name + self.data_root = data_root + self.out_features = out_features + self.hidden_dim = hidden_dim + self.b1 = b1 + self.b2 = b2 # if you specify an example input, the summary will show input/output for each layer - self.example_input_array = torch.rand(5, 28 * 28) + # TODO: to be fixed in #1773 + # self.example_input_array = torch.rand(5, 28 * 28) # build model self.__build_model() @@ -57,15 +77,15 @@ def __build_model(self): :return: """ self.c_d1 = nn.Linear( - in_features=self.hparams.in_features, - out_features=self.hparams.hidden_dim + in_features=self.in_features, + out_features=self.hidden_dim ) - self.c_d1_bn = nn.BatchNorm1d(self.hparams.hidden_dim) - self.c_d1_drop = nn.Dropout(self.hparams.drop_prob) + self.c_d1_bn = nn.BatchNorm1d(self.hidden_dim) + self.c_d1_drop = nn.Dropout(self.drop_prob) self.c_d2 = nn.Linear( - in_features=self.hparams.hidden_dim, - out_features=self.hparams.out_features + in_features=self.hidden_dim, + out_features=self.out_features ) def forward(self, x): @@ -84,10 +104,10 @@ def loss(self, labels, logits): return nll def prepare_data(self): - _ = TrialMNIST(root=self.hparams.data_root, train=True, download=True) + _ = TrialMNIST(root=self.data_root, train=True, download=True) @staticmethod - def get_default_hparams(continue_training: bool = False, hpc_exp_number: int = 0) -> Namespace: + def get_default_hparams(continue_training: bool = False, hpc_exp_number: int = 0) -> dict: args = dict( drop_prob=0.2, batch_size=32, @@ -107,5 +127,4 @@ def get_default_hparams(continue_training: bool = False, hpc_exp_number: int = 0 hpc_exp_number=hpc_exp_number, ) - hparams = Namespace(**args) - return hparams + return args diff --git a/tests/base/model_utilities.py b/tests/base/model_utilities.py index e1a40f95b804f..ce34b39b162f8 100644 --- a/tests/base/model_utilities.py +++ b/tests/base/model_utilities.py @@ -7,11 +7,11 @@ class ModelTemplateData: hparams: ... def dataloader(self, train): - dataset = TrialMNIST(root=self.hparams.data_root, train=train, download=True) + dataset = TrialMNIST(root=self.data_root, train=train, download=True) loader = DataLoader( dataset=dataset, - batch_size=self.hparams.batch_size, + batch_size=self.batch_size, # test and valid shall not be shuffled shuffle=train, ) diff --git a/tests/base/models.py b/tests/base/models.py index fed694891c291..77deb0766b9b6 100644 --- a/tests/base/models.py +++ b/tests/base/models.py @@ -18,7 +18,7 @@ class Generator(nn.Module): - def __init__(self, latent_dim, img_shape): + def __init__(self, latent_dim: tuple, img_shape: tuple): super().__init__() self.img_shape = img_shape @@ -45,7 +45,7 @@ def forward(self, z): class Discriminator(nn.Module): - def __init__(self, img_shape): + def __init__(self, img_shape: tuple): super().__init__() self.model = nn.Sequential( @@ -67,13 +67,16 @@ def forward(self, img): class TestGAN(LightningModule): """Implements a basic GAN for the purpose of illustrating multiple optimizers.""" - def __init__(self, hparams): + def __init__(self, hidden_dim, learning_rate, b1, b2, **kwargs): super().__init__() - self.hparams = hparams + self.hidden_dim = hidden_dim + self.learning_rate = learning_rate + self.b1 = b1 + self.b2 = b2 # networks mnist_shape = (1, 28, 28) - self.generator = Generator(latent_dim=hparams.hidden_dim, img_shape=mnist_shape) + self.generator = Generator(latent_dim=self.hidden_dim, img_shape=mnist_shape) self.discriminator = Discriminator(img_shape=mnist_shape) # cache for generated images @@ -93,7 +96,7 @@ def training_step(self, batch, batch_idx, optimizer_idx=None): # train generator if optimizer_idx == 0: # sample noise - z = torch.randn(imgs.shape[0], self.hparams.hidden_dim) + z = torch.randn(imgs.shape[0], self.hidden_dim) z = z.type_as(imgs) # generate images @@ -128,8 +131,7 @@ def training_step(self, batch, batch_idx, optimizer_idx=None): fake = torch.zeros(imgs.size(0), 1) fake = fake.type_as(fake) - fake_loss = self.adversarial_loss( - self.discriminator(self.generated_imgs.detach()), fake) + fake_loss = self.adversarial_loss(self.discriminator(self.generated_imgs.detach()), fake) # discriminator loss is the average of these d_loss = (real_loss + fake_loss) / 2 @@ -142,9 +144,9 @@ def training_step(self, batch, batch_idx, optimizer_idx=None): return output def configure_optimizers(self): - lr = self.hparams.learning_rate - b1 = self.hparams.b1 - b2 = self.hparams.b2 + lr = self.learning_rate + b1 = self.b1 + b2 = self.b2 opt_g = torch.optim.Adam(self.generator.parameters(), lr=lr, betas=(b1, b2)) opt_d = torch.optim.Adam(self.discriminator.parameters(), lr=lr, betas=(b1, b2)) diff --git a/tests/callbacks/test_callbacks.py b/tests/callbacks/test_callbacks.py index 52c03ada3bd62..57b0b537b4dca 100644 --- a/tests/callbacks/test_callbacks.py +++ b/tests/callbacks/test_callbacks.py @@ -13,7 +13,7 @@ def test_trainer_callback_system(tmpdir): """Test the callback system.""" hparams = EvalModelTemplate.get_default_hparams() - model = EvalModelTemplate(hparams) + model = EvalModelTemplate(**hparams) def _check_args(trainer, pl_module): assert isinstance(trainer, Trainer) diff --git a/tests/loggers/test_base.py b/tests/loggers/test_base.py index 4d0b869a5d398..e8c8ead2501c3 100644 --- a/tests/loggers/test_base.py +++ b/tests/loggers/test_base.py @@ -61,7 +61,7 @@ def version(self): def test_custom_logger(tmpdir): hparams = EvalModelTemplate.get_default_hparams() - model = EvalModelTemplate(hparams) + model = EvalModelTemplate(**hparams) logger = CustomLogger() @@ -80,7 +80,7 @@ def test_custom_logger(tmpdir): def test_multiple_loggers(tmpdir): hparams = EvalModelTemplate.get_default_hparams() - model = EvalModelTemplate(hparams) + model = EvalModelTemplate(**hparams) logger1 = CustomLogger() logger2 = CustomLogger() diff --git a/tests/models/test_cpu.py b/tests/models/test_cpu.py index f9eb4b9e5810e..d4195d28dbb7e 100644 --- a/tests/models/test_cpu.py +++ b/tests/models/test_cpu.py @@ -69,9 +69,9 @@ def test_lbfgs_cpu_model(tmpdir): ) hparams = EvalModelTemplate.get_default_hparams() - setattr(hparams, 'optimizer_name', 'lbfgs') - setattr(hparams, 'learning_rate', 0.002) - model = EvalModelTemplate(hparams) + hparams.update(optimizer_name='lbfgs', + learning_rate=0.002) + model = EvalModelTemplate(**hparams) model.configure_optimizers = model.configure_optimizers__lbfgs tutils.run_model_test_without_loggers(trainer_options, model, min_acc=0.5) @@ -272,8 +272,8 @@ def __len__(self): return 1 class BpttTestModel(EvalModelTemplate): - def __init__(self, hparams): - super().__init__(hparams) + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) self.test_hidden = None def training_step(self, batch, batch_idx, hiddens): @@ -303,12 +303,14 @@ def train_dataloader(self): ) hparams = EvalModelTemplate.get_default_hparams() - hparams.batch_size = batch_size - hparams.in_features = truncated_bptt_steps - hparams.hidden_dim = truncated_bptt_steps - hparams.out_features = truncated_bptt_steps + hparams.update( + batch_size=batch_size, + in_features=truncated_bptt_steps, + hidden_dim=truncated_bptt_steps, + out_features=truncated_bptt_steps + ) - model = BpttTestModel(hparams) + model = BpttTestModel(**hparams) # fit model trainer = Trainer( diff --git a/tests/models/test_gpu.py b/tests/models/test_gpu.py index f75b0a1f1a582..4746a494543c9 100644 --- a/tests/models/test_gpu.py +++ b/tests/models/test_gpu.py @@ -65,7 +65,7 @@ def test_ddp_all_dataloaders_passed_to_fit(tmpdir): def test_cpu_slurm_save_load(tmpdir): """Verify model save/load/checkpoint on CPU.""" hparams = EvalModelTemplate.get_default_hparams() - model = EvalModelTemplate(hparams) + model = EvalModelTemplate(**hparams) # logger file to get meta logger = tutils.get_default_logger(tmpdir) @@ -112,7 +112,7 @@ def test_cpu_slurm_save_load(tmpdir): logger=logger, checkpoint_callback=ModelCheckpoint(tmpdir), ) - model = EvalModelTemplate(hparams) + model = EvalModelTemplate(**hparams) # set the epoch start hook so we can predict before the model does the full training def assert_pred_same(): diff --git a/tests/models/test_hooks.py b/tests/models/test_hooks.py index 90a0468f576af..568a8eae437c2 100644 --- a/tests/models/test_hooks.py +++ b/tests/models/test_hooks.py @@ -1,6 +1,6 @@ import pytest +import torch -import tests.base.utils as tutils from pytorch_lightning import Trainer from tests.base import EvalModelTemplate @@ -27,3 +27,44 @@ def on_before_zero_grad(self, optimizer): model.on_before_zero_grad_called = 0 trainer.test(model) assert 0 == model.on_before_zero_grad_called + + +def test_training_epoch_end_metrics_collection(tmpdir): + """ Test that progress bar metrics also get collected at the end of an epoch. """ + num_epochs = 3 + + class CurrentModel(EvalModelTemplate): + + def training_step(self, *args, **kwargs): + output = super().training_step(*args, **kwargs) + output['progress_bar'].update({'step_metric': torch.tensor(-1)}) + output['progress_bar'].update({'shared_metric': 100}) + return output + + def training_epoch_end(self, outputs): + epoch = self.current_epoch + # both scalar tensors and Python numbers are accepted + return { + 'progress_bar': { + f'epoch_metric_{epoch}': torch.tensor(epoch), # add a new metric key every epoch + 'shared_metric': 111, + } + } + + model = CurrentModel() + trainer = Trainer( + max_epochs=num_epochs, + default_root_dir=tmpdir, + overfit_pct=0.1, + ) + result = trainer.fit(model) + assert result == 1 + metrics = trainer.progress_bar_dict + + # metrics added in training step should be unchanged by epoch end method + assert metrics['step_metric'] == -1 + # a metric shared in both methods gets overwritten by epoch_end + assert metrics['shared_metric'] == 111 + # metrics are kept after each epoch + for i in range(num_epochs): + assert metrics[f'epoch_metric_{i}'] == i diff --git a/tests/models/test_horovod.py b/tests/models/test_horovod.py index 6117bc8a0e264..4e5fe0ef81552 100644 --- a/tests/models/test_horovod.py +++ b/tests/models/test_horovod.py @@ -143,8 +143,7 @@ def validation_step(self, batch, *args, **kwargs): @pytest.mark.skipif(sys.version_info >= (3, 8), reason="Horovod not yet supported in Python 3.8") @pytest.mark.skipif(platform.system() == "Windows", reason="Horovod is not supported on Windows") def test_horovod_multi_optimizer(tmpdir): - hparams = EvalModelTemplate.get_default_hparams() - model = TestGAN(hparams) + model = TestGAN(**EvalModelTemplate.get_default_hparams()) trainer_options = dict( default_root_dir=str(tmpdir), diff --git a/tests/models/test_hparams.py b/tests/models/test_hparams.py new file mode 100644 index 0000000000000..c97dc61233fb0 --- /dev/null +++ b/tests/models/test_hparams.py @@ -0,0 +1,69 @@ +import os + +import pytest +import torch + +from pytorch_lightning import Trainer +from pytorch_lightning.core.lightning import CHECKPOINT_KEY_MODULE_ARGS +from tests.base import EvalModelTemplate + + +class SubClassEvalModel(EvalModelTemplate): + any_other_loss = torch.nn.CrossEntropyLoss() + + def __init__(self, *args, subclass_arg=1200, **kwargs): + super().__init__(*args, **kwargs) + self.subclass_arg = subclass_arg + + +class SubSubClassEvalModel(SubClassEvalModel): + pass + + +class AggSubClassEvalModel(SubClassEvalModel): + + def __init__(self, *args, my_loss=torch.nn.CrossEntropyLoss(), **kwargs): + super().__init__(*args, **kwargs) + self.my_loss = my_loss + + +@pytest.mark.parametrize("cls", [EvalModelTemplate, + SubClassEvalModel, + SubSubClassEvalModel, + AggSubClassEvalModel]) +def test_collect_init_arguments(tmpdir, cls): + """ Test that the model automatically saves the arguments passed into the constructor """ + extra_args = dict(my_loss=torch.nn.CosineEmbeddingLoss()) if cls is AggSubClassEvalModel else {} + + model = cls(**extra_args) + assert model.batch_size == 32 + model = cls(batch_size=179, **extra_args) + assert model.batch_size == 179 + + if isinstance(model, SubClassEvalModel): + assert model.subclass_arg == 1200 + + if isinstance(model, AggSubClassEvalModel): + assert isinstance(model.my_loss, torch.nn.CosineEmbeddingLoss) + + # verify that the checkpoint saved the correct values + trainer = Trainer(max_steps=5, default_root_dir=tmpdir) + trainer.fit(model) + raw_checkpoint_path = os.listdir(trainer.checkpoint_callback.dirpath) + raw_checkpoint_path = [x for x in raw_checkpoint_path if '.ckpt' in x][0] + raw_checkpoint_path = os.path.join(trainer.checkpoint_callback.dirpath, raw_checkpoint_path) + + raw_checkpoint = torch.load(raw_checkpoint_path) + assert CHECKPOINT_KEY_MODULE_ARGS in raw_checkpoint + assert raw_checkpoint[CHECKPOINT_KEY_MODULE_ARGS]['batch_size'] == 179 + + # verify that model loads correctly + model = cls.load_from_checkpoint(raw_checkpoint_path) + assert model.batch_size == 179 + + if isinstance(model, AggSubClassEvalModel): + assert isinstance(model.my_loss, torch.nn.CrossEntropyLoss) + + # verify that we can overwrite whatever we want + model = cls.load_from_checkpoint(raw_checkpoint_path, batch_size=99) + assert model.batch_size == 99 diff --git a/tests/models/test_module_hooks.py b/tests/models/test_module_hooks.py deleted file mode 100644 index 8b855ba4a70d7..0000000000000 --- a/tests/models/test_module_hooks.py +++ /dev/null @@ -1,47 +0,0 @@ -import torch - -from pytorch_lightning import Trainer -from tests.base import EvalModelTemplate - -import tests.base.utils as tutils - - -def test_training_epoch_end_metrics_collection(tmpdir): - """ Test that progress bar metrics also get collected at the end of an epoch. """ - num_epochs = 3 - - class CurrentModel(EvalModelTemplate): - - def training_step(self, *args, **kwargs): - output = super().training_step(*args, **kwargs) - output['progress_bar'].update({'step_metric': torch.tensor(-1)}) - output['progress_bar'].update({'shared_metric': 100}) - return output - - def training_epoch_end(self, outputs): - epoch = self.current_epoch - # both scalar tensors and Python numbers are accepted - return { - 'progress_bar': { - f'epoch_metric_{epoch}': torch.tensor(epoch), # add a new metric key every epoch - 'shared_metric': 111, - } - } - - model = CurrentModel() - trainer = Trainer( - max_epochs=num_epochs, - default_root_dir=tmpdir, - overfit_pct=0.1, - ) - result = trainer.fit(model) - assert result == 1 - metrics = trainer.progress_bar_dict - - # metrics added in training step should be unchanged by epoch end method - assert metrics['step_metric'] == -1 - # a metric shared in both methods gets overwritten by epoch_end - assert metrics['shared_metric'] == 111 - # metrics are kept after each epoch - for i in range(num_epochs): - assert metrics[f'epoch_metric_{i}'] == i diff --git a/tests/models/test_restore.py b/tests/models/test_restore.py index 73e7362655d2c..cae58cc8faa8f 100644 --- a/tests/models/test_restore.py +++ b/tests/models/test_restore.py @@ -8,7 +8,6 @@ import tests.base.utils as tutils from pytorch_lightning import Trainer from pytorch_lightning.callbacks import ModelCheckpoint -from pytorch_lightning.utilities.exceptions import MisconfigurationException from tests.base import EvalModelTemplate @@ -104,7 +103,7 @@ def test_running_test_pretrained_model_cpu(tmpdir): def test_load_model_from_checkpoint(tmpdir): """Verify test() on pretrained model.""" hparams = EvalModelTemplate.get_default_hparams() - model = EvalModelTemplate(hparams) + model = EvalModelTemplate(**hparams) trainer_options = dict( progress_bar_refresh_rate=0, @@ -128,8 +127,8 @@ def test_load_model_from_checkpoint(tmpdir): pretrained_model = EvalModelTemplate.load_from_checkpoint(last_checkpoint) # test that hparams loaded correctly - for k, v in vars(hparams).items(): - assert getattr(pretrained_model.hparams, k) == v + for k, v in hparams.items(): + assert getattr(pretrained_model, k) == v # assert weights are the same for (old_name, old_p), (new_name, new_p) in zip(model.named_parameters(), pretrained_model.named_parameters()): @@ -146,7 +145,7 @@ def test_load_model_from_checkpoint(tmpdir): def test_dp_resume(tmpdir): """Make sure DP continues training correctly.""" hparams = EvalModelTemplate.get_default_hparams() - model = EvalModelTemplate(hparams) + model = EvalModelTemplate(**hparams) trainer_options = dict( max_epochs=1, @@ -204,7 +203,7 @@ def assert_good_acc(): tutils.run_prediction(dataloader, dp_model, dp=True) # new model - model = EvalModelTemplate(hparams) + model = EvalModelTemplate(**hparams) model.on_train_start = assert_good_acc # fit new model which should load hpc weights @@ -270,42 +269,8 @@ def test_model_saving_loading(tmpdir): assert torch.all(torch.eq(pred_before_saving, new_pred)).item() == 1 -def test_load_model_with_missing_hparams(tmpdir): - trainer_options = dict( - progress_bar_refresh_rate=0, - max_epochs=1, - checkpoint_callback=ModelCheckpoint(tmpdir, save_top_k=-1), - logger=False, - default_root_dir=tmpdir, - ) - - # fit model - trainer = Trainer(**trainer_options) - - class CurrentModelWithoutHparams(EvalModelTemplate): - def __init__(self): - super().__init__() - - class CurrentModelUnusedHparams(EvalModelTemplate): - def __init__(self, hparams): - super().__init__() +def test_model_pickle(tmpdir): + import pickle - model = CurrentModelWithoutHparams() - trainer.fit(model) - last_checkpoint = sorted(glob.glob(os.path.join(trainer.checkpoint_callback.dirpath, "*.ckpt")))[-1] - - # try to load a checkpoint that has hparams but model is missing hparams arg - with pytest.raises(MisconfigurationException, match=r".*__init__ is missing the argument 'hparams'.*"): - CurrentModelWithoutHparams.load_from_checkpoint(last_checkpoint) - - # create a checkpoint without hyperparameters - # if the model does not take a hparams argument, it should not throw an error - ckpt = torch.load(last_checkpoint) - del(ckpt['hparams']) - torch.save(ckpt, last_checkpoint) - CurrentModelWithoutHparams.load_from_checkpoint(last_checkpoint) - - # load checkpoint without hparams again - # warn if user's model has hparams argument - with pytest.warns(UserWarning, match=r".*Will pass in an empty Namespace instead."): - CurrentModelUnusedHparams.load_from_checkpoint(last_checkpoint) + model = EvalModelTemplate() + pickle.dumps(model) diff --git a/tests/trainer/test_checks.py b/tests/trainer/test_checks.py index 4d03035b460fd..c3106abc2a94f 100755 --- a/tests/trainer/test_checks.py +++ b/tests/trainer/test_checks.py @@ -19,12 +19,12 @@ def test_wrong_train_setting(tmpdir): trainer = Trainer(default_root_dir=tmpdir, max_epochs=1) with pytest.raises(MisconfigurationException): - model = EvalModelTemplate(hparams) + model = EvalModelTemplate(**hparams) model.train_dataloader = None trainer.fit(model) with pytest.raises(MisconfigurationException): - model = EvalModelTemplate(hparams) + model = EvalModelTemplate(**hparams) model.training_step = None trainer.fit(model) @@ -53,19 +53,19 @@ def test_wrong_validation_settings(tmpdir): # check val_dataloader -> val_step with pytest.raises(MisconfigurationException): - model = EvalModelTemplate(hparams) + model = EvalModelTemplate(**hparams) model.validation_step = None trainer.fit(model) # check val_dataloader + val_step -> val_epoch_end with pytest.warns(RuntimeWarning): - model = EvalModelTemplate(hparams) + model = EvalModelTemplate(**hparams) model.validation_epoch_end = None trainer.fit(model) # check val_step -> val_dataloader with pytest.raises(MisconfigurationException): - model = EvalModelTemplate(hparams) + model = EvalModelTemplate(**hparams) model.val_dataloader = None trainer.fit(model) @@ -84,7 +84,7 @@ def test_wrong_test_settigs(tmpdir): # if have test_dataloader should have test_step # ---------------- with pytest.raises(MisconfigurationException): - model = EvalModelTemplate(hparams) + model = EvalModelTemplate(**hparams) model.test_step = None trainer.fit(model) @@ -92,7 +92,7 @@ def test_wrong_test_settigs(tmpdir): # if have test_dataloader and test_step recommend test_epoch_end # ---------------- with pytest.warns(RuntimeWarning): - model = EvalModelTemplate(hparams) + model = EvalModelTemplate(**hparams) model.test_epoch_end = None trainer.test(model) @@ -100,7 +100,7 @@ def test_wrong_test_settigs(tmpdir): # if have test_step and NO test_dataloader passed in tell user to pass test_dataloader # ---------------- with pytest.raises(MisconfigurationException): - model = EvalModelTemplate(hparams) + model = EvalModelTemplate(**hparams) model.test_dataloader = LightningModule.test_dataloader trainer.test(model) @@ -108,7 +108,7 @@ def test_wrong_test_settigs(tmpdir): # if have test_dataloader and NO test_step tell user to implement test_step # ---------------- with pytest.raises(MisconfigurationException): - model = EvalModelTemplate(hparams) + model = EvalModelTemplate(**hparams) model.test_dataloader = LightningModule.test_dataloader model.test_step = None trainer.test(model, test_dataloaders=model.dataloader(train=False)) @@ -117,18 +117,7 @@ def test_wrong_test_settigs(tmpdir): # if have test_dataloader and test_step but no test_epoch_end warn user # ---------------- with pytest.warns(RuntimeWarning): - model = EvalModelTemplate(hparams) + model = EvalModelTemplate(**hparams) model.test_dataloader = LightningModule.test_dataloader model.test_epoch_end = None trainer.test(model, test_dataloaders=model.dataloader(train=False)) - - # ---------------- - # if we are just testing, no need for train_dataloader, train_step, val_dataloader, and val_step - # ---------------- - model = EvalModelTemplate(hparams) - model.test_dataloader = LightningModule.test_dataloader - model.train_dataloader = None - model.train_step = None - model.val_dataloader = None - model.val_step = None - trainer.test(model, test_dataloaders=model.dataloader(train=False)) diff --git a/tests/trainer/test_dataloaders.py b/tests/trainer/test_dataloaders.py index f7a197708d0db..7c0cf0bf95b79 100644 --- a/tests/trainer/test_dataloaders.py +++ b/tests/trainer/test_dataloaders.py @@ -423,14 +423,14 @@ def train_dataloader(self): dataset = Subset(dataloader.dataset, range(size)) dataloader = DataLoader( dataset, - batch_size=self.hparams.batch_size, + batch_size=self.batch_size, drop_last=False, ) return dataloader hparams = EvalModelTemplate.get_default_hparams() - hparams.batch_size = batch_size - model = CurrentTestModel(hparams) + hparams['batch_size'] = batch_size + model = CurrentTestModel(**hparams) trainer = Trainer( max_epochs=1, diff --git a/tests/trainer/test_lr_finder.py b/tests/trainer/test_lr_finder.py index fe4894c3e49de..67dd6a6f3d0bd 100755 --- a/tests/trainer/test_lr_finder.py +++ b/tests/trainer/test_lr_finder.py @@ -78,8 +78,8 @@ def test_trainer_reset_correctly(tmpdir): def test_trainer_arg_bool(tmpdir): """ Test that setting trainer arg to bool works """ hparams = EvalModelTemplate.get_default_hparams() - model = EvalModelTemplate(hparams) - before_lr = hparams.learning_rate + model = EvalModelTemplate(**hparams) + before_lr = hparams.get('learning_rate') # logger file to get meta trainer = Trainer( @@ -89,18 +89,17 @@ def test_trainer_arg_bool(tmpdir): ) trainer.fit(model) - after_lr = model.hparams.learning_rate + after_lr = model.learning_rate assert before_lr != after_lr, \ 'Learning rate was not altered after running learning rate finder' def test_trainer_arg_str(tmpdir): """ Test that setting trainer arg to string works """ - hparams = EvalModelTemplate.get_default_hparams() - hparams.__dict__['my_fancy_lr'] = 1.0 # update with non-standard field - model = EvalModelTemplate(hparams) + model = EvalModelTemplate() + model.my_fancy_lr = 1.0 # update with non-standard field - before_lr = hparams.my_fancy_lr + before_lr = model.my_fancy_lr # logger file to get meta trainer = Trainer( default_save_path=tmpdir, @@ -109,7 +108,7 @@ def test_trainer_arg_str(tmpdir): ) trainer.fit(model) - after_lr = model.hparams.my_fancy_lr + after_lr = model.my_fancy_lr assert before_lr != after_lr, \ 'Learning rate was not altered after running learning rate finder' @@ -118,18 +117,18 @@ def test_call_to_trainer_method(tmpdir): """ Test that directly calling the trainer method works """ hparams = EvalModelTemplate.get_default_hparams() - model = EvalModelTemplate(hparams) + model = EvalModelTemplate(**hparams) - before_lr = hparams.learning_rate + before_lr = hparams.get('learning_rate') # logger file to get meta trainer = Trainer( default_save_path=tmpdir, - max_epochs=5 + max_epochs=5, ) lrfinder = trainer.lr_find(model, mode='linear') after_lr = lrfinder.suggestion() - model.hparams.learning_rate = after_lr + model.learning_rate = after_lr trainer.fit(model) assert before_lr != after_lr, \ @@ -141,9 +140,9 @@ def test_accumulation_and_early_stopping(tmpdir): accumulation also works for this feature """ hparams = EvalModelTemplate.get_default_hparams() - model = EvalModelTemplate(hparams) + model = EvalModelTemplate(**hparams) - before_lr = hparams.learning_rate + before_lr = hparams.get('learning_rate') # logger file to get meta trainer = Trainer( default_save_path=tmpdir, @@ -165,12 +164,12 @@ def test_suggestion_parameters_work(tmpdir): """ Test that default skipping does not alter results in basic case """ hparams = EvalModelTemplate.get_default_hparams() - model = EvalModelTemplate(hparams) + model = EvalModelTemplate(**hparams) # logger file to get meta trainer = Trainer( default_save_path=tmpdir, - max_epochs=10 + max_epochs=10, ) lrfinder = trainer.lr_find(model) diff --git a/tests/trainer/test_optimizers.py b/tests/trainer/test_optimizers.py index 06ea784111153..fcd07fbb77b19 100644 --- a/tests/trainer/test_optimizers.py +++ b/tests/trainer/test_optimizers.py @@ -10,7 +10,7 @@ def test_optimizer_with_scheduling(tmpdir): """ Verify that learning rate scheduling is working """ hparams = EvalModelTemplate.get_default_hparams() - model = EvalModelTemplate(hparams) + model = EvalModelTemplate(**hparams) model.configure_optimizers = model.configure_optimizers__single_scheduler # fit model @@ -23,7 +23,7 @@ def test_optimizer_with_scheduling(tmpdir): results = trainer.fit(model) assert results == 1 - init_lr = hparams.learning_rate + init_lr = hparams.get('learning_rate') adjusted_lr = [pg['lr'] for pg in trainer.optimizers[0].param_groups] assert len(trainer.lr_schedulers) == 1, \ @@ -41,7 +41,7 @@ def test_multi_optimizer_with_scheduling(tmpdir): """ Verify that learning rate scheduling is working """ hparams = EvalModelTemplate.get_default_hparams() - model = EvalModelTemplate(hparams) + model = EvalModelTemplate(**hparams) model.configure_optimizers = model.configure_optimizers__multiple_schedulers # fit model @@ -54,7 +54,7 @@ def test_multi_optimizer_with_scheduling(tmpdir): results = trainer.fit(model) assert results == 1 - init_lr = hparams.learning_rate + init_lr = hparams.get('learning_rate') adjusted_lr1 = [pg['lr'] for pg in trainer.optimizers[0].param_groups] adjusted_lr2 = [pg['lr'] for pg in trainer.optimizers[1].param_groups] @@ -76,7 +76,7 @@ def test_multi_optimizer_with_scheduling(tmpdir): def test_multi_optimizer_with_scheduling_stepping(tmpdir): hparams = EvalModelTemplate.get_default_hparams() - model = EvalModelTemplate(hparams) + model = EvalModelTemplate(**hparams) model.configure_optimizers = model.configure_optimizers__multiple_schedulers # fit model @@ -89,7 +89,7 @@ def test_multi_optimizer_with_scheduling_stepping(tmpdir): results = trainer.fit(model) assert results == 1 - init_lr = hparams.learning_rate + init_lr = hparams.get('learning_rate') adjusted_lr1 = [pg['lr'] for pg in trainer.optimizers[0].param_groups] adjusted_lr2 = [pg['lr'] for pg in trainer.optimizers[1].param_groups] @@ -115,7 +115,7 @@ def test_multi_optimizer_with_scheduling_stepping(tmpdir): def test_reduce_lr_on_plateau_scheduling(tmpdir): hparams = EvalModelTemplate.get_default_hparams() - model = EvalModelTemplate(hparams) + model = EvalModelTemplate(**hparams) model.configure_optimizers = model.configure_optimizers__reduce_lr_on_plateau # fit model @@ -205,7 +205,7 @@ def test_none_optimizer_warning(): def test_none_optimizer(tmpdir): hparams = EvalModelTemplate.get_default_hparams() - model = EvalModelTemplate(hparams) + model = EvalModelTemplate(**hparams) model.configure_optimizers = model.configure_optimizers__empty # fit model diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index 1c2c169191564..74e01fba7a37e 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -6,12 +6,12 @@ import pytest import torch -import yaml import tests.base.utils as tutils from pytorch_lightning import Callback, LightningModule from pytorch_lightning import Trainer from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint +from pytorch_lightning.core.lightning import CHECKPOINT_KEY_MODULE_ARGS from pytorch_lightning.core.saving import load_hparams_from_tags_csv, load_hparams_from_yaml, save_hparams_to_tags_csv from pytorch_lightning.loggers import TensorBoardLogger from pytorch_lightning.trainer.logging import TrainerLoggingMixin @@ -19,32 +19,6 @@ from tests.base import EvalModelTemplate -def test_model_pickle(tmpdir): - import pickle - - model = EvalModelTemplate() - pickle.dumps(model) - - -def test_hparams_save_load(tmpdir): - model = EvalModelTemplate(vars(EvalModelTemplate.get_default_hparams())) - - trainer = Trainer( - default_root_dir=tmpdir, - max_epochs=1, - ) - # fit model - result = trainer.fit(model) - assert result == 1 - - # try to load the model now - pretrained_model = tutils.load_model_from_checkpoint( - trainer.checkpoint_callback.dirpath, - module_class=EvalModelTemplate - ) - assert pretrained_model - - def test_no_val_module(tmpdir): """Tests use case where trainer saves the model, and user loads it from tags independently.""" @@ -69,7 +43,7 @@ def test_no_val_module(tmpdir): # assert ckpt has hparams ckpt = torch.load(new_weights_path) - assert 'hparams' in ckpt.keys(), 'hparams missing from checkpoints' + assert CHECKPOINT_KEY_MODULE_ARGS in ckpt.keys(), 'module_arguments missing from checkpoints' # load new model hparams_path = tutils.get_data_path(logger, path_dir=tmpdir) @@ -349,7 +323,7 @@ def test_resume_from_checkpoint_epoch_restored(tmpdir): def _new_model(): # Create a model that tracks epochs and batches seen - model = EvalModelTemplate(hparams) + model = EvalModelTemplate(**hparams) model.num_epochs_seen = 0 model.num_batches_seen = 0 model.num_on_load_checkpoint_called = 0 @@ -526,22 +500,22 @@ def test_testpass_overrides(tmpdir): # Misconfig when neither test_step or test_end is implemented with pytest.raises(MisconfigurationException, match='.*not implement `test_dataloader`.*'): - model = EvalModelTemplate(hparams) + model = EvalModelTemplate(**hparams) model.test_dataloader = LightningModule.test_dataloader Trainer().test(model) # Misconfig when neither test_step or test_end is implemented with pytest.raises(MisconfigurationException): - model = EvalModelTemplate(hparams) + model = EvalModelTemplate(**hparams) model.test_step = LightningModule.test_step Trainer().test(model) # No exceptions when one or both of test_step or test_end are implemented - model = EvalModelTemplate(hparams) + model = EvalModelTemplate(**hparams) model.test_step_end = LightningModule.test_step_end Trainer().test(model) - model = EvalModelTemplate(hparams) + model = EvalModelTemplate(**hparams) Trainer().test(model) @@ -562,7 +536,7 @@ def validation_epoch_end(self, *args, **kwargs): return super().validation_epoch_end(*args, **kwargs) hparams = EvalModelTemplate.get_default_hparams() - model = CurrentModel(hparams) + model = CurrentModel(**hparams) trainer_options = dict( progress_bar_refresh_rate=0, @@ -584,7 +558,7 @@ def validation_epoch_end(self, *args, **kwargs): '`validation_epoch_end` should not run when `val_percent_check=0`' # check that val_percent_check has no influence when fast_dev_run is turned on - model = CurrentModel(hparams) + model = CurrentModel(**hparams) trainer_options.update(fast_dev_run=True) trainer = Trainer(**trainer_options) result = trainer.fit(model) diff --git a/tests/trainer/test_trainer_tricks.py b/tests/trainer/test_trainer_tricks.py index 0b9b548c69c60..a66e8bbde8b7f 100755 --- a/tests/trainer/test_trainer_tricks.py +++ b/tests/trainer/test_trainer_tricks.py @@ -72,9 +72,9 @@ def test_trainer_arg(tmpdir, scale_arg): tutils.reset_seed() hparams = EvalModelTemplate.get_default_hparams() - model = EvalModelTemplate(hparams) + model = EvalModelTemplate(**hparams) - before_batch_size = hparams.batch_size + before_batch_size = hparams.get('batch_size') # logger file to get meta trainer = Trainer( default_save_path=tmpdir, @@ -83,7 +83,7 @@ def test_trainer_arg(tmpdir, scale_arg): ) trainer.fit(model) - after_batch_size = model.hparams.batch_size + after_batch_size = model.batch_size assert before_batch_size != after_batch_size, \ 'Batch size was not altered after running auto scaling of batch size' @@ -94,9 +94,9 @@ def test_call_to_trainer_method(tmpdir, scale_method): tutils.reset_seed() hparams = EvalModelTemplate.get_default_hparams() - model = EvalModelTemplate(hparams) + model = EvalModelTemplate(**hparams) - before_batch_size = hparams.batch_size + before_batch_size = hparams.get('batch_size') # logger file to get meta trainer = Trainer( default_save_path=tmpdir, @@ -104,7 +104,7 @@ def test_call_to_trainer_method(tmpdir, scale_method): ) after_batch_size = trainer.scale_batch_size(model, mode=scale_method, max_trials=5) - model.hparams.batch_size = after_batch_size + model.batch_size = after_batch_size trainer.fit(model) assert before_batch_size != after_batch_size, \