Skip to content

Commit cbca6cd

Browse files
bmahlbrandcarmocca
andauthored
fix: update example autoencoder.py to reflect args (#6638)
* fix: update example autoencoder.py to reflect args * Update pl_examples/basic_examples/autoencoder.py Co-authored-by: Carlos Mocholí <[email protected]>
1 parent 70beddf commit cbca6cd

File tree

1 file changed

+7
-7
lines changed

1 file changed

+7
-7
lines changed

pl_examples/basic_examples/autoencoder.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -39,17 +39,17 @@ class LitAutoEncoder(pl.LightningModule):
3939
)
4040
"""
4141

42-
def __init__(self):
42+
def __init__(self, hidden_dim: int = 64):
4343
super().__init__()
4444
self.encoder = nn.Sequential(
45-
nn.Linear(28 * 28, 64),
45+
nn.Linear(28 * 28, hidden_dim),
4646
nn.ReLU(),
47-
nn.Linear(64, 3),
47+
nn.Linear(hidden_dim, 3),
4848
)
4949
self.decoder = nn.Sequential(
50-
nn.Linear(3, 64),
50+
nn.Linear(3, hidden_dim),
5151
nn.ReLU(),
52-
nn.Linear(64, 28 * 28),
52+
nn.Linear(hidden_dim, 28 * 28),
5353
)
5454

5555
def forward(self, x):
@@ -94,7 +94,7 @@ def cli_main():
9494
# ------------
9595
parser = ArgumentParser()
9696
parser.add_argument('--batch_size', default=32, type=int)
97-
parser.add_argument('--hidden_dim', type=int, default=128)
97+
parser.add_argument('--hidden_dim', type=int, default=64)
9898
parser = pl.Trainer.add_argparse_args(parser)
9999
args = parser.parse_args()
100100

@@ -112,7 +112,7 @@ def cli_main():
112112
# ------------
113113
# model
114114
# ------------
115-
model = LitAutoEncoder()
115+
model = LitAutoEncoder(args.hidden_dim)
116116

117117
# ------------
118118
# training

0 commit comments

Comments
 (0)