3737
3838
3939class Generator (nn .Module ):
40- def __init__ (self , latent_dim , img_shape ):
40+ """
41+ >>> Generator(img_shape=(1, 8, 8)) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE
42+ Generator(
43+ (model): Sequential(...)
44+ )
45+ """
46+ def __init__ (self , latent_dim : int = 100 , img_shape : tuple = (1 , 28 , 28 )):
4147 super ().__init__ ()
4248 self .img_shape = img_shape
4349
@@ -64,6 +70,12 @@ def forward(self, z):
6470
6571
6672class Discriminator (nn .Module ):
73+ """
74+ >>> Discriminator(img_shape=(1, 28, 28)) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE
75+ Discriminator(
76+ (model): Sequential(...)
77+ )
78+ """
6779 def __init__ (self , img_shape ):
6880 super ().__init__ ()
6981
@@ -83,6 +95,37 @@ def forward(self, img):
8395
8496
8597class GAN (LightningModule ):
98+ """
99+ >>> GAN(img_shape=(1, 8, 8)) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE
100+ GAN(
101+ (generator): Generator(
102+ (model): Sequential(...)
103+ )
104+ (discriminator): Discriminator(
105+ (model): Sequential(...)
106+ )
107+ )
108+ """
109+ def __init__ (
110+ self ,
111+ img_shape : tuple = (1 , 28 , 28 ),
112+ lr : float = 0.0002 ,
113+ b1 : float = 0.5 ,
114+ b2 : float = 0.999 ,
115+ latent_dim : int = 100 ,
116+ ):
117+ super ().__init__ ()
118+
119+ self .save_hyperparameters ()
120+
121+ # networks
122+ self .generator = Generator (latent_dim = self .hparams .latent_dim , img_shape = img_shape )
123+ self .discriminator = Discriminator (img_shape = img_shape )
124+
125+ self .validation_z = torch .randn (8 , self .hparams .latent_dim )
126+
127+ self .example_input_array = torch .zeros (2 , self .hparams .latent_dim )
128+
86129 @staticmethod
87130 def add_argparse_args (parent_parser : ArgumentParser ):
88131 parser = ArgumentParser (parents = [parent_parser ], add_help = False )
@@ -96,20 +139,6 @@ def add_argparse_args(parent_parser: ArgumentParser):
96139
97140 return parser
98141
99- def __init__ (self , hparams : Namespace ):
100- super ().__init__ ()
101-
102- self .hparams = hparams
103-
104- # networks
105- mnist_shape = (1 , 28 , 28 )
106- self .generator = Generator (latent_dim = self .hparams .latent_dim , img_shape = mnist_shape )
107- self .discriminator = Discriminator (img_shape = mnist_shape )
108-
109- self .validation_z = torch .randn (8 , self .hparams .latent_dim )
110-
111- self .example_input_array = torch .zeros (2 , self .hparams .latent_dim )
112-
113142 def forward (self , z ):
114143 return self .generator (z )
115144
@@ -180,6 +209,10 @@ def on_epoch_end(self):
180209
181210
182211class MNISTDataModule (LightningDataModule ):
212+ """
213+ >>> MNISTDataModule() # doctest: +ELLIPSIS
214+ <...generative_adversarial_net.MNISTDataModule object at ...>
215+ """
183216 def __init__ (self , batch_size : int = 64 , data_path : str = os .getcwd (), num_workers : int = 4 ):
184217 super ().__init__ ()
185218 self .batch_size = batch_size
0 commit comments