Skip to content

Commit 34b733b

Browse files
authored
Fix manual optimization in pl_example (#6373)
* Fix automatic_optimization * Fix automatic_optimization * Uncomment fairscale
1 parent facfda8 commit 34b733b

File tree

1 file changed

+6
-7
lines changed

1 file changed

+6
-7
lines changed

pl_examples/basic_examples/conv_sequential_example.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -102,8 +102,9 @@ def __init__(self, lr=0.05, batch_size=32, manual_optimization=False):
102102
nn.Linear(512, 10)
103103
)
104104
self._example_input_array = torch.randn((1, 3, 32, 32))
105-
self._manual_optimization = manual_optimization
106-
if self._manual_optimization:
105+
106+
if manual_optimization:
107+
self.automatic_optimization = False
107108
self.training_step = self.training_step_manual
108109

109110
def forward(self, x):
@@ -165,10 +166,6 @@ def configure_optimizers(self):
165166
}
166167
}
167168

168-
@property
169-
def automatic_optimization(self) -> bool:
170-
return not self._manual_optimization
171-
172169

173170
#################################
174171
# Instantiate Data Module #
@@ -189,6 +186,7 @@ def instantiate_datamodule(args):
189186
])
190187

191188
cifar10_dm = pl_bolts.datamodules.CIFAR10DataModule(
189+
data_dir=args.data_dir,
192190
batch_size=args.batch_size,
193191
train_transforms=train_transforms,
194192
test_transforms=test_transforms,
@@ -206,6 +204,7 @@ def instantiate_datamodule(args):
206204

207205
parser = ArgumentParser(description="Pipe Example")
208206
parser.add_argument("--use_rpc_sequential", action="store_true")
207+
parser.add_argument("--manual_optimization", action="store_true")
209208
parser = Trainer.add_argparse_args(parser)
210209
parser = pl_bolts.datamodules.CIFAR10DataModule.add_argparse_args(parser)
211210
args = parser.parse_args()
@@ -216,7 +215,7 @@ def instantiate_datamodule(args):
216215
if args.use_rpc_sequential:
217216
plugins = RPCSequentialPlugin()
218217

219-
model = LitResnet(batch_size=args.batch_size, manual_optimization=not args.automatic_optimization)
218+
model = LitResnet(batch_size=args.batch_size, manual_optimization=args.manual_optimization)
220219

221220
trainer = pl.Trainer.from_argparse_args(args, plugins=[plugins] if plugins else None)
222221
trainer.fit(model, cifar10_dm)

0 commit comments

Comments
 (0)