Skip to content

Commit 247b5fe

Browse files
[dreambooth] low precision guard (#1916)
* [dreambooth] low precision guard * fix * add docs to cli args * Update examples/dreambooth/train_dreambooth.py Co-authored-by: Patrick von Platen <[email protected]> * style Co-authored-by: Patrick von Platen <[email protected]>
1 parent 7101c73 commit 247b5fe

File tree

1 file changed

+20
-2
lines changed

1 file changed

+20
-2
lines changed

examples/dreambooth/train_dreambooth.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,10 @@ def parse_args(input_args=None):
7070
type=str,
7171
default=None,
7272
required=False,
73-
help="Revision of pretrained model identifier from huggingface.co/models.",
73+
help=(
74+
"Revision of pretrained model identifier from huggingface.co/models. Trainable model components should be"
75+
" float32 precision."
76+
),
7477
)
7578
parser.add_argument(
7679
"--tokenizer_name",
@@ -140,7 +143,11 @@ def parse_args(input_args=None):
140143
parser.add_argument(
141144
"--center_crop", action="store_true", help="Whether to center crop images before resizing to resolution"
142145
)
143-
parser.add_argument("--train_text_encoder", action="store_true", help="Whether to train the text encoder")
146+
parser.add_argument(
147+
"--train_text_encoder",
148+
action="store_true",
149+
help="Whether to train the text encoder. If set, the text encoder should be float32 precision.",
150+
)
144151
parser.add_argument(
145152
"--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader."
146153
)
@@ -671,6 +678,17 @@ def main(args):
671678
if not args.train_text_encoder:
672679
text_encoder.to(accelerator.device, dtype=weight_dtype)
673680

681+
low_precision_error_string = (
682+
"Please make sure to always have all model weights in full float32 precision when starting training - even if"
683+
" doing mixed precision training. copy of the weights should still be float32."
684+
)
685+
686+
if unet.dtype != torch.float32:
687+
raise ValueError(f"Unet loaded as datatype {unet.dtype}. {low_precision_error_string}")
688+
689+
if args.train_text_encoder and text_encoder.dtype != torch.float32:
690+
raise ValueError(f"Text encoder loaded as datatype {text_encoder.dtype}. {low_precision_error_string}")
691+
674692
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
675693
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
676694
if overrode_max_train_steps:

0 commit comments

Comments
 (0)