diff --git a/src/diffusers/models/unets/unet_1d.py b/src/diffusers/models/unets/unet_1d.py index 4f57f3349bc4..4c4c528a59ad 100644 --- a/src/diffusers/models/unets/unet_1d.py +++ b/src/diffusers/models/unets/unet_1d.py @@ -82,6 +82,7 @@ def __init__( out_channels: int = 2, extra_in_channels: int = 0, time_embedding_type: str = "fourier", + time_embedding_dim: Optional[int] = None, flip_sin_to_cos: bool = True, use_timestep_embedding: bool = False, freq_shift: float = 0.0, @@ -100,15 +101,23 @@ def __init__( # time if time_embedding_type == "fourier": + time_embed_dim = time_embedding_dim or block_out_channels[0] * 2 + if time_embed_dim % 2 != 0: + raise ValueError(f"`time_embed_dim` should be divisible by 2, but is {time_embed_dim}.") self.time_proj = GaussianFourierProjection( - embedding_size=8, set_W_to_weight=False, log=False, flip_sin_to_cos=flip_sin_to_cos + embedding_size=time_embed_dim // 2, set_W_to_weight=False, log=False, flip_sin_to_cos=flip_sin_to_cos ) - timestep_input_dim = 2 * block_out_channels[0] + timestep_input_dim = time_embed_dim elif time_embedding_type == "positional": + time_embed_dim = time_embedding_dim or block_out_channels[0] * 4 self.time_proj = Timesteps( block_out_channels[0], flip_sin_to_cos=flip_sin_to_cos, downscale_freq_shift=freq_shift ) timestep_input_dim = block_out_channels[0] + else: + raise ValueError( + f"{time_embedding_type} does not exist. Please make sure to use one of `fourier` or `positional`." + ) if use_timestep_embedding: time_embed_dim = block_out_channels[0] * 4