Skip to content

Commit 55e1790

Browse files
authored
Add dropout parameter to UNet2DModel/UNet2DConditionModel (huggingface#4882)
* Add dropout param to get_down_block/get_up_block and UNet2DModel/UNet2DConditionModel. * Add dropout param to Versatile Diffusion modeling, which has a copy of UNet2DConditionModel and its own get_down_block/get_up_block functions.
1 parent c81a88b commit 55e1790

File tree

4 files changed

+47
-0
lines changed

4 files changed

+47
-0
lines changed

src/diffusers/models/unet_2d.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@ class UNet2DModel(ModelMixin, ConfigMixin):
7070
The downsample type for downsampling layers. Choose between "conv" and "resnet"
7171
upsample_type (`str`, *optional*, defaults to `conv`):
7272
The upsample type for upsampling layers. Choose between "conv" and "resnet"
73+
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
7374
act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
7475
attention_head_dim (`int`, *optional*, defaults to `8`): The attention head dimension.
7576
norm_num_groups (`int`, *optional*, defaults to `32`): The number of groups for normalization.
@@ -102,6 +103,7 @@ def __init__(
102103
downsample_padding: int = 1,
103104
downsample_type: str = "conv",
104105
upsample_type: str = "conv",
106+
dropout: float = 0.0,
105107
act_fn: str = "silu",
106108
attention_head_dim: Optional[int] = 8,
107109
norm_num_groups: int = 32,
@@ -175,13 +177,15 @@ def __init__(
175177
downsample_padding=downsample_padding,
176178
resnet_time_scale_shift=resnet_time_scale_shift,
177179
downsample_type=downsample_type,
180+
dropout=dropout,
178181
)
179182
self.down_blocks.append(down_block)
180183

181184
# mid
182185
self.mid_block = UNetMidBlock2D(
183186
in_channels=block_out_channels[-1],
184187
temb_channels=time_embed_dim,
188+
dropout=dropout,
185189
resnet_eps=norm_eps,
186190
resnet_act_fn=act_fn,
187191
output_scale_factor=mid_block_scale_factor,
@@ -215,6 +219,7 @@ def __init__(
215219
attention_head_dim=attention_head_dim if attention_head_dim is not None else output_channel,
216220
resnet_time_scale_shift=resnet_time_scale_shift,
217221
upsample_type=upsample_type,
222+
dropout=dropout,
218223
)
219224
self.up_blocks.append(up_block)
220225
prev_output_channel = output_channel

src/diffusers/models/unet_2d_blocks.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ def get_down_block(
5555
cross_attention_norm=None,
5656
attention_head_dim=None,
5757
downsample_type=None,
58+
dropout=0.0,
5859
):
5960
# If attn head dim is not defined, we default it to the number of heads
6061
if attention_head_dim is None:
@@ -70,6 +71,7 @@ def get_down_block(
7071
in_channels=in_channels,
7172
out_channels=out_channels,
7273
temb_channels=temb_channels,
74+
dropout=dropout,
7375
add_downsample=add_downsample,
7476
resnet_eps=resnet_eps,
7577
resnet_act_fn=resnet_act_fn,
@@ -83,6 +85,7 @@ def get_down_block(
8385
in_channels=in_channels,
8486
out_channels=out_channels,
8587
temb_channels=temb_channels,
88+
dropout=dropout,
8689
add_downsample=add_downsample,
8790
resnet_eps=resnet_eps,
8891
resnet_act_fn=resnet_act_fn,
@@ -101,6 +104,7 @@ def get_down_block(
101104
in_channels=in_channels,
102105
out_channels=out_channels,
103106
temb_channels=temb_channels,
107+
dropout=dropout,
104108
resnet_eps=resnet_eps,
105109
resnet_act_fn=resnet_act_fn,
106110
resnet_groups=resnet_groups,
@@ -118,6 +122,7 @@ def get_down_block(
118122
in_channels=in_channels,
119123
out_channels=out_channels,
120124
temb_channels=temb_channels,
125+
dropout=dropout,
121126
add_downsample=add_downsample,
122127
resnet_eps=resnet_eps,
123128
resnet_act_fn=resnet_act_fn,
@@ -140,6 +145,7 @@ def get_down_block(
140145
in_channels=in_channels,
141146
out_channels=out_channels,
142147
temb_channels=temb_channels,
148+
dropout=dropout,
143149
add_downsample=add_downsample,
144150
resnet_eps=resnet_eps,
145151
resnet_act_fn=resnet_act_fn,
@@ -158,6 +164,7 @@ def get_down_block(
158164
in_channels=in_channels,
159165
out_channels=out_channels,
160166
temb_channels=temb_channels,
167+
dropout=dropout,
161168
add_downsample=add_downsample,
162169
resnet_eps=resnet_eps,
163170
resnet_act_fn=resnet_act_fn,
@@ -170,6 +177,7 @@ def get_down_block(
170177
in_channels=in_channels,
171178
out_channels=out_channels,
172179
temb_channels=temb_channels,
180+
dropout=dropout,
173181
add_downsample=add_downsample,
174182
resnet_eps=resnet_eps,
175183
resnet_act_fn=resnet_act_fn,
@@ -181,6 +189,7 @@ def get_down_block(
181189
num_layers=num_layers,
182190
in_channels=in_channels,
183191
out_channels=out_channels,
192+
dropout=dropout,
184193
add_downsample=add_downsample,
185194
resnet_eps=resnet_eps,
186195
resnet_act_fn=resnet_act_fn,
@@ -193,6 +202,7 @@ def get_down_block(
193202
num_layers=num_layers,
194203
in_channels=in_channels,
195204
out_channels=out_channels,
205+
dropout=dropout,
196206
add_downsample=add_downsample,
197207
resnet_eps=resnet_eps,
198208
resnet_act_fn=resnet_act_fn,
@@ -207,6 +217,7 @@ def get_down_block(
207217
in_channels=in_channels,
208218
out_channels=out_channels,
209219
temb_channels=temb_channels,
220+
dropout=dropout,
210221
add_downsample=add_downsample,
211222
resnet_eps=resnet_eps,
212223
resnet_act_fn=resnet_act_fn,
@@ -217,6 +228,7 @@ def get_down_block(
217228
in_channels=in_channels,
218229
out_channels=out_channels,
219230
temb_channels=temb_channels,
231+
dropout=dropout,
220232
add_downsample=add_downsample,
221233
resnet_eps=resnet_eps,
222234
resnet_act_fn=resnet_act_fn,
@@ -252,6 +264,7 @@ def get_up_block(
252264
cross_attention_norm=None,
253265
attention_head_dim=None,
254266
upsample_type=None,
267+
dropout=0.0,
255268
):
256269
# If attn head dim is not defined, we default it to the number of heads
257270
if attention_head_dim is None:
@@ -268,6 +281,7 @@ def get_up_block(
268281
out_channels=out_channels,
269282
prev_output_channel=prev_output_channel,
270283
temb_channels=temb_channels,
284+
dropout=dropout,
271285
add_upsample=add_upsample,
272286
resnet_eps=resnet_eps,
273287
resnet_act_fn=resnet_act_fn,
@@ -281,6 +295,7 @@ def get_up_block(
281295
out_channels=out_channels,
282296
prev_output_channel=prev_output_channel,
283297
temb_channels=temb_channels,
298+
dropout=dropout,
284299
add_upsample=add_upsample,
285300
resnet_eps=resnet_eps,
286301
resnet_act_fn=resnet_act_fn,
@@ -299,6 +314,7 @@ def get_up_block(
299314
out_channels=out_channels,
300315
prev_output_channel=prev_output_channel,
301316
temb_channels=temb_channels,
317+
dropout=dropout,
302318
add_upsample=add_upsample,
303319
resnet_eps=resnet_eps,
304320
resnet_act_fn=resnet_act_fn,
@@ -321,6 +337,7 @@ def get_up_block(
321337
out_channels=out_channels,
322338
prev_output_channel=prev_output_channel,
323339
temb_channels=temb_channels,
340+
dropout=dropout,
324341
add_upsample=add_upsample,
325342
resnet_eps=resnet_eps,
326343
resnet_act_fn=resnet_act_fn,
@@ -345,6 +362,7 @@ def get_up_block(
345362
out_channels=out_channels,
346363
prev_output_channel=prev_output_channel,
347364
temb_channels=temb_channels,
365+
dropout=dropout,
348366
resnet_eps=resnet_eps,
349367
resnet_act_fn=resnet_act_fn,
350368
resnet_groups=resnet_groups,
@@ -359,6 +377,7 @@ def get_up_block(
359377
out_channels=out_channels,
360378
prev_output_channel=prev_output_channel,
361379
temb_channels=temb_channels,
380+
dropout=dropout,
362381
add_upsample=add_upsample,
363382
resnet_eps=resnet_eps,
364383
resnet_act_fn=resnet_act_fn,
@@ -371,6 +390,7 @@ def get_up_block(
371390
out_channels=out_channels,
372391
prev_output_channel=prev_output_channel,
373392
temb_channels=temb_channels,
393+
dropout=dropout,
374394
add_upsample=add_upsample,
375395
resnet_eps=resnet_eps,
376396
resnet_act_fn=resnet_act_fn,
@@ -382,6 +402,7 @@ def get_up_block(
382402
num_layers=num_layers,
383403
in_channels=in_channels,
384404
out_channels=out_channels,
405+
dropout=dropout,
385406
add_upsample=add_upsample,
386407
resnet_eps=resnet_eps,
387408
resnet_act_fn=resnet_act_fn,
@@ -394,6 +415,7 @@ def get_up_block(
394415
num_layers=num_layers,
395416
in_channels=in_channels,
396417
out_channels=out_channels,
418+
dropout=dropout,
397419
add_upsample=add_upsample,
398420
resnet_eps=resnet_eps,
399421
resnet_act_fn=resnet_act_fn,
@@ -408,6 +430,7 @@ def get_up_block(
408430
in_channels=in_channels,
409431
out_channels=out_channels,
410432
temb_channels=temb_channels,
433+
dropout=dropout,
411434
add_upsample=add_upsample,
412435
resnet_eps=resnet_eps,
413436
resnet_act_fn=resnet_act_fn,
@@ -418,6 +441,7 @@ def get_up_block(
418441
in_channels=in_channels,
419442
out_channels=out_channels,
420443
temb_channels=temb_channels,
444+
dropout=dropout,
421445
add_upsample=add_upsample,
422446
resnet_eps=resnet_eps,
423447
resnet_act_fn=resnet_act_fn,

src/diffusers/models/unet_2d_condition.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
9898
layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block.
9999
downsample_padding (`int`, *optional*, defaults to 1): The padding to use for the downsampling convolution.
100100
mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block.
101+
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
101102
act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
102103
norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization.
103104
If `None`, normalization and activation layers is skipped in post-processing.
@@ -178,6 +179,7 @@ def __init__(
178179
layers_per_block: Union[int, Tuple[int]] = 2,
179180
downsample_padding: int = 1,
180181
mid_block_scale_factor: float = 1,
182+
dropout: float = 0.0,
181183
act_fn: str = "silu",
182184
norm_num_groups: Optional[int] = 32,
183185
norm_eps: float = 1e-5,
@@ -459,6 +461,7 @@ def __init__(
459461
resnet_out_scale_factor=resnet_out_scale_factor,
460462
cross_attention_norm=cross_attention_norm,
461463
attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
464+
dropout=dropout,
462465
)
463466
self.down_blocks.append(down_block)
464467

@@ -468,6 +471,7 @@ def __init__(
468471
transformer_layers_per_block=transformer_layers_per_block[-1],
469472
in_channels=block_out_channels[-1],
470473
temb_channels=blocks_time_embed_dim,
474+
dropout=dropout,
471475
resnet_eps=norm_eps,
472476
resnet_act_fn=act_fn,
473477
output_scale_factor=mid_block_scale_factor,
@@ -484,6 +488,7 @@ def __init__(
484488
self.mid_block = UNetMidBlock2DSimpleCrossAttn(
485489
in_channels=block_out_channels[-1],
486490
temb_channels=blocks_time_embed_dim,
491+
dropout=dropout,
487492
resnet_eps=norm_eps,
488493
resnet_act_fn=act_fn,
489494
output_scale_factor=mid_block_scale_factor,
@@ -550,6 +555,7 @@ def __init__(
550555
resnet_out_scale_factor=resnet_out_scale_factor,
551556
cross_attention_norm=cross_attention_norm,
552557
attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
558+
dropout=dropout,
553559
)
554560
self.up_blocks.append(up_block)
555561
prev_output_channel = output_channel

src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ def get_down_block(
5858
resnet_skip_time_act=False,
5959
resnet_out_scale_factor=1.0,
6060
cross_attention_norm=None,
61+
dropout=0.0,
6162
):
6263
down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_type
6364
if down_block_type == "DownBlockFlat":
@@ -66,6 +67,7 @@ def get_down_block(
6667
in_channels=in_channels,
6768
out_channels=out_channels,
6869
temb_channels=temb_channels,
70+
dropout=dropout,
6971
add_downsample=add_downsample,
7072
resnet_eps=resnet_eps,
7173
resnet_act_fn=resnet_act_fn,
@@ -81,6 +83,7 @@ def get_down_block(
8183
in_channels=in_channels,
8284
out_channels=out_channels,
8385
temb_channels=temb_channels,
86+
dropout=dropout,
8487
add_downsample=add_downsample,
8588
resnet_eps=resnet_eps,
8689
resnet_act_fn=resnet_act_fn,
@@ -117,6 +120,7 @@ def get_up_block(
117120
resnet_skip_time_act=False,
118121
resnet_out_scale_factor=1.0,
119122
cross_attention_norm=None,
123+
dropout=0.0,
120124
):
121125
up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type
122126
if up_block_type == "UpBlockFlat":
@@ -126,6 +130,7 @@ def get_up_block(
126130
out_channels=out_channels,
127131
prev_output_channel=prev_output_channel,
128132
temb_channels=temb_channels,
133+
dropout=dropout,
129134
add_upsample=add_upsample,
130135
resnet_eps=resnet_eps,
131136
resnet_act_fn=resnet_act_fn,
@@ -141,6 +146,7 @@ def get_up_block(
141146
out_channels=out_channels,
142147
prev_output_channel=prev_output_channel,
143148
temb_channels=temb_channels,
149+
dropout=dropout,
144150
add_upsample=add_upsample,
145151
resnet_eps=resnet_eps,
146152
resnet_act_fn=resnet_act_fn,
@@ -284,6 +290,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
284290
layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block.
285291
downsample_padding (`int`, *optional*, defaults to 1): The padding to use for the downsampling convolution.
286292
mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block.
293+
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
287294
act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
288295
norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization.
289296
If `None`, normalization and activation layers is skipped in post-processing.
@@ -369,6 +376,7 @@ def __init__(
369376
layers_per_block: Union[int, Tuple[int]] = 2,
370377
downsample_padding: int = 1,
371378
mid_block_scale_factor: float = 1,
379+
dropout: float = 0.0,
372380
act_fn: str = "silu",
373381
norm_num_groups: Optional[int] = 32,
374382
norm_eps: float = 1e-5,
@@ -660,6 +668,7 @@ def __init__(
660668
resnet_out_scale_factor=resnet_out_scale_factor,
661669
cross_attention_norm=cross_attention_norm,
662670
attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
671+
dropout=dropout,
663672
)
664673
self.down_blocks.append(down_block)
665674

@@ -669,6 +678,7 @@ def __init__(
669678
transformer_layers_per_block=transformer_layers_per_block[-1],
670679
in_channels=block_out_channels[-1],
671680
temb_channels=blocks_time_embed_dim,
681+
dropout=dropout,
672682
resnet_eps=norm_eps,
673683
resnet_act_fn=act_fn,
674684
output_scale_factor=mid_block_scale_factor,
@@ -685,6 +695,7 @@ def __init__(
685695
self.mid_block = UNetMidBlockFlatSimpleCrossAttn(
686696
in_channels=block_out_channels[-1],
687697
temb_channels=blocks_time_embed_dim,
698+
dropout=dropout,
688699
resnet_eps=norm_eps,
689700
resnet_act_fn=act_fn,
690701
output_scale_factor=mid_block_scale_factor,
@@ -751,6 +762,7 @@ def __init__(
751762
resnet_out_scale_factor=resnet_out_scale_factor,
752763
cross_attention_norm=cross_attention_norm,
753764
attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
765+
dropout=dropout,
754766
)
755767
self.up_blocks.append(up_block)
756768
prev_output_channel = output_channel

0 commit comments

Comments
 (0)