From 4313003fa318fdc60ea8ae0b4822157b55a962a8 Mon Sep 17 00:00:00 2001 From: Barque-S <111188221+Barque-S@users.noreply.github.com> Date: Tue, 30 Sep 2025 22:33:05 +0800 Subject: [PATCH 1/3] fix: train_dreambooth_sd3.py line1605: 1. pass prompts to the prompt param instead of None; 2. tokenizers should be a list; 3. and max_sequence_length should be provided; 4. _encode_with_t5(...) now behave similar as train_dreambooth_lora_sd3.py does; when calling encode_prompt(...). --- examples/dreambooth/train_dreambooth_sd3.py | 28 +++++++++++++-------- 1 file changed, 17 insertions(+), 11 deletions(-) diff --git a/examples/dreambooth/train_dreambooth_sd3.py b/examples/dreambooth/train_dreambooth_sd3.py index d345ebb391e3..be595188a72a 100644 --- a/examples/dreambooth/train_dreambooth_sd3.py +++ b/examples/dreambooth/train_dreambooth_sd3.py @@ -875,15 +875,20 @@ def _encode_prompt_with_t5( prompt = [prompt] if isinstance(prompt, str) else prompt batch_size = len(prompt) - text_inputs = tokenizer( - prompt, - padding="max_length", - max_length=max_sequence_length, - truncation=True, - add_special_tokens=True, - return_tensors="pt", - ) - text_input_ids = text_inputs.input_ids + if tokenizer is not None: + text_inputs = tokenizer( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + add_special_tokens=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + else: + if text_input_ids is None: + raise ValueError("text_input_ids must be provided when the tokenizer is not specified") + prompt_embeds = text_encoder(text_input_ids.to(device))[0] dtype = text_encoder.dtype @@ -1604,8 +1609,9 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): else: prompt_embeds, pooled_prompt_embeds = encode_prompt( text_encoders=[text_encoder_one, text_encoder_two, text_encoder_three], - tokenizers=None, - prompt=None, + tokenizers=[None, None, None], + prompt=prompts, + max_sequence_length=args.max_sequence_length, text_input_ids_list=[tokens_one, tokens_two, tokens_three], ) model_pred = transformer( From 5505cdb5d9dec706823874728647fc9c95e4091a Mon Sep 17 00:00:00 2001 From: Barque-S <111188221+Barque-S@users.noreply.github.com> Date: Tue, 30 Sep 2025 22:34:25 +0800 Subject: [PATCH 2/3] fix: train_dreambooth_lora_sd3.py line 1751: prompts instead of args.instance_prompt should be passed to param prompt, when calling encode_prompt(...) here.. --- examples/dreambooth/train_dreambooth_lora_sd3.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/dreambooth/train_dreambooth_lora_sd3.py b/examples/dreambooth/train_dreambooth_lora_sd3.py index eef732c531d3..7ffcc1711e1d 100644 --- a/examples/dreambooth/train_dreambooth_lora_sd3.py +++ b/examples/dreambooth/train_dreambooth_lora_sd3.py @@ -1750,8 +1750,8 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): if args.train_text_encoder: prompt_embeds, pooled_prompt_embeds = encode_prompt( text_encoders=[text_encoder_one, text_encoder_two, text_encoder_three], - tokenizers=[None, None, tokenizer_three], - prompt=args.instance_prompt, + tokenizers=[None, None, None], + prompt=prompts, max_sequence_length=args.max_sequence_length, text_input_ids_list=[tokens_one, tokens_two, tokens_three], ) From f0d9ee446677acc58fde77b45adc96d30ad177c2 Mon Sep 17 00:00:00 2001 From: Barque-S <111188221+Barque-S@users.noreply.github.com> Date: Sat, 11 Oct 2025 01:02:24 +0800 Subject: [PATCH 3/3] fix code quality check: whitespace removed from blank line for code quality check --- examples/dreambooth/train_dreambooth_sd3.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/dreambooth/train_dreambooth_sd3.py b/examples/dreambooth/train_dreambooth_sd3.py index be595188a72a..b30142bd0361 100644 --- a/examples/dreambooth/train_dreambooth_sd3.py +++ b/examples/dreambooth/train_dreambooth_sd3.py @@ -888,7 +888,7 @@ def _encode_prompt_with_t5( else: if text_input_ids is None: raise ValueError("text_input_ids must be provided when the tokenizer is not specified") - + prompt_embeds = text_encoder(text_input_ids.to(device))[0] dtype = text_encoder.dtype