From c0015fcfa79be3ab60a15f988853d29d1c9cb0be Mon Sep 17 00:00:00 2001 From: Zijian Zhou Date: Tue, 16 Sep 2025 10:49:02 +0100 Subject: [PATCH 1/2] Update autoencoder_kl_wan.py MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit When using the Wan2.2 VAE, the spatial compression ratio calculated here is incorrect. It should be 16 instead of 8. Pass it in directly via the config to ensure it’s correct here. --- src/diffusers/models/autoencoders/autoencoder_kl_wan.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_wan.py b/src/diffusers/models/autoencoders/autoencoder_kl_wan.py index d84a0861e984..c542b7af0a24 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl_wan.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl_wan.py @@ -1052,7 +1052,7 @@ def __init__( is_residual=is_residual, ) - self.spatial_compression_ratio = 2 ** len(self.temperal_downsample) + self.spatial_compression_ratio = scale_factor_spatial # When decoding a batch of video latents at a time, one can save memory by slicing across the batch dimension # to perform decoding of a single video latent at a time. From e0ea4431120f5538ecb1dcd2e1e7722a94d1bca1 Mon Sep 17 00:00:00 2001 From: Zijian Zhou Date: Tue, 16 Sep 2025 11:20:16 +0100 Subject: [PATCH 2/2] Update autoencoder_kl_wan.py --- src/diffusers/models/autoencoders/autoencoder_kl_wan.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_wan.py b/src/diffusers/models/autoencoders/autoencoder_kl_wan.py index c542b7af0a24..e6e58c1cce85 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl_wan.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl_wan.py @@ -1145,12 +1145,13 @@ def clear_cache(self): def _encode(self, x: torch.Tensor): _, _, num_frame, height, width = x.shape - if self.use_tiling and (width > self.tile_sample_min_width or height > self.tile_sample_min_height): - return self.tiled_encode(x) - self.clear_cache() if self.config.patch_size is not None: x = patchify(x, patch_size=self.config.patch_size) + + if self.use_tiling and (width > self.tile_sample_min_width or height > self.tile_sample_min_height): + return self.tiled_encode(x) + iter_ = 1 + (num_frame - 1) // 4 for i in range(iter_): self._enc_conv_idx = [0]