We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 5cd29d6 commit 33c4874Copy full SHA for 33c4874
examples/dreambooth/train_dreambooth.py
@@ -494,7 +494,12 @@ def collate_fn(examples):
494
pixel_values = torch.stack(pixel_values)
495
pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
496
497
- input_ids = tokenizer.pad({"input_ids": input_ids}, padding=True, return_tensors="pt").input_ids
+ input_ids = tokenizer.pad(
498
+ {"input_ids": input_ids},
499
+ padding="max_length",
500
+ max_length=tokenizer.model_max_length,
501
+ return_tensors="pt",
502
+ ).input_ids
503
504
batch = {
505
"input_ids": input_ids,
0 commit comments