@@ -104,17 +104,20 @@ class FlaxBasicTransformerBlock(nn.Module):
104104 Hidden states dimension inside each head
105105 dropout (:obj:`float`, *optional*, defaults to 0.0):
106106 Dropout rate
107+ only_cross_attention (`bool`, defaults to `False`):
108+ Whether to only apply cross attention.
107109 dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
108110 Parameters `dtype`
109111 """
110112 dim : int
111113 n_heads : int
112114 d_head : int
113115 dropout : float = 0.0
116+ only_cross_attention : bool = False
114117 dtype : jnp .dtype = jnp .float32
115118
116119 def setup (self ):
117- # self attention
120+ # self attention (or cross_attention if only_cross_attention is True)
118121 self .attn1 = FlaxAttentionBlock (self .dim , self .n_heads , self .d_head , self .dropout , dtype = self .dtype )
119122 # cross attention
120123 self .attn2 = FlaxAttentionBlock (self .dim , self .n_heads , self .d_head , self .dropout , dtype = self .dtype )
@@ -126,7 +129,10 @@ def setup(self):
126129 def __call__ (self , hidden_states , context , deterministic = True ):
127130 # self attention
128131 residual = hidden_states
129- hidden_states = self .attn1 (self .norm1 (hidden_states ), deterministic = deterministic )
132+ if self .only_cross_attention :
133+ hidden_states = self .attn1 (self .norm1 (hidden_states ), context , deterministic = deterministic )
134+ else :
135+ hidden_states = self .attn1 (self .norm1 (hidden_states ), deterministic = deterministic )
130136 hidden_states = hidden_states + residual
131137
132138 # cross attention
@@ -159,6 +165,8 @@ class FlaxTransformer2DModel(nn.Module):
159165 Number of transformers block
160166 dropout (:obj:`float`, *optional*, defaults to 0.0):
161167 Dropout rate
168+ use_linear_projection (`bool`, defaults to `False`): tbd
169+ only_cross_attention (`bool`, defaults to `False`): tbd
162170 dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
163171 Parameters `dtype`
164172 """
@@ -167,49 +175,70 @@ class FlaxTransformer2DModel(nn.Module):
167175 d_head : int
168176 depth : int = 1
169177 dropout : float = 0.0
178+ use_linear_projection : bool = False
179+ only_cross_attention : bool = False
170180 dtype : jnp .dtype = jnp .float32
171181
172182 def setup (self ):
173183 self .norm = nn .GroupNorm (num_groups = 32 , epsilon = 1e-5 )
174184
175185 inner_dim = self .n_heads * self .d_head
176- self .proj_in = nn .Conv (
177- inner_dim ,
178- kernel_size = (1 , 1 ),
179- strides = (1 , 1 ),
180- padding = "VALID" ,
181- dtype = self .dtype ,
182- )
186+ if self .use_linear_projection :
187+ self .proj_in = nn .Dense (inner_dim , dtype = self .dtype )
188+ else :
189+ self .proj_in = nn .Conv (
190+ inner_dim ,
191+ kernel_size = (1 , 1 ),
192+ strides = (1 , 1 ),
193+ padding = "VALID" ,
194+ dtype = self .dtype ,
195+ )
183196
184197 self .transformer_blocks = [
185- FlaxBasicTransformerBlock (inner_dim , self .n_heads , self .d_head , dropout = self .dropout , dtype = self .dtype )
198+ FlaxBasicTransformerBlock (
199+ inner_dim ,
200+ self .n_heads ,
201+ self .d_head ,
202+ dropout = self .dropout ,
203+ only_cross_attention = self .only_cross_attention ,
204+ dtype = self .dtype ,
205+ )
186206 for _ in range (self .depth )
187207 ]
188208
189- self .proj_out = nn .Conv (
190- inner_dim ,
191- kernel_size = (1 , 1 ),
192- strides = (1 , 1 ),
193- padding = "VALID" ,
194- dtype = self .dtype ,
195- )
209+ if self .use_linear_projection :
210+ self .proj_out = nn .Dense (inner_dim , dtype = self .dtype )
211+ else :
212+ self .proj_out = nn .Conv (
213+ inner_dim ,
214+ kernel_size = (1 , 1 ),
215+ strides = (1 , 1 ),
216+ padding = "VALID" ,
217+ dtype = self .dtype ,
218+ )
196219
197220 def __call__ (self , hidden_states , context , deterministic = True ):
198221 batch , height , width , channels = hidden_states .shape
199222 residual = hidden_states
200223 hidden_states = self .norm (hidden_states )
201- hidden_states = self .proj_in (hidden_states )
202-
203- hidden_states = hidden_states .reshape (batch , height * width , channels )
224+ if self .use_linear_projection :
225+ hidden_states = hidden_states .reshape (batch , height * width , channels )
226+ hidden_states = self .proj_in (hidden_states )
227+ else :
228+ hidden_states = self .proj_in (hidden_states )
229+ hidden_states = hidden_states .reshape (batch , height * width , channels )
204230
205231 for transformer_block in self .transformer_blocks :
206232 hidden_states = transformer_block (hidden_states , context , deterministic = deterministic )
207233
208- hidden_states = hidden_states .reshape (batch , height , width , channels )
234+ if self .use_linear_projection :
235+ hidden_states = self .proj_out (hidden_states )
236+ hidden_states = hidden_states .reshape (batch , height , width , channels )
237+ else :
238+ hidden_states = hidden_states .reshape (batch , height , width , channels )
239+ hidden_states = self .proj_out (hidden_states )
209240
210- hidden_states = self .proj_out (hidden_states )
211241 hidden_states = hidden_states + residual
212-
213242 return hidden_states
214243
215244
0 commit comments