@@ -247,6 +247,16 @@ def parse_args(input_args=None):
247247 " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
248248 ),
249249 )
250+ parser .add_argument (
251+ "--prior_generation_precision" ,
252+ type = str ,
253+ default = None ,
254+ choices = ["no" , "fp32" , "fp16" , "bf16" ],
255+ help = (
256+ "Choose prior generation precision between fp32, fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
257+ " 1.10.and an Nvidia Ampere GPU. Default to fp16 if a GPU is available else fp32."
258+ ),
259+ )
250260 parser .add_argument ("--local_rank" , type = int , default = - 1 , help = "For distributed training: local_rank" )
251261 parser .add_argument (
252262 "--enable_xformers_memory_efficient_attention" , action = "store_true" , help = "Whether or not to use xformers."
@@ -436,6 +446,12 @@ def main(args):
436446
437447 if cur_class_images < args .num_class_images :
438448 torch_dtype = torch .float16 if accelerator .device .type == "cuda" else torch .float32
449+ if args .prior_generation_precision == "fp32" :
450+ torch_dtype = torch .float32
451+ elif args .prior_generation_precision == "fp16" :
452+ torch_dtype = torch .float16
453+ elif args .prior_generation_precision == "bf16" :
454+ torch_dtype = torch .bfloat16
439455 pipeline = DiffusionPipeline .from_pretrained (
440456 args .pretrained_model_name_or_path ,
441457 torch_dtype = torch_dtype ,
0 commit comments