@@ -102,8 +102,9 @@ def __init__(self, lr=0.05, batch_size=32, manual_optimization=False):
102102 nn .Linear (512 , 10 )
103103 )
104104 self ._example_input_array = torch .randn ((1 , 3 , 32 , 32 ))
105- self ._manual_optimization = manual_optimization
106- if self ._manual_optimization :
105+
106+ if manual_optimization :
107+ self .automatic_optimization = False
107108 self .training_step = self .training_step_manual
108109
109110 def forward (self , x ):
@@ -165,10 +166,6 @@ def configure_optimizers(self):
165166 }
166167 }
167168
168- @property
169- def automatic_optimization (self ) -> bool :
170- return not self ._manual_optimization
171-
172169
173170#################################
174171# Instantiate Data Module #
@@ -189,6 +186,7 @@ def instantiate_datamodule(args):
189186 ])
190187
191188 cifar10_dm = pl_bolts .datamodules .CIFAR10DataModule (
189+ data_dir = args .data_dir ,
192190 batch_size = args .batch_size ,
193191 train_transforms = train_transforms ,
194192 test_transforms = test_transforms ,
@@ -206,6 +204,7 @@ def instantiate_datamodule(args):
206204
207205 parser = ArgumentParser (description = "Pipe Example" )
208206 parser .add_argument ("--use_rpc_sequential" , action = "store_true" )
207+ parser .add_argument ("--manual_optimization" , action = "store_true" )
209208 parser = Trainer .add_argparse_args (parser )
210209 parser = pl_bolts .datamodules .CIFAR10DataModule .add_argparse_args (parser )
211210 args = parser .parse_args ()
@@ -216,7 +215,7 @@ def instantiate_datamodule(args):
216215 if args .use_rpc_sequential :
217216 plugins = RPCSequentialPlugin ()
218217
219- model = LitResnet (batch_size = args .batch_size , manual_optimization = not args .automatic_optimization )
218+ model = LitResnet (batch_size = args .batch_size , manual_optimization = args .manual_optimization )
220219
221220 trainer = pl .Trainer .from_argparse_args (args , plugins = [plugins ] if plugins else None )
222221 trainer .fit (model , cifar10_dm )
0 commit comments