@@ -70,7 +70,10 @@ def parse_args(input_args=None):
7070 type = str ,
7171 default = None ,
7272 required = False ,
73- help = "Revision of pretrained model identifier from huggingface.co/models." ,
73+ help = (
74+ "Revision of pretrained model identifier from huggingface.co/models. Trainable model components should be"
75+ " float32 precision."
76+ ),
7477 )
7578 parser .add_argument (
7679 "--tokenizer_name" ,
@@ -140,7 +143,11 @@ def parse_args(input_args=None):
140143 parser .add_argument (
141144 "--center_crop" , action = "store_true" , help = "Whether to center crop images before resizing to resolution"
142145 )
143- parser .add_argument ("--train_text_encoder" , action = "store_true" , help = "Whether to train the text encoder" )
146+ parser .add_argument (
147+ "--train_text_encoder" ,
148+ action = "store_true" ,
149+ help = "Whether to train the text encoder. If set, the text encoder should be float32 precision." ,
150+ )
144151 parser .add_argument (
145152 "--train_batch_size" , type = int , default = 4 , help = "Batch size (per device) for the training dataloader."
146153 )
@@ -671,6 +678,17 @@ def main(args):
671678 if not args .train_text_encoder :
672679 text_encoder .to (accelerator .device , dtype = weight_dtype )
673680
681+ low_precision_error_string = (
682+ "Please make sure to always have all model weights in full float32 precision when starting training - even if"
683+ " doing mixed precision training. copy of the weights should still be float32."
684+ )
685+
686+ if unet .dtype != torch .float32 :
687+ raise ValueError (f"Unet loaded as datatype { unet .dtype } . { low_precision_error_string } " )
688+
689+ if args .train_text_encoder and text_encoder .dtype != torch .float32 :
690+ raise ValueError (f"Text encoder loaded as datatype { text_encoder .dtype } . { low_precision_error_string } " )
691+
674692 # We need to recalculate our total training steps as the size of the training dataloader may have changed.
675693 num_update_steps_per_epoch = math .ceil (len (train_dataloader ) / args .gradient_accumulation_steps )
676694 if overrode_max_train_steps :
0 commit comments