Skip to content

Commit 33c4874

Browse files
authored
Fix padding in dreambooth (#1030)
1 parent 5cd29d6 commit 33c4874

File tree

1 file changed

+6
-1
lines changed

1 file changed

+6
-1
lines changed

examples/dreambooth/train_dreambooth.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -494,7 +494,12 @@ def collate_fn(examples):
494494
pixel_values = torch.stack(pixel_values)
495495
pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
496496

497-
input_ids = tokenizer.pad({"input_ids": input_ids}, padding=True, return_tensors="pt").input_ids
497+
input_ids = tokenizer.pad(
498+
{"input_ids": input_ids},
499+
padding="max_length",
500+
max_length=tokenizer.model_max_length,
501+
return_tensors="pt",
502+
).input_ids
498503

499504
batch = {
500505
"input_ids": input_ids,

0 commit comments

Comments
 (0)