Skip to content

Commit 1e07b6b

Browse files
authored
[Flax SD finetune] Fix dtype (#1038)
fix jnp dtype
1 parent fb38bb1 commit 1e07b6b

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

examples/text_to_image/train_text_to_image_flax.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -371,11 +371,11 @@ def collate_fn(examples):
371371
train_dataset, shuffle=True, collate_fn=collate_fn, batch_size=total_train_batch_size, drop_last=True
372372
)
373373

374-
weight_dtype = torch.float32
374+
weight_dtype = jnp.float32
375375
if args.mixed_precision == "fp16":
376-
weight_dtype = torch.float16
376+
weight_dtype = jnp.float16
377377
elif args.mixed_precision == "bf16":
378-
weight_dtype = torch.bfloat16
378+
weight_dtype = jnp.bfloat16
379379

380380
# Load models and create wrapper for stable diffusion
381381
tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder="tokenizer")

0 commit comments

Comments
 (0)