From 7deec938951329a6672874b8be1383b1d8b1477b Mon Sep 17 00:00:00 2001 From: William Berman Date: Wed, 4 Jan 2023 11:15:57 -0800 Subject: [PATCH 1/5] [dreambooth] low precision guard --- examples/dreambooth/train_dreambooth.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/examples/dreambooth/train_dreambooth.py b/examples/dreambooth/train_dreambooth.py index 8b752b45c534..a629d272e550 100644 --- a/examples/dreambooth/train_dreambooth.py +++ b/examples/dreambooth/train_dreambooth.py @@ -671,6 +671,17 @@ def main(args): if not args.train_text_encoder: text_encoder.to(accelerator.device, dtype=weight_dtype) + low_precision_error_string = ( + "Training on low precision datatypes is not supported. Even When doing mixed precision training, the master" + " copy of the weights should still be float32." + ) + + if unet.dtype != torch.float32 or (args.train_text_encoder and text_encoder.dtype != torch.float32): + raise ValueError(f"Unet loaded as datatype {unet.dtype}. {low_precision_error_string}") + + if args.train_text_encoder and text_encoder.dtype != torch.float32: + raise ValueError(f"Text encoder loaded as datatype {text_encoder.dtype}. {low_precision_error_string}") + # We need to recalculate our total training steps as the size of the training dataloader may have changed. num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) if overrode_max_train_steps: From 36125238608156f8d4ef1e396486592aa2c4dbc1 Mon Sep 17 00:00:00 2001 From: William Berman Date: Wed, 4 Jan 2023 11:21:22 -0800 Subject: [PATCH 2/5] fix --- examples/dreambooth/train_dreambooth.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/dreambooth/train_dreambooth.py b/examples/dreambooth/train_dreambooth.py index a629d272e550..8c9a24cf995c 100644 --- a/examples/dreambooth/train_dreambooth.py +++ b/examples/dreambooth/train_dreambooth.py @@ -676,7 +676,7 @@ def main(args): " copy of the weights should still be float32." ) - if unet.dtype != torch.float32 or (args.train_text_encoder and text_encoder.dtype != torch.float32): + if unet.dtype != torch.float32: raise ValueError(f"Unet loaded as datatype {unet.dtype}. {low_precision_error_string}") if args.train_text_encoder and text_encoder.dtype != torch.float32: From 8237e21d5c81b0cafe9c0393a71aa4eb11d2f6eb Mon Sep 17 00:00:00 2001 From: William Berman Date: Wed, 4 Jan 2023 11:34:24 -0800 Subject: [PATCH 3/5] add docs to cli args --- examples/dreambooth/train_dreambooth.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/examples/dreambooth/train_dreambooth.py b/examples/dreambooth/train_dreambooth.py index 8c9a24cf995c..d9d3b35b5ad3 100644 --- a/examples/dreambooth/train_dreambooth.py +++ b/examples/dreambooth/train_dreambooth.py @@ -70,7 +70,10 @@ def parse_args(input_args=None): type=str, default=None, required=False, - help="Revision of pretrained model identifier from huggingface.co/models.", + help=( + "Revision of pretrained model identifier from huggingface.co/models. Trainable model components should be" + " float32 precision." + ), ) parser.add_argument( "--tokenizer_name", @@ -140,7 +143,11 @@ def parse_args(input_args=None): parser.add_argument( "--center_crop", action="store_true", help="Whether to center crop images before resizing to resolution" ) - parser.add_argument("--train_text_encoder", action="store_true", help="Whether to train the text encoder") + parser.add_argument( + "--train_text_encoder", + action="store_true", + help="Whether to train the text encoder. If set, the text encoder should be float32 precision.", + ) parser.add_argument( "--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader." ) From 3faf0f4e60bb730f3ee9454c0413566a1d70359e Mon Sep 17 00:00:00 2001 From: Will Berman Date: Wed, 4 Jan 2023 11:50:25 -0800 Subject: [PATCH 4/5] Update examples/dreambooth/train_dreambooth.py Co-authored-by: Patrick von Platen --- examples/dreambooth/train_dreambooth.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/dreambooth/train_dreambooth.py b/examples/dreambooth/train_dreambooth.py index d9d3b35b5ad3..659bbb3e5f5e 100644 --- a/examples/dreambooth/train_dreambooth.py +++ b/examples/dreambooth/train_dreambooth.py @@ -679,7 +679,7 @@ def main(args): text_encoder.to(accelerator.device, dtype=weight_dtype) low_precision_error_string = ( - "Training on low precision datatypes is not supported. Even When doing mixed precision training, the master" + "Please make sure to always have all model weights in full float32 precision when starting training - even if doing mixed precision training." " copy of the weights should still be float32." ) From 568ceb7ed249b4983afe4e0ed7793e21d28a1e25 Mon Sep 17 00:00:00 2001 From: William Berman Date: Wed, 4 Jan 2023 11:52:25 -0800 Subject: [PATCH 5/5] style --- examples/dreambooth/train_dreambooth.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/dreambooth/train_dreambooth.py b/examples/dreambooth/train_dreambooth.py index 659bbb3e5f5e..84351fb84d9b 100644 --- a/examples/dreambooth/train_dreambooth.py +++ b/examples/dreambooth/train_dreambooth.py @@ -679,8 +679,8 @@ def main(args): text_encoder.to(accelerator.device, dtype=weight_dtype) low_precision_error_string = ( - "Please make sure to always have all model weights in full float32 precision when starting training - even if doing mixed precision training." - " copy of the weights should still be float32." + "Please make sure to always have all model weights in full float32 precision when starting training - even if" + " doing mixed precision training. copy of the weights should still be float32." ) if unet.dtype != torch.float32: