File tree Expand file tree Collapse file tree 1 file changed +9
-0
lines changed Expand file tree Collapse file tree 1 file changed +9
-0
lines changed Original file line number Diff line number Diff line change @@ -120,6 +120,11 @@ def parse_args():
120120 default = 1 ,
121121 help = "Number of updates steps to accumulate before performing a backward/update pass." ,
122122 )
123+ parser .add_argument (
124+ "--gradient_checkpointing" ,
125+ action = "store_true" ,
126+ help = "Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass." ,
127+ )
123128 parser .add_argument (
124129 "--learning_rate" ,
125130 type = float ,
@@ -388,10 +393,14 @@ def main():
388393 args .pretrained_model_name_or_path , subfolder = "unet" , use_auth_token = args .use_auth_token
389394 )
390395
396+ if args .gradient_checkpointing :
397+ unet .enable_gradient_checkpointing ()
398+
391399 if args .scale_lr :
392400 args .learning_rate = (
393401 args .learning_rate * args .gradient_accumulation_steps * args .train_batch_size * accelerator .num_processes
394402 )
403+
395404 optimizer = torch .optim .AdamW (
396405 unet .parameters (), # only optimize unet
397406 lr = args .learning_rate ,
You can’t perform that action at this time.
0 commit comments