From 5ed984cc47b44d0e6354411a90fc082acb65bbf3 Mon Sep 17 00:00:00 2001 From: DN6 Date: Fri, 3 Oct 2025 14:42:58 +0530 Subject: [PATCH 1/3] update --- src/diffusers/pipelines/wan/pipeline_wan_vace.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/wan/pipeline_wan_vace.py b/src/diffusers/pipelines/wan/pipeline_wan_vace.py index 2b1890afec97..48e9d7aa1274 100644 --- a/src/diffusers/pipelines/wan/pipeline_wan_vace.py +++ b/src/diffusers/pipelines/wan/pipeline_wan_vace.py @@ -336,7 +336,15 @@ def check_inputs( reference_images=None, guidance_scale_2=None, ): - base = self.vae_scale_factor_spatial * self.transformer.config.patch_size[1] + if self.transformer is not None: + base = self.vae_scale_factor_spatial * self.transformer.config.patch_size[1] + elif self.transformer_2 is not None: + base = self.vae_scale_factor_spatial * self.transformer_2.config.patch_size[1] + else: + raise ValueError( + "`transformer` or `transformer_2` component must be set in order to run inference with this pipeline" + ) + if height % base != 0 or width % base != 0: raise ValueError(f"`height` and `width` have to be divisible by {base} but are {height} and {width}.") From 5015ce4fc78e13b2a83b83e9dbcc1bca49469bf0 Mon Sep 17 00:00:00 2001 From: DN6 Date: Fri, 3 Oct 2025 16:44:23 +0530 Subject: [PATCH 2/3] update --- .../pipelines/wan/pipeline_wan_vace.py | 37 +++++++++++----- tests/pipelines/wan/test_wan_vace.py | 44 +++++++++++++++++-- 2 files changed, 68 insertions(+), 13 deletions(-) diff --git a/src/diffusers/pipelines/wan/pipeline_wan_vace.py b/src/diffusers/pipelines/wan/pipeline_wan_vace.py index 48e9d7aa1274..3dd80cb1c6a0 100644 --- a/src/diffusers/pipelines/wan/pipeline_wan_vace.py +++ b/src/diffusers/pipelines/wan/pipeline_wan_vace.py @@ -422,7 +422,11 @@ def preprocess_conditions( device: Optional[torch.device] = None, ): if video is not None: - base = self.vae_scale_factor_spatial * self.transformer.config.patch_size[1] + base = self.vae_scale_factor_spatial * ( + self.transformer.config.patch_size[1] + if self.transformer is not None + else self.transformer_2.config.patch_size[1] + ) video_height, video_width = self.video_processor.get_default_height_width(video[0]) if video_height * video_width > height * width: @@ -597,7 +601,11 @@ def prepare_masks( "Generating with more than one video is not yet supported. This may be supported in the future." ) - transformer_patch_size = self.transformer.config.patch_size[1] + transformer_patch_size = ( + self.transformer.config.patch_size[1] + if self.transformer is not None + else self.transformer_2.config.patch_size[1] + ) mask_list = [] for mask_, reference_images_batch in zip(mask, reference_images): @@ -852,20 +860,25 @@ def __call__( batch_size = prompt_embeds.shape[0] vae_dtype = self.vae.dtype - transformer_dtype = self.transformer.dtype + transformer_dtype = self.transformer.dtype if self.transformer is not None else self.transformer_2.dtype + vace_layers = ( + self.transformer.config.vace_layers + if self.transformer is not None + else self.transformer_2.config.vace_layers + ) if isinstance(conditioning_scale, (int, float)): - conditioning_scale = [conditioning_scale] * len(self.transformer.config.vace_layers) + conditioning_scale = [conditioning_scale] * len(vace_layers) if isinstance(conditioning_scale, list): - if len(conditioning_scale) != len(self.transformer.config.vace_layers): + if len(conditioning_scale) != len(vace_layers): raise ValueError( - f"Length of `conditioning_scale` {len(conditioning_scale)} does not match number of layers {len(self.transformer.config.vace_layers)}." + f"Length of `conditioning_scale` {len(conditioning_scale)} does not match number of layers {len(vace_layers)}." ) conditioning_scale = torch.tensor(conditioning_scale) if isinstance(conditioning_scale, torch.Tensor): - if conditioning_scale.size(0) != len(self.transformer.config.vace_layers): + if conditioning_scale.size(0) != len(vace_layers): raise ValueError( - f"Length of `conditioning_scale` {conditioning_scale.size(0)} does not match number of layers {len(self.transformer.config.vace_layers)}." + f"Length of `conditioning_scale` {conditioning_scale.size(0)} does not match number of layers {len(vace_layers)}." ) conditioning_scale = conditioning_scale.to(device=device, dtype=transformer_dtype) @@ -908,7 +921,11 @@ def __call__( conditioning_latents = torch.cat([conditioning_latents, mask], dim=1) conditioning_latents = conditioning_latents.to(transformer_dtype) - num_channels_latents = self.transformer.config.in_channels + num_channels_latents = ( + self.transformer.config.in_channels + if self.transformer is not None + else self.transformer_2.config.in_channels + ) latents = self.prepare_latents( batch_size * num_videos_per_prompt, num_channels_latents, @@ -976,7 +993,7 @@ def __call__( attention_kwargs=attention_kwargs, return_dict=False, )[0] - noise_pred = noise_uncond + guidance_scale * (noise_pred - noise_uncond) + noise_pred = noise_uncond + current_guidance_scale * (noise_pred - noise_uncond) # compute the previous noisy sample x_t -> x_t-1 latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] diff --git a/tests/pipelines/wan/test_wan_vace.py b/tests/pipelines/wan/test_wan_vace.py index f99863c88092..4a7226af0f78 100644 --- a/tests/pipelines/wan/test_wan_vace.py +++ b/tests/pipelines/wan/test_wan_vace.py @@ -19,9 +19,15 @@ from PIL import Image from transformers import AutoTokenizer, T5EncoderModel -from diffusers import AutoencoderKLWan, FlowMatchEulerDiscreteScheduler, WanVACEPipeline, WanVACETransformer3DModel - -from ...testing_utils import enable_full_determinism +from diffusers import ( + AutoencoderKLWan, + FlowMatchEulerDiscreteScheduler, + UniPCMultistepScheduler, + WanVACEPipeline, + WanVACETransformer3DModel, +) + +from ...testing_utils import enable_full_determinism, torch_device from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS from ..test_pipelines_common import PipelineTesterMixin @@ -212,3 +218,35 @@ def test_float16_inference(self): ) def test_save_load_float16(self): pass + + def test_inference_with_only_transformer(self): + components = self.get_dummy_components() + components["transformer_2"] = None + components["boundary_ratio"] = 0.0 + pipe = self.pipeline_class(**components) + pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(torch_device) + video = pipe(**inputs).frames[0] + assert video.shape == (17, 3, 16, 16) + + def test_inference_with_only_transformer_2(self): + components = self.get_dummy_components() + components["transformer_2"] = components["transformer"] + components["transformer"] = None + + # FlowMatchEulerDiscreteScheduler doesn't support running low noise only scheduler + # because starting timestep t == 1000 == boundary_timestep + components["scheduler"] = UniPCMultistepScheduler( + prediction_type="flow_prediction", use_flow_sigmas=True, flow_shift=3.0 + ) + + components["boundary_ratio"] = 1.0 + pipe = self.pipeline_class(**components) + pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(torch_device) + video = pipe(**inputs).frames[0] + assert video.shape == (17, 3, 16, 16) From 99308efb551316502152c08fdb631daa16a3770a Mon Sep 17 00:00:00 2001 From: DN6 Date: Fri, 3 Oct 2025 16:48:43 +0530 Subject: [PATCH 3/3] update --- .../pipelines/wan/pipeline_wan_vace.py | 24 ++++++++++--------- 1 file changed, 13 insertions(+), 11 deletions(-) diff --git a/src/diffusers/pipelines/wan/pipeline_wan_vace.py b/src/diffusers/pipelines/wan/pipeline_wan_vace.py index 3dd80cb1c6a0..e0c9cc2575fa 100644 --- a/src/diffusers/pipelines/wan/pipeline_wan_vace.py +++ b/src/diffusers/pipelines/wan/pipeline_wan_vace.py @@ -152,34 +152,36 @@ class WanVACEPipeline(DiffusionPipeline, WanLoraLoaderMixin): text_encoder ([`T5EncoderModel`]): [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically the [google/umt5-xxl](https://huggingface.co/google/umt5-xxl) variant. - transformer ([`WanVACETransformer3DModel`]): - Conditional Transformer to denoise the input latents. - transformer_2 ([`WanVACETransformer3DModel`], *optional*): - Conditional Transformer to denoise the input latents during the low-noise stage. In two-stage denoising, - `transformer` handles high-noise stages and `transformer_2` handles low-noise stages. If not provided, only - `transformer` is used. - scheduler ([`UniPCMultistepScheduler`]): - A scheduler to be used in combination with `transformer` to denoise the encoded image latents. vae ([`AutoencoderKLWan`]): Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations. + scheduler ([`UniPCMultistepScheduler`]): + A scheduler to be used in combination with `transformer` to denoise the encoded image latents. + transformer ([`WanVACETransformer3DModel`], *optional*): + Conditional Transformer to denoise the input latents during the high-noise stage. In two-stage denoising, + `transformer` handles high-noise stages and `transformer_2` handles low-noise stages. At least one of + `transformer` or `transformer_2` must be provided. + transformer_2 ([`WanVACETransformer3DModel`], *optional*): + Conditional Transformer to denoise the input latents during the low-noise stage. In two-stage denoising, + `transformer` handles high-noise stages and `transformer_2` handles low-noise stages. At least one of + `transformer` or `transformer_2` must be provided. boundary_ratio (`float`, *optional*, defaults to `None`): Ratio of total timesteps to use as the boundary for switching between transformers in two-stage denoising. The actual boundary timestep is calculated as `boundary_ratio * num_train_timesteps`. When provided, `transformer` handles timesteps >= boundary_timestep and `transformer_2` handles timesteps < - boundary_timestep. If `None`, only `transformer` is used for the entire denoising process. + boundary_timestep. If `None`, only the available transformer is used for the entire denoising process. """ model_cpu_offload_seq = "text_encoder->transformer->vae" _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] - _optional_components = ["transformer_2"] + _optional_components = ["transformer", "transformer_2"] def __init__( self, tokenizer: AutoTokenizer, text_encoder: UMT5EncoderModel, - transformer: WanVACETransformer3DModel, vae: AutoencoderKLWan, scheduler: FlowMatchEulerDiscreteScheduler, + transformer: WanVACETransformer3DModel = None, transformer_2: WanVACETransformer3DModel = None, boundary_ratio: Optional[float] = None, ):