From 7a8ba35e5ee3a5a7e24477738d5ddb829a187a60 Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Fri, 11 Dec 2020 08:39:22 +0100 Subject: [PATCH 1/9] define tests --- pl_examples/basic_examples/autoencoder.py | 7 ++ .../backbone_image_classifier.py | 6 ++ .../basic_examples/conv_sequential_example.py | 3 + .../basic_examples/dali_image_classifier.py | 2 + .../basic_examples/mnist_datamodule.py | 2 + .../basic_examples/simple_image_classifier.py | 3 + pl_examples/bug_report_model.py | 6 ++ .../computer_vision_fine_tuning.py | 32 +++++---- .../generative_adversarial_net.py | 12 ++++ pl_examples/domain_templates/imagenet.py | 3 + .../domain_templates/reinforce_learn_Qnet.py | 70 ++++++++++++------- .../domain_templates/semantic_segmentation.py | 21 +++--- pl_examples/domain_templates/unet.py | 21 ++++-- 13 files changed, 135 insertions(+), 53 deletions(-) diff --git a/pl_examples/basic_examples/autoencoder.py b/pl_examples/basic_examples/autoencoder.py index 72bfcb17c0872..91f7ac0a1569d 100644 --- a/pl_examples/basic_examples/autoencoder.py +++ b/pl_examples/basic_examples/autoencoder.py @@ -31,6 +31,13 @@ class LitAutoEncoder(pl.LightningModule): + """ + >>> LitAutoEncoder() # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE + LitAutoEncoder( + (encoder): ... + (decoder): ... + ) + """ def __init__(self): super().__init__() diff --git a/pl_examples/basic_examples/backbone_image_classifier.py b/pl_examples/basic_examples/backbone_image_classifier.py index b0ca2efd5d76b..188c0efd37dc9 100644 --- a/pl_examples/basic_examples/backbone_image_classifier.py +++ b/pl_examples/basic_examples/backbone_image_classifier.py @@ -29,6 +29,9 @@ class Backbone(torch.nn.Module): + """ + >>> Backbone() # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE + """ def __init__(self, hidden_dim=128): super().__init__() self.l1 = torch.nn.Linear(28 * 28, hidden_dim) @@ -42,6 +45,9 @@ def forward(self, x): class LitClassifier(pl.LightningModule): + """ + >>> LitClassifier() # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE + """ def __init__(self, backbone, learning_rate=1e-3): super().__init__() self.save_hyperparameters() diff --git a/pl_examples/basic_examples/conv_sequential_example.py b/pl_examples/basic_examples/conv_sequential_example.py index 06fddd689260f..83a062ef7d33f 100644 --- a/pl_examples/basic_examples/conv_sequential_example.py +++ b/pl_examples/basic_examples/conv_sequential_example.py @@ -55,6 +55,9 @@ def forward(self, x): class LitResnet(pl.LightningModule): + """ + >>> LitResnet() # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE + """ def __init__(self, lr=0.05, batch_size=32, manual_optimization=False): super().__init__() diff --git a/pl_examples/basic_examples/dali_image_classifier.py b/pl_examples/basic_examples/dali_image_classifier.py index 9f3ba5e08b37e..d0893b509eaa6 100644 --- a/pl_examples/basic_examples/dali_image_classifier.py +++ b/pl_examples/basic_examples/dali_image_classifier.py @@ -42,6 +42,8 @@ class ExternalMNISTInputIterator(object): """ This iterator class wraps torchvision's MNIST dataset and returns the images and labels in batches + + >>> ExternalMNISTInputIterator() # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE """ def __init__(self, mnist_ds, batch_size): diff --git a/pl_examples/basic_examples/mnist_datamodule.py b/pl_examples/basic_examples/mnist_datamodule.py index eb1415cf8b981..a31b668239941 100644 --- a/pl_examples/basic_examples/mnist_datamodule.py +++ b/pl_examples/basic_examples/mnist_datamodule.py @@ -29,6 +29,8 @@ class MNISTDataModule(LightningDataModule): """ Standard MNIST, train, val, test splits and transforms + + >>> MNISTDataModule() # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE """ name = "mnist" diff --git a/pl_examples/basic_examples/simple_image_classifier.py b/pl_examples/basic_examples/simple_image_classifier.py index 6b8457e0e4897..4d306f17cf15a 100644 --- a/pl_examples/basic_examples/simple_image_classifier.py +++ b/pl_examples/basic_examples/simple_image_classifier.py @@ -24,6 +24,9 @@ class LitClassifier(pl.LightningModule): + """ + >>> LitClassifier() # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE + """ def __init__(self, hidden_dim=128, learning_rate=1e-3): super().__init__() self.save_hyperparameters() diff --git a/pl_examples/bug_report_model.py b/pl_examples/bug_report_model.py index e2201db12f894..bfe2ed0ea290b 100644 --- a/pl_examples/bug_report_model.py +++ b/pl_examples/bug_report_model.py @@ -28,6 +28,9 @@ class RandomDataset(Dataset): + """ + >>> RandomDataset((10, 5), 20) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE + """ def __init__(self, size, length): self.len = length self.data = torch.randn(length, size) @@ -40,6 +43,9 @@ def __len__(self): class BoringModel(LightningModule): + """ + >>>BoringModel() # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE + """ def __init__(self): """ diff --git a/pl_examples/domain_templates/computer_vision_fine_tuning.py b/pl_examples/domain_templates/computer_vision_fine_tuning.py index 1c60e3aa6d23f..76864f5f6ddc7 100644 --- a/pl_examples/domain_templates/computer_vision_fine_tuning.py +++ b/pl_examples/domain_templates/computer_vision_fine_tuning.py @@ -159,20 +159,26 @@ def _unfreeze_and_add_param_group(module: Module, class TransferLearningModel(pl.LightningModule): """Transfer Learning with pre-trained ResNet50. - Args: - hparams: Model hyperparameters - dl_path: Path where the data will be downloaded + >>> TransferLearningModel() # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE """ - def __init__(self, - dl_path: Union[str, Path], - backbone: str = 'resnet50', - train_bn: bool = True, - milestones: tuple = (5, 10), - batch_size: int = 8, - lr: float = 1e-2, - lr_scheduler_gamma: float = 1e-1, - num_workers: int = 6, **kwargs) -> None: - super().__init__() + def __init__( + self, + dl_path: Union[str, Path], + backbone: str = 'resnet50', + train_bn: bool = True, + milestones: tuple = (5, 10), + batch_size: int = 8, + lr: float = 1e-2, + lr_scheduler_gamma: float = 1e-1, + num_workers: int = 6, + **kwargs, + ) -> None: + """ + Args: + hparams: Model hyperparameters + dl_path: Path where the data will be downloaded + """ + super().__init__(**kwargs) self.dl_path = dl_path self.backbone = backbone self.train_bn = train_bn diff --git a/pl_examples/domain_templates/generative_adversarial_net.py b/pl_examples/domain_templates/generative_adversarial_net.py index 210a80721d9a9..84d5537c48ae9 100644 --- a/pl_examples/domain_templates/generative_adversarial_net.py +++ b/pl_examples/domain_templates/generative_adversarial_net.py @@ -37,6 +37,9 @@ class Generator(nn.Module): + """ + >>> Generator() # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE + """ def __init__(self, latent_dim, img_shape): super().__init__() self.img_shape = img_shape @@ -64,6 +67,9 @@ def forward(self, z): class Discriminator(nn.Module): + """ + >>> Discriminator() # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE + """ def __init__(self, img_shape): super().__init__() @@ -83,6 +89,9 @@ def forward(self, img): class GAN(LightningModule): + """ + >>> GAN() # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE + """ @staticmethod def add_argparse_args(parent_parser: ArgumentParser): parser = ArgumentParser(parents=[parent_parser], add_help=False) @@ -180,6 +189,9 @@ def on_epoch_end(self): class MNISTDataModule(LightningDataModule): + """ + >>> MNISTDataModule() # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE + """ def __init__(self, batch_size: int = 64, data_path: str = os.getcwd(), num_workers: int = 4): super().__init__() self.batch_size = batch_size diff --git a/pl_examples/domain_templates/imagenet.py b/pl_examples/domain_templates/imagenet.py index b1eea307478f9..58e132f46dbba 100644 --- a/pl_examples/domain_templates/imagenet.py +++ b/pl_examples/domain_templates/imagenet.py @@ -50,6 +50,9 @@ class ImageNetLightningModel(LightningModule): + """ + >>> ImageNetLightningModel() # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE + """ # pull out resnet names from torchvision models MODEL_NAMES = sorted( name for name in models.__dict__ diff --git a/pl_examples/domain_templates/reinforce_learn_Qnet.py b/pl_examples/domain_templates/reinforce_learn_Qnet.py index a8b9db095f377..eba2b23fa02d7 100644 --- a/pl_examples/domain_templates/reinforce_learn_Qnet.py +++ b/pl_examples/domain_templates/reinforce_learn_Qnet.py @@ -53,13 +53,16 @@ class DQN(nn.Module): """ Simple MLP network - Args: - obs_size: observation/state size of the environment - n_actions: number of discrete actions available in the environment - hidden_size: size of hidden layers + >>> DQN() # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE """ def __init__(self, obs_size: int, n_actions: int, hidden_size: int = 128): + """ + Args: + obs_size: observation/state size of the environment + n_actions: number of discrete actions available in the environment + hidden_size: size of hidden layers + """ super(DQN, self).__init__() self.net = nn.Sequential( nn.Linear(obs_size, hidden_size), @@ -81,11 +84,14 @@ class ReplayBuffer: """ Replay Buffer for storing past experiences allowing the agent to learn from them - Args: - capacity: size of the buffer + >>> ReplayBuffer() # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE """ def __init__(self, capacity: int) -> None: + """ + Args: + capacity: size of the buffer + """ self.buffer = deque(maxlen=capacity) def __len__(self) -> int: @@ -113,12 +119,15 @@ class RLDataset(IterableDataset): Iterable Dataset containing the ExperienceBuffer which will be updated with new experiences during training - Args: - buffer: replay buffer - sample_size: number of experiences to sample at a time + >>> RLDataset() # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE """ def __init__(self, buffer: ReplayBuffer, sample_size: int = 200) -> None: + """ + Args: + buffer: replay buffer + sample_size: number of experiences to sample at a time + """ self.buffer = buffer self.sample_size = sample_size @@ -132,12 +141,15 @@ class Agent: """ Base Agent class handling the interaction with the environment - Args: - env: training environment - replay_buffer: replay buffer storing experiences + >>> Agent() # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE """ def __init__(self, env: gym.Env, replay_buffer: ReplayBuffer) -> None: + """ + Args: + env: training environment + replay_buffer: replay buffer storing experiences + """ self.env = env self.replay_buffer = replay_buffer self.reset() @@ -204,20 +216,26 @@ def play_step(self, net: nn.Module, epsilon: float = 0.0, device: str = 'cpu') - class DQNLightning(pl.LightningModule): - """ Basic DQN Model """ - - def __init__(self, - replay_size, - warm_start_steps: int, - gamma: float, - eps_start: int, - eps_end: int, - eps_last_frame: int, - sync_rate, - lr: float, - episode_length, - batch_size, **kwargs) -> None: - super().__init__() + """ Basic DQN Model + + >>> DQNLightning() # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE + """ + + def __init__( + self, + replay_size, + warm_start_steps: int, + gamma: float, + eps_start: int, + eps_end: int, + eps_last_frame: int, + sync_rate, + lr: float, + episode_length, + batch_size, + **kwargs, + ) -> None: + super().__init__(**kwargs) self.replay_size = replay_size self.warm_start_steps = warm_start_steps self.gamma = gamma diff --git a/pl_examples/domain_templates/semantic_segmentation.py b/pl_examples/domain_templates/semantic_segmentation.py index 08bdc1140916a..15d406d954aa8 100644 --- a/pl_examples/domain_templates/semantic_segmentation.py +++ b/pl_examples/domain_templates/semantic_segmentation.py @@ -141,16 +141,21 @@ class SegModel(pl.LightningModule): It uses the FCN ResNet50 model as an example. Adam optimizer is used along with Cosine Annealing learning rate scheduler. + + >>> SegModel() # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE """ - def __init__(self, - data_path: str, - batch_size: int, - lr: float, - num_layers: int, - features_start: int, - bilinear: bool, **kwargs): - super().__init__() + def __init__( + self, + data_path: str, + batch_size: int, + lr: float, + num_layers: int, + features_start: int, + bilinear: bool, + **kwargs, + ): + super().__init__(**kwargs) self.data_path = data_path self.batch_size = batch_size self.lr = lr diff --git a/pl_examples/domain_templates/unet.py b/pl_examples/domain_templates/unet.py index 20b4bdb2a4bf9..49666793861e6 100644 --- a/pl_examples/domain_templates/unet.py +++ b/pl_examples/domain_templates/unet.py @@ -22,12 +22,7 @@ class UNet(nn.Module): Architecture based on U-Net: Convolutional Networks for Biomedical Image Segmentation Link - https://arxiv.org/abs/1505.04597 - Parameters: - num_classes: Number of output classes required (default 19 for KITTI dataset) - num_layers: Number of layers in each side of U-net - features_start: Number of features in first layer - bilinear: Whether to use bilinear interpolation or transposed - convolutions for upsampling. + >>> UNet() # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE """ def __init__( @@ -36,6 +31,14 @@ def __init__( features_start: int = 64, bilinear: bool = False ): + """ + Args: + num_classes: Number of output classes required (default 19 for KITTI dataset) + num_layers: Number of layers in each side of U-net + features_start: Number of features in first layer + bilinear: Whether to use bilinear interpolation or transposed + convolutions for upsampling. + """ super().__init__() self.num_layers = num_layers @@ -69,6 +72,8 @@ class DoubleConv(nn.Module): """ Double Convolution and BN and ReLU (3x3 conv -> BN -> ReLU) ** 2 + + >>> DoubleConv() # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE """ def __init__(self, in_ch: int, out_ch: int): @@ -89,6 +94,8 @@ def forward(self, x): class Down(nn.Module): """ Combination of MaxPool2d and DoubleConv in series + + >>> Down() # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE """ def __init__(self, in_ch: int, out_ch: int): @@ -107,6 +114,8 @@ class Up(nn.Module): Upsampling (by either bilinear interpolation or transpose convolutions) followed by concatenation of feature map from contracting path, followed by double 3x3 convolution. + + >>> Up() # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE """ def __init__(self, in_ch: int, out_ch: int, bilinear: bool = False): From b7d524f661ca9352036e24acd260d2eb60705b52 Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Fri, 11 Dec 2020 09:16:05 +0100 Subject: [PATCH 2/9] fix basic --- pl_examples/basic_examples/backbone_image_classifier.py | 9 ++++++++- pl_examples/basic_examples/conv_sequential_example.py | 3 +++ pl_examples/basic_examples/dali_image_classifier.py | 2 -- pl_examples/basic_examples/mnist_datamodule.py | 1 + pl_examples/basic_examples/simple_image_classifier.py | 4 ++++ 5 files changed, 16 insertions(+), 3 deletions(-) diff --git a/pl_examples/basic_examples/backbone_image_classifier.py b/pl_examples/basic_examples/backbone_image_classifier.py index 188c0efd37dc9..bb1daad301d08 100644 --- a/pl_examples/basic_examples/backbone_image_classifier.py +++ b/pl_examples/basic_examples/backbone_image_classifier.py @@ -31,6 +31,10 @@ class Backbone(torch.nn.Module): """ >>> Backbone() # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE + Backbone( + (l1): Linear(...) + (l2): Linear(...) + ) """ def __init__(self, hidden_dim=128): super().__init__() @@ -46,7 +50,10 @@ def forward(self, x): class LitClassifier(pl.LightningModule): """ - >>> LitClassifier() # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE + >>> LitClassifier(Backbone()) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE + LitClassifier( + (backbone): ... + ) """ def __init__(self, backbone, learning_rate=1e-3): super().__init__() diff --git a/pl_examples/basic_examples/conv_sequential_example.py b/pl_examples/basic_examples/conv_sequential_example.py index 83a062ef7d33f..39634084860c2 100644 --- a/pl_examples/basic_examples/conv_sequential_example.py +++ b/pl_examples/basic_examples/conv_sequential_example.py @@ -57,6 +57,9 @@ def forward(self, x): class LitResnet(pl.LightningModule): """ >>> LitResnet() # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE + LitResnet( + (sequential_module): Sequential(...) + ) """ def __init__(self, lr=0.05, batch_size=32, manual_optimization=False): super().__init__() diff --git a/pl_examples/basic_examples/dali_image_classifier.py b/pl_examples/basic_examples/dali_image_classifier.py index d0893b509eaa6..9f3ba5e08b37e 100644 --- a/pl_examples/basic_examples/dali_image_classifier.py +++ b/pl_examples/basic_examples/dali_image_classifier.py @@ -42,8 +42,6 @@ class ExternalMNISTInputIterator(object): """ This iterator class wraps torchvision's MNIST dataset and returns the images and labels in batches - - >>> ExternalMNISTInputIterator() # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE """ def __init__(self, mnist_ds, batch_size): diff --git a/pl_examples/basic_examples/mnist_datamodule.py b/pl_examples/basic_examples/mnist_datamodule.py index a31b668239941..df1b571aac93d 100644 --- a/pl_examples/basic_examples/mnist_datamodule.py +++ b/pl_examples/basic_examples/mnist_datamodule.py @@ -31,6 +31,7 @@ class MNISTDataModule(LightningDataModule): Standard MNIST, train, val, test splits and transforms >>> MNISTDataModule() # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE + """ name = "mnist" diff --git a/pl_examples/basic_examples/simple_image_classifier.py b/pl_examples/basic_examples/simple_image_classifier.py index 4d306f17cf15a..894eeea619ba9 100644 --- a/pl_examples/basic_examples/simple_image_classifier.py +++ b/pl_examples/basic_examples/simple_image_classifier.py @@ -26,6 +26,10 @@ class LitClassifier(pl.LightningModule): """ >>> LitClassifier() # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE + LitClassifier( + (l1): Linear(...) + (l2): Linear(...) + ) """ def __init__(self, hidden_dim=128, learning_rate=1e-3): super().__init__() From d668a585c46ba2a75c32846e80903342bf629690 Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Fri, 11 Dec 2020 09:33:11 +0100 Subject: [PATCH 3/9] fix gans --- .../computer_vision_fine_tuning.py | 7 ++- .../generative_adversarial_net.py | 57 +++++++++++++------ pl_examples/domain_templates/imagenet.py | 19 ++++--- 3 files changed, 56 insertions(+), 27 deletions(-) diff --git a/pl_examples/domain_templates/computer_vision_fine_tuning.py b/pl_examples/domain_templates/computer_vision_fine_tuning.py index 76864f5f6ddc7..ca5925f0f56de 100644 --- a/pl_examples/domain_templates/computer_vision_fine_tuning.py +++ b/pl_examples/domain_templates/computer_vision_fine_tuning.py @@ -159,7 +159,12 @@ def _unfreeze_and_add_param_group(module: Module, class TransferLearningModel(pl.LightningModule): """Transfer Learning with pre-trained ResNet50. - >>> TransferLearningModel() # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE + >>> with TemporaryDirectory(dir='.') as tmp_dir: + ... TransferLearningModel(tmp_dir) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE + TransferLearningModel( + (feature_extractor): Sequential(...) + (fc): Sequential(...) + ) """ def __init__( self, diff --git a/pl_examples/domain_templates/generative_adversarial_net.py b/pl_examples/domain_templates/generative_adversarial_net.py index 84d5537c48ae9..e130d18294235 100644 --- a/pl_examples/domain_templates/generative_adversarial_net.py +++ b/pl_examples/domain_templates/generative_adversarial_net.py @@ -38,9 +38,12 @@ class Generator(nn.Module): """ - >>> Generator() # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE + >>> Generator(img_shape=(1, 8, 8)) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE + Generator( + (model): Sequential(...) + ) """ - def __init__(self, latent_dim, img_shape): + def __init__(self, latent_dim: int = 100, img_shape: tuple = (1, 28, 28)): super().__init__() self.img_shape = img_shape @@ -68,7 +71,10 @@ def forward(self, z): class Discriminator(nn.Module): """ - >>> Discriminator() # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE + >>> Discriminator(img_shape=(1, 28, 28)) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE + Discriminator( + (model): Sequential(...) + ) """ def __init__(self, img_shape): super().__init__() @@ -90,8 +96,36 @@ def forward(self, img): class GAN(LightningModule): """ - >>> GAN() # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE + >>> GAN(img_shape=(1, 8, 8)) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE + GAN( + (generator): Generator( + (model): Sequential(...) + ) + (discriminator): Discriminator( + (model): Sequential(...) + ) + ) """ + def __init__( + self, + img_shape: tuple = (1, 28, 28), + lr: float = 0.0002, + b1: float = 0.5, + b2: float = 0.999, + latent_dim: int = 100, + ): + super().__init__() + + self.save_hyperparameters() + + # networks + self.generator = Generator(latent_dim=self.hparams.latent_dim, img_shape=img_shape) + self.discriminator = Discriminator(img_shape=img_shape) + + self.validation_z = torch.randn(8, self.hparams.latent_dim) + + self.example_input_array = torch.zeros(2, self.hparams.latent_dim) + @staticmethod def add_argparse_args(parent_parser: ArgumentParser): parser = ArgumentParser(parents=[parent_parser], add_help=False) @@ -105,20 +139,6 @@ def add_argparse_args(parent_parser: ArgumentParser): return parser - def __init__(self, hparams: Namespace): - super().__init__() - - self.hparams = hparams - - # networks - mnist_shape = (1, 28, 28) - self.generator = Generator(latent_dim=self.hparams.latent_dim, img_shape=mnist_shape) - self.discriminator = Discriminator(img_shape=mnist_shape) - - self.validation_z = torch.randn(8, self.hparams.latent_dim) - - self.example_input_array = torch.zeros(2, self.hparams.latent_dim) - def forward(self, z): return self.generator(z) @@ -191,6 +211,7 @@ def on_epoch_end(self): class MNISTDataModule(LightningDataModule): """ >>> MNISTDataModule() # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE + """ def __init__(self, batch_size: int = 64, data_path: str = os.getcwd(), num_workers: int = 4): super().__init__() diff --git a/pl_examples/domain_templates/imagenet.py b/pl_examples/domain_templates/imagenet.py index 58e132f46dbba..cc36f3542a1c8 100644 --- a/pl_examples/domain_templates/imagenet.py +++ b/pl_examples/domain_templates/imagenet.py @@ -51,7 +51,10 @@ class ImageNetLightningModel(LightningModule): """ - >>> ImageNetLightningModel() # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE + >>> ImageNetLightningModel(data_path='missing') # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE + ImageNetLightningModel( + (model): ResNet(...) + ) """ # pull out resnet names from torchvision models MODEL_NAMES = sorted( @@ -61,14 +64,14 @@ class ImageNetLightningModel(LightningModule): def __init__( self, - arch: str, - pretrained: bool, - lr: float, - momentum: float, - weight_decay: int, data_path: str, - batch_size: int, - workers: int, + arch: str = 'resnet18', + pretrained: bool = False, + lr: float = 0.1, + momentum: float = 0.9, + weight_decay: float = 1e-4, + batch_size: int = 4, + workers: int = 2, **kwargs, ): super().__init__() From 910d2ac85175ead33065aa0505fd5a3942c5f2e0 Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Fri, 11 Dec 2020 10:44:51 +0100 Subject: [PATCH 4/9] unet --- .../basic_examples/mnist_datamodule.py | 2 +- .../generative_adversarial_net.py | 2 +- .../domain_templates/reinforce_learn_Qnet.py | 98 +++++++++++-------- .../domain_templates/semantic_segmentation.py | 50 ++++------ pl_examples/domain_templates/unet.py | 40 ++++++-- 5 files changed, 115 insertions(+), 77 deletions(-) diff --git a/pl_examples/basic_examples/mnist_datamodule.py b/pl_examples/basic_examples/mnist_datamodule.py index df1b571aac93d..6599b7d594aa8 100644 --- a/pl_examples/basic_examples/mnist_datamodule.py +++ b/pl_examples/basic_examples/mnist_datamodule.py @@ -30,7 +30,7 @@ class MNISTDataModule(LightningDataModule): """ Standard MNIST, train, val, test splits and transforms - >>> MNISTDataModule() # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE + >>> MNISTDataModule() # doctest: +ELLIPSIS """ diff --git a/pl_examples/domain_templates/generative_adversarial_net.py b/pl_examples/domain_templates/generative_adversarial_net.py index e130d18294235..0abeb71516480 100644 --- a/pl_examples/domain_templates/generative_adversarial_net.py +++ b/pl_examples/domain_templates/generative_adversarial_net.py @@ -210,7 +210,7 @@ def on_epoch_end(self): class MNISTDataModule(LightningDataModule): """ - >>> MNISTDataModule() # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE + >>> MNISTDataModule() # doctest: +ELLIPSIS """ def __init__(self, batch_size: int = 64, data_path: str = os.getcwd(), num_workers: int = 4): diff --git a/pl_examples/domain_templates/reinforce_learn_Qnet.py b/pl_examples/domain_templates/reinforce_learn_Qnet.py index eba2b23fa02d7..2166286a044c2 100644 --- a/pl_examples/domain_templates/reinforce_learn_Qnet.py +++ b/pl_examples/domain_templates/reinforce_learn_Qnet.py @@ -53,7 +53,10 @@ class DQN(nn.Module): """ Simple MLP network - >>> DQN() # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE + >>> DQN(10, 5) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE + DQN( + (net): Sequential(...) + ) """ def __init__(self, obs_size: int, n_actions: int, hidden_size: int = 128): @@ -84,7 +87,8 @@ class ReplayBuffer: """ Replay Buffer for storing past experiences allowing the agent to learn from them - >>> ReplayBuffer() # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE + >>> ReplayBuffer(5) # doctest: +ELLIPSIS + """ def __init__(self, capacity: int) -> None: @@ -119,7 +123,8 @@ class RLDataset(IterableDataset): Iterable Dataset containing the ExperienceBuffer which will be updated with new experiences during training - >>> RLDataset() # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE + >>> RLDataset(ReplayBuffer(5)) # doctest: +ELLIPSIS + """ def __init__(self, buffer: ReplayBuffer, sample_size: int = 200) -> None: @@ -141,7 +146,10 @@ class Agent: """ Base Agent class handling the interaction with the environment - >>> Agent() # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE + >>> env = gym.make("CartPole-v0") + >>> buffer = ReplayBuffer(10) + >>> Agent(env, buffer) # doctest: +ELLIPSIS + """ def __init__(self, env: gym.Env, replay_buffer: ReplayBuffer) -> None: @@ -218,21 +226,29 @@ def play_step(self, net: nn.Module, epsilon: float = 0.0, device: str = 'cpu') - class DQNLightning(pl.LightningModule): """ Basic DQN Model - >>> DQNLightning() # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE + >>> DQNLightning(env="CartPole-v0") # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE + DQNLightning( + (net): DQN( + (net): Sequential(...) + ) + (target_net): DQN( + (net): Sequential(...) + ) + ) """ - def __init__( self, - replay_size, - warm_start_steps: int, - gamma: float, - eps_start: int, - eps_end: int, - eps_last_frame: int, - sync_rate, - lr: float, - episode_length, - batch_size, + env: str, + replay_size: int = 200, + warm_start_steps: int = 200, + gamma: float = 0.99, + eps_start: float = 1.0, + eps_end: float = 0.01, + eps_last_frame: int = 200, + sync_rate: int = 10, + lr: float = 1e-2, + episode_length: int = 50, + batch_size: int = 4, **kwargs, ) -> None: super().__init__(**kwargs) @@ -247,7 +263,7 @@ def __init__( self.episode_length = episode_length self.batch_size = batch_size - self.env = gym.make(self.env) + self.env = gym.make(env) obs_size = self.env.observation_space.shape[0] n_actions = self.env.action_space.n @@ -320,8 +336,7 @@ def training_step(self, batch: Tuple[torch.Tensor, torch.Tensor], nb_batch) -> O Training loss and log metrics """ device = self.get_device(batch) - epsilon = max(self.eps_end, self.eps_start - - self.global_step + 1 / self.eps_last_frame) + epsilon = max(self.eps_end, self.eps_start - self.global_step + 1 / self.eps_last_frame) # step through environment with agent reward, done = self.agent.play_step(self.net, epsilon, device) @@ -367,6 +382,30 @@ def get_device(self, batch) -> str: """Retrieve device currently being used by minibatch""" return batch[0].device.index if self.on_gpu else 'cpu' + @staticmethod + def add_model_specific_args(parent_parser): # pragma: no-cover + parser = argparse.ArgumentParser(parents=[parent_parser]) + parser.add_argument("--batch_size", type=int, default=16, help="size of the batches") + parser.add_argument("--lr", type=float, default=1e-2, help="learning rate") + parser.add_argument("--env", type=str, default="CartPole-v0", help="gym environment tag") + parser.add_argument("--gamma", type=float, default=0.99, help="discount factor") + parser.add_argument("--sync_rate", type=int, default=10, + help="how many frames do we update the target network") + parser.add_argument("--replay_size", type=int, default=1000, + help="capacity of the replay buffer") + parser.add_argument("--warm_start_size", type=int, default=1000, + help="how many samples do we use to fill our buffer at the start of training") + parser.add_argument("--eps_last_frame", type=int, default=1000, + help="what frame should epsilon stop decaying") + parser.add_argument("--eps_start", type=float, default=1.0, help="starting value of epsilon") + parser.add_argument("--eps_end", type=float, default=0.01, help="final value of epsilon") + parser.add_argument("--episode_length", type=int, default=200, help="max length of an episode") + parser.add_argument("--max_episode_reward", type=int, default=200, + help="max episode reward in the environment") + parser.add_argument("--warm_start_steps", type=int, default=1000, + help="max episode reward in the environment") + return parser + def main(args) -> None: model = DQNLightning(**vars(args)) @@ -386,26 +425,7 @@ def main(args) -> None: np.random.seed(0) parser = argparse.ArgumentParser() - parser.add_argument("--batch_size", type=int, default=16, help="size of the batches") - parser.add_argument("--lr", type=float, default=1e-2, help="learning rate") - parser.add_argument("--env", type=str, default="CartPole-v0", help="gym environment tag") - parser.add_argument("--gamma", type=float, default=0.99, help="discount factor") - parser.add_argument("--sync_rate", type=int, default=10, - help="how many frames do we update the target network") - parser.add_argument("--replay_size", type=int, default=1000, - help="capacity of the replay buffer") - parser.add_argument("--warm_start_size", type=int, default=1000, - help="how many samples do we use to fill our buffer at the start of training") - parser.add_argument("--eps_last_frame", type=int, default=1000, - help="what frame should epsilon stop decaying") - parser.add_argument("--eps_start", type=float, default=1.0, help="starting value of epsilon") - parser.add_argument("--eps_end", type=float, default=0.01, help="final value of epsilon") - parser.add_argument("--episode_length", type=int, default=200, help="max length of an episode") - parser.add_argument("--max_episode_reward", type=int, default=200, - help="max episode reward in the environment") - parser.add_argument("--warm_start_steps", type=int, default=1000, - help="max episode reward in the environment") - + parser = DQNLightning.add_model_specific_args(parser) args = parser.parse_args() main(args) diff --git a/pl_examples/domain_templates/semantic_segmentation.py b/pl_examples/domain_templates/semantic_segmentation.py index 15d406d954aa8..6ca93a9c356f3 100644 --- a/pl_examples/domain_templates/semantic_segmentation.py +++ b/pl_examples/domain_templates/semantic_segmentation.py @@ -142,17 +142,16 @@ class SegModel(pl.LightningModule): Adam optimizer is used along with Cosine Annealing learning rate scheduler. - >>> SegModel() # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE + >>> SegModel(data_path='.') # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE """ - def __init__( self, data_path: str, - batch_size: int, - lr: float, - num_layers: int, - features_start: int, - bilinear: bool, + batch_size: int = 4, + lr: float = 1e-3, + num_layers: int = 3, + features_start: int = 64, + bilinear: bool = False, **kwargs, ): super().__init__(**kwargs) @@ -209,6 +208,18 @@ def train_dataloader(self): def val_dataloader(self): return DataLoader(self.validset, batch_size=self.batch_size, shuffle=False) + @staticmethod + def add_model_specific_args(parent_parser): # pragma: no-cover + parser = ArgumentParser(parents=[parent_parser]) + parser.add_argument("--data_path", type=str, help="path where dataset is stored") + parser.add_argument("--batch_size", type=int, default=16, help="size of the batches") + parser.add_argument("--lr", type=float, default=0.001, help="adam: learning rate") + parser.add_argument("--num_layers", type=int, default=5, help="number of layers on u-net") + parser.add_argument("--features_start", type=float, default=64, help="number of features in first layer") + parser.add_argument("--bilinear", action='store_true', default=False, + help="whether to use bilinear interpolation or transposed") + return parser + def main(hparams: Namespace): # ------------------------ @@ -229,14 +240,7 @@ def main(hparams: Namespace): # ------------------------ # 3 INIT TRAINER # ------------------------ - trainer = pl.Trainer( - gpus=hparams.gpus, - logger=logger, - max_epochs=hparams.epochs, - accumulate_grad_batches=hparams.grad_batches, - accelerator=hparams.accelerator, - precision=16 if hparams.use_amp else 32, - ) + trainer = pl.Trainer.from_argparse_args(hparams) # ------------------------ # 5 START TRAINING @@ -247,21 +251,7 @@ def main(hparams: Namespace): if __name__ == '__main__': cli_lightning_logo() parser = ArgumentParser() - parser.add_argument("--data_path", type=str, help="path where dataset is stored") - parser.add_argument("--gpus", type=int, default=-1, help="number of available GPUs") - parser.add_argument('--distributed-backend', type=str, default='dp', choices=('dp', 'ddp', 'ddp2'), - help='supports three options dp, ddp, ddp2') - parser.add_argument('--use_amp', action='store_true', help='if true uses 16 bit precision') - parser.add_argument("--batch_size", type=int, default=4, help="size of the batches") - parser.add_argument("--lr", type=float, default=0.001, help="adam: learning rate") - parser.add_argument("--num_layers", type=int, default=5, help="number of layers on u-net") - parser.add_argument("--features_start", type=float, default=64, help="number of features in first layer") - parser.add_argument("--bilinear", action='store_true', default=False, - help="whether to use bilinear interpolation or transposed") - parser.add_argument("--grad_batches", type=int, default=1, help="number of batches to accumulate") - parser.add_argument("--epochs", type=int, default=20, help="number of epochs to train") - parser.add_argument("--log_wandb", action='store_true', help="log training on Weights & Biases") - + parser = SegModel.add_model_specific_args(parser) hparams = parser.parse_args() main(hparams) diff --git a/pl_examples/domain_templates/unet.py b/pl_examples/domain_templates/unet.py index 49666793861e6..bb4ebba3cbba5 100644 --- a/pl_examples/domain_templates/unet.py +++ b/pl_examples/domain_templates/unet.py @@ -22,14 +22,25 @@ class UNet(nn.Module): Architecture based on U-Net: Convolutional Networks for Biomedical Image Segmentation Link - https://arxiv.org/abs/1505.04597 - >>> UNet() # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE + >>> UNet(num_classes=2, num_layers=3) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE + UNet( + (layers): ModuleList( + (0): DoubleConv(...) + (1): Down(...) + (2): Down(...) + (3): Up(...) + (4): Up(...) + (5): Conv2d(64, 2, kernel_size=(1, 1), stride=(1, 1)) + ) + ) """ def __init__( - self, num_classes: int = 19, + self, + num_classes: int = 19, num_layers: int = 5, features_start: int = 64, - bilinear: bool = False + bilinear: bool = False, ): """ Args: @@ -73,7 +84,10 @@ class DoubleConv(nn.Module): Double Convolution and BN and ReLU (3x3 conv -> BN -> ReLU) ** 2 - >>> DoubleConv() # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE + >>> DoubleConv(4, 4) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE + DoubleConv( + (net): Sequential(...) + ) """ def __init__(self, in_ch: int, out_ch: int): @@ -95,7 +109,15 @@ class Down(nn.Module): """ Combination of MaxPool2d and DoubleConv in series - >>> Down() # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE + >>> Down(4, 8) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE + Down( + (net): Sequential( + (0): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False) + (1): DoubleConv( + (net): Sequential(...) + ) + ) + ) """ def __init__(self, in_ch: int, out_ch: int): @@ -115,7 +137,13 @@ class Up(nn.Module): followed by concatenation of feature map from contracting path, followed by double 3x3 convolution. - >>> Up() # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE + >>> Up(8, 4) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE + Up( + (upsample): ConvTranspose2d(8, 4, kernel_size=(2, 2), stride=(2, 2)) + (conv): DoubleConv( + (net): Sequential(...) + ) + ) """ def __init__(self, in_ch: int, out_ch: int, bilinear: bool = False): From d9fe9cd05629bf480bdccb827eb528ced1a60b82 Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Fri, 11 Dec 2020 11:09:52 +0100 Subject: [PATCH 5/9] test --- pl_examples/bug_report_model.py | 15 +++++++++------ .../domain_templates/semantic_segmentation.py | 2 ++ 2 files changed, 11 insertions(+), 6 deletions(-) diff --git a/pl_examples/bug_report_model.py b/pl_examples/bug_report_model.py index bfe2ed0ea290b..89969fad730f5 100644 --- a/pl_examples/bug_report_model.py +++ b/pl_examples/bug_report_model.py @@ -29,7 +29,8 @@ class RandomDataset(Dataset): """ - >>> RandomDataset((10, 5), 20) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE + >>> RandomDataset(size=10, length=20) # doctest: +ELLIPSIS + """ def __init__(self, size, length): self.len = length @@ -44,7 +45,10 @@ def __len__(self): class BoringModel(LightningModule): """ - >>>BoringModel() # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE + >>> BoringModel() # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE + BoringModel( + (layer): Linear(...) + ) """ def __init__(self): @@ -75,7 +79,7 @@ def loss(self, batch, prediction): # An arbitrary loss to have a loss that updates the model weights during `Trainer.fit` calls return torch.nn.functional.mse_loss(prediction, torch.ones_like(prediction)) - def step(self, x): + def _step(self, x): x = self.layer(x) out = torch.nn.functional.mse_loss(x, torch.ones_like(x)) return out @@ -119,10 +123,9 @@ def configure_optimizers(self): # parser = ArgumentParser() # args = parser.parse_args(opt) -def run_test(): +def test_run(): class TestModel(BoringModel): - def on_train_epoch_start(self) -> None: print('override any method to prove your bug') @@ -146,4 +149,4 @@ def on_train_epoch_start(self) -> None: if __name__ == '__main__': cli_lightning_logo() - run_test() + test_run() diff --git a/pl_examples/domain_templates/semantic_segmentation.py b/pl_examples/domain_templates/semantic_segmentation.py index 6ca93a9c356f3..1f05abeee2725 100644 --- a/pl_examples/domain_templates/semantic_segmentation.py +++ b/pl_examples/domain_templates/semantic_segmentation.py @@ -53,6 +53,8 @@ class KITTI(Dataset): In the `get_item` function, images and masks are resized to the given `img_size`, masks are encoded using `encode_segmap`, and given `transform` (if any) are applied to the image only (mask does not usually require transforms, but they can be implemented in a similar way). + + >>> KITTI() # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE """ IMAGE_PATH = os.path.join('training', 'image_2') MASK_PATH = os.path.join('training', 'semantic') From e82148c9b2a92e5cd9c787fc1354b3f93017c26e Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Fri, 11 Dec 2020 11:36:41 +0100 Subject: [PATCH 6/9] drop --- pl_examples/domain_templates/semantic_segmentation.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/pl_examples/domain_templates/semantic_segmentation.py b/pl_examples/domain_templates/semantic_segmentation.py index 1f05abeee2725..7bcad597a9a68 100644 --- a/pl_examples/domain_templates/semantic_segmentation.py +++ b/pl_examples/domain_templates/semantic_segmentation.py @@ -53,8 +53,6 @@ class KITTI(Dataset): In the `get_item` function, images and masks are resized to the given `img_size`, masks are encoded using `encode_segmap`, and given `transform` (if any) are applied to the image only (mask does not usually require transforms, but they can be implemented in a similar way). - - >>> KITTI() # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE """ IMAGE_PATH = os.path.join('training', 'image_2') MASK_PATH = os.path.join('training', 'semantic') @@ -143,8 +141,6 @@ class SegModel(pl.LightningModule): It uses the FCN ResNet50 model as an example. Adam optimizer is used along with Cosine Annealing learning rate scheduler. - - >>> SegModel(data_path='.') # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE """ def __init__( self, From 39e4a703b01a1199aa6f73b019107e4913ea3e81 Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Fri, 11 Dec 2020 11:40:47 +0100 Subject: [PATCH 7/9] format --- pl_examples/domain_templates/computer_vision_fine_tuning.py | 1 - pl_examples/domain_templates/unet.py | 3 +-- 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/pl_examples/domain_templates/computer_vision_fine_tuning.py b/pl_examples/domain_templates/computer_vision_fine_tuning.py index ca5925f0f56de..4392ac47e837f 100644 --- a/pl_examples/domain_templates/computer_vision_fine_tuning.py +++ b/pl_examples/domain_templates/computer_vision_fine_tuning.py @@ -180,7 +180,6 @@ def __init__( ) -> None: """ Args: - hparams: Model hyperparameters dl_path: Path where the data will be downloaded """ super().__init__(**kwargs) diff --git a/pl_examples/domain_templates/unet.py b/pl_examples/domain_templates/unet.py index bb4ebba3cbba5..2314e19ddbfc9 100644 --- a/pl_examples/domain_templates/unet.py +++ b/pl_examples/domain_templates/unet.py @@ -47,8 +47,7 @@ def __init__( num_classes: Number of output classes required (default 19 for KITTI dataset) num_layers: Number of layers in each side of U-net features_start: Number of features in first layer - bilinear: Whether to use bilinear interpolation or transposed - convolutions for upsampling. + bilinear: Whether to use bilinear interpolation or transposed convolutions for upsampling. """ super().__init__() self.num_layers = num_layers From 54f92005e7df29a3d00c45645a15aedb806f4d80 Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Fri, 11 Dec 2020 11:53:36 +0100 Subject: [PATCH 8/9] fix --- pl_examples/basic_examples/mnist_datamodule.py | 2 +- pl_examples/bug_report_model.py | 2 +- pl_examples/domain_templates/generative_adversarial_net.py | 2 +- pl_examples/domain_templates/reinforce_learn_Qnet.py | 6 +++--- 4 files changed, 6 insertions(+), 6 deletions(-) diff --git a/pl_examples/basic_examples/mnist_datamodule.py b/pl_examples/basic_examples/mnist_datamodule.py index 6599b7d594aa8..95e20d22e1fdd 100644 --- a/pl_examples/basic_examples/mnist_datamodule.py +++ b/pl_examples/basic_examples/mnist_datamodule.py @@ -31,7 +31,7 @@ class MNISTDataModule(LightningDataModule): Standard MNIST, train, val, test splits and transforms >>> MNISTDataModule() # doctest: +ELLIPSIS - + <...mnist_datamodule.MNISTDataModule object at ...> """ name = "mnist" diff --git a/pl_examples/bug_report_model.py b/pl_examples/bug_report_model.py index 89969fad730f5..f56b883b4ce02 100644 --- a/pl_examples/bug_report_model.py +++ b/pl_examples/bug_report_model.py @@ -30,7 +30,7 @@ class RandomDataset(Dataset): """ >>> RandomDataset(size=10, length=20) # doctest: +ELLIPSIS - + <...bug_report_model.RandomDataset object at ...> """ def __init__(self, size, length): self.len = length diff --git a/pl_examples/domain_templates/generative_adversarial_net.py b/pl_examples/domain_templates/generative_adversarial_net.py index 0abeb71516480..b0c324c193574 100644 --- a/pl_examples/domain_templates/generative_adversarial_net.py +++ b/pl_examples/domain_templates/generative_adversarial_net.py @@ -211,7 +211,7 @@ def on_epoch_end(self): class MNISTDataModule(LightningDataModule): """ >>> MNISTDataModule() # doctest: +ELLIPSIS - + <...generative_adversarial_net.MNISTDataModule object at ...> """ def __init__(self, batch_size: int = 64, data_path: str = os.getcwd(), num_workers: int = 4): super().__init__() diff --git a/pl_examples/domain_templates/reinforce_learn_Qnet.py b/pl_examples/domain_templates/reinforce_learn_Qnet.py index 2166286a044c2..6aee8bb6038c1 100644 --- a/pl_examples/domain_templates/reinforce_learn_Qnet.py +++ b/pl_examples/domain_templates/reinforce_learn_Qnet.py @@ -88,7 +88,7 @@ class ReplayBuffer: Replay Buffer for storing past experiences allowing the agent to learn from them >>> ReplayBuffer(5) # doctest: +ELLIPSIS - + <...reinforce_learn_Qnet.ReplayBuffer object at ...> """ def __init__(self, capacity: int) -> None: @@ -124,7 +124,7 @@ class RLDataset(IterableDataset): which will be updated with new experiences during training >>> RLDataset(ReplayBuffer(5)) # doctest: +ELLIPSIS - + <...reinforce_learn_Qnet.RLDataset object at ...> """ def __init__(self, buffer: ReplayBuffer, sample_size: int = 200) -> None: @@ -149,7 +149,7 @@ class Agent: >>> env = gym.make("CartPole-v0") >>> buffer = ReplayBuffer(10) >>> Agent(env, buffer) # doctest: +ELLIPSIS - + <...reinforce_learn_Qnet.Agent object at ...> """ def __init__(self, env: gym.Env, replay_buffer: ReplayBuffer) -> None: From 6c00b9ec56f07aaef8235557ed745ffaa1d11283 Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Sat, 12 Dec 2020 00:19:16 +0100 Subject: [PATCH 9/9] revert Co-authored-by: Justus Schock <12886177+justusschock@users.noreply.github.com> --- pl_examples/bug_report_model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pl_examples/bug_report_model.py b/pl_examples/bug_report_model.py index f56b883b4ce02..30345122e251f 100644 --- a/pl_examples/bug_report_model.py +++ b/pl_examples/bug_report_model.py @@ -79,7 +79,7 @@ def loss(self, batch, prediction): # An arbitrary loss to have a loss that updates the model weights during `Trainer.fit` calls return torch.nn.functional.mse_loss(prediction, torch.ones_like(prediction)) - def _step(self, x): + def step(self, x): x = self.layer(x) out = torch.nn.functional.mse_loss(x, torch.ones_like(x)) return out