Skip to content

Commit 2868d99

Browse files
timhpatrickvonplatenpcuenca
authored
dreambooth: fix #1566: maintain fp32 wrapper when saving a checkpoint to avoid crash when running fp16 (#1618)
* dreambooth: fix #1566: maintain fp32 wrapper when saving a checkpoint to avoid crash when running fp16 * dreambooth: guard against passing keep_fp32_wrapper arg to older versions of accelerate. part of fix for #1566 * Apply suggestions from code review Co-authored-by: Pedro Cuenca <[email protected]> * Update examples/dreambooth/train_dreambooth.py Co-authored-by: Patrick von Platen <[email protected]> Co-authored-by: Pedro Cuenca <[email protected]>
1 parent 0c18d02 commit 2868d99

File tree

1 file changed

+12
-2
lines changed

1 file changed

+12
-2
lines changed

examples/dreambooth/train_dreambooth.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import argparse
22
import hashlib
3+
import inspect
34
import itertools
45
import math
56
import os
@@ -690,10 +691,19 @@ def main(args):
690691

691692
if global_step % args.save_steps == 0:
692693
if accelerator.is_main_process:
694+
# When 'keep_fp32_wrapper' is `False` (the default), then the models are
695+
# unwrapped and the mixed precision hooks are removed, so training crashes
696+
# when the unwrapped models are used for further training.
697+
# This is only supported in newer versions of `accelerate`.
698+
# TODO(Pedro, Suraj): Remove `accepts_keep_fp32_wrapper` when forcing newer accelerate versions
699+
accepts_keep_fp32_wrapper = "keep_fp32_wrapper" in set(
700+
inspect.signature(accelerator.unwrap_model).parameters.keys()
701+
)
702+
extra_args = {"keep_fp32_wrapper": True} if accepts_keep_fp32_wrapper else {}
693703
pipeline = DiffusionPipeline.from_pretrained(
694704
args.pretrained_model_name_or_path,
695-
unet=accelerator.unwrap_model(unet),
696-
text_encoder=accelerator.unwrap_model(text_encoder),
705+
unet=accelerator.unwrap_model(unet, **extra_args),
706+
text_encoder=accelerator.unwrap_model(text_encoder, **extra_args),
697707
revision=args.revision,
698708
)
699709
save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")

0 commit comments

Comments
 (0)