Skip to content

Commit cf5265a

Browse files
Allow selecting precision to make Dreambooth class images (#1832)
* allow selecting precision to make DB class images addresses #1831 * add prior_generation_precision argument * correct prior_generation_precision's description Co-authored-by: Suraj Patil <[email protected]> Co-authored-by: Suraj Patil <[email protected]>
1 parent 8874027 commit cf5265a

File tree

1 file changed

+16
-0
lines changed

1 file changed

+16
-0
lines changed

examples/dreambooth/train_dreambooth.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -247,6 +247,16 @@ def parse_args(input_args=None):
247247
" flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
248248
),
249249
)
250+
parser.add_argument(
251+
"--prior_generation_precision",
252+
type=str,
253+
default=None,
254+
choices=["no", "fp32", "fp16", "bf16"],
255+
help=(
256+
"Choose prior generation precision between fp32, fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
257+
" 1.10.and an Nvidia Ampere GPU. Default to fp16 if a GPU is available else fp32."
258+
),
259+
)
250260
parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
251261
parser.add_argument(
252262
"--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
@@ -436,6 +446,12 @@ def main(args):
436446

437447
if cur_class_images < args.num_class_images:
438448
torch_dtype = torch.float16 if accelerator.device.type == "cuda" else torch.float32
449+
if args.prior_generation_precision == "fp32":
450+
torch_dtype = torch.float32
451+
elif args.prior_generation_precision == "fp16":
452+
torch_dtype = torch.float16
453+
elif args.prior_generation_precision == "bf16":
454+
torch_dtype = torch.bfloat16
439455
pipeline = DiffusionPipeline.from_pretrained(
440456
args.pretrained_model_name_or_path,
441457
torch_dtype=torch_dtype,

0 commit comments

Comments
 (0)