@@ -39,17 +39,17 @@ class LitAutoEncoder(pl.LightningModule):
3939 )
4040 """
4141
42- def __init__ (self ):
42+ def __init__ (self , hidden_dim : int = 64 ):
4343 super ().__init__ ()
4444 self .encoder = nn .Sequential (
45- nn .Linear (28 * 28 , 64 ),
45+ nn .Linear (28 * 28 , hidden_dim ),
4646 nn .ReLU (),
47- nn .Linear (64 , 3 ),
47+ nn .Linear (hidden_dim , 3 ),
4848 )
4949 self .decoder = nn .Sequential (
50- nn .Linear (3 , 64 ),
50+ nn .Linear (3 , hidden_dim ),
5151 nn .ReLU (),
52- nn .Linear (64 , 28 * 28 ),
52+ nn .Linear (hidden_dim , 28 * 28 ),
5353 )
5454
5555 def forward (self , x ):
@@ -94,7 +94,7 @@ def cli_main():
9494 # ------------
9595 parser = ArgumentParser ()
9696 parser .add_argument ('--batch_size' , default = 32 , type = int )
97- parser .add_argument ('--hidden_dim' , type = int , default = 128 )
97+ parser .add_argument ('--hidden_dim' , type = int , default = 64 )
9898 parser = pl .Trainer .add_argparse_args (parser )
9999 args = parser .parse_args ()
100100
@@ -112,7 +112,7 @@ def cli_main():
112112 # ------------
113113 # model
114114 # ------------
115- model = LitAutoEncoder ()
115+ model = LitAutoEncoder (args . hidden_dim )
116116
117117 # ------------
118118 # training
0 commit comments