1- from typing import Dict , Union
1+ from typing import Dict , Optional , Tuple , Union
22
33import torch
44import torch .nn as nn
@@ -13,23 +13,28 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin):
1313 @register_to_config
1414 def __init__ (
1515 self ,
16- sample_size = None ,
17- in_channels = 4 ,
18- out_channels = 4 ,
19- center_input_sample = False ,
20- flip_sin_to_cos = True ,
21- freq_shift = 0 ,
22- down_block_types = ("CrossAttnDownBlock2D" , "CrossAttnDownBlock2D" , "CrossAttnDownBlock2D" , "DownBlock2D" ),
23- up_block_types = ("UpBlock2D" , "CrossAttnUpBlock2D" , "CrossAttnUpBlock2D" , "CrossAttnUpBlock2D" ),
24- block_out_channels = (320 , 640 , 1280 , 1280 ),
25- layers_per_block = 2 ,
26- downsample_padding = 1 ,
27- mid_block_scale_factor = 1 ,
28- act_fn = "silu" ,
29- norm_num_groups = 32 ,
30- norm_eps = 1e-5 ,
31- cross_attention_dim = 1280 ,
32- attention_head_dim = 8 ,
16+ sample_size : Optional [int ] = None ,
17+ in_channels : int = 4 ,
18+ out_channels : int = 4 ,
19+ center_input_sample : bool = False ,
20+ flip_sin_to_cos : bool = True ,
21+ freq_shift : int = 0 ,
22+ down_block_types : Tuple [str ] = (
23+ "CrossAttnDownBlock2D" ,
24+ "CrossAttnDownBlock2D" ,
25+ "CrossAttnDownBlock2D" ,
26+ "DownBlock2D" ,
27+ ),
28+ up_block_types : Tuple [str ] = ("UpBlock2D" , "CrossAttnUpBlock2D" , "CrossAttnUpBlock2D" , "CrossAttnUpBlock2D" ),
29+ block_out_channels : Tuple [int ] = (320 , 640 , 1280 , 1280 ),
30+ layers_per_block : int = 2 ,
31+ downsample_padding : int = 1 ,
32+ mid_block_scale_factor : float = 1 ,
33+ act_fn : str = "silu" ,
34+ norm_num_groups : int = 32 ,
35+ norm_eps : float = 1e-5 ,
36+ cross_attention_dim : int = 1280 ,
37+ attention_head_dim : int = 8 ,
3338 ):
3439 super ().__init__ ()
3540
0 commit comments