11import math
2+ from typing import Optional
23
34import torch
45import torch .nn .functional as F
@@ -10,16 +11,24 @@ class AttentionBlock(nn.Module):
1011 An attention block that allows spatial positions to attend to each other. Originally ported from here, but adapted
1112 to the N-d case.
1213 https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
13- Uses three q, k, v linear layers to compute attention
14+ Uses three q, k, v linear layers to compute attention.
15+
16+ Parameters:
17+ channels (:obj:`int`): The number of channels in the input and output.
18+ num_head_channels (:obj:`int`, *optional*):
19+ The number of channels in each head. If None, then `num_heads` = 1.
20+ num_groups (:obj:`int`, *optional*, defaults to 32): The number of groups to use for group norm.
21+ rescale_output_factor (:obj:`float`, *optional*, defaults to 1.0): The factor to rescale the output by.
22+ eps (:obj:`float`, *optional*, defaults to 1e-5): The epsilon value to use for group norm.
1423 """
1524
1625 def __init__ (
1726 self ,
18- channels ,
19- num_head_channels = None ,
20- num_groups = 32 ,
21- rescale_output_factor = 1.0 ,
22- eps = 1e-5 ,
27+ channels : int ,
28+ num_head_channels : Optional [ int ] = None ,
29+ num_groups : int = 32 ,
30+ rescale_output_factor : float = 1.0 ,
31+ eps : float = 1e-5 ,
2332 ):
2433 super ().__init__ ()
2534 self .channels = channels
@@ -86,10 +95,26 @@ def forward(self, hidden_states):
8695class SpatialTransformer (nn .Module ):
8796 """
8897 Transformer block for image-like data. First, project the input (aka embedding) and reshape to b, t, d. Then apply
89- standard transformer action. Finally, reshape to image
98+ standard transformer action. Finally, reshape to image.
99+
100+ Parameters:
101+ in_channels (:obj:`int`): The number of channels in the input and output.
102+ n_heads (:obj:`int`): The number of heads to use for multi-head attention.
103+ d_head (:obj:`int`): The number of channels in each head.
104+ depth (:obj:`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
105+ dropout (:obj:`float`, *optional*, defaults to 0.1): The dropout probability to use.
106+ context_dim (:obj:`int`, *optional*): The number of context dimensions to use.
90107 """
91108
92- def __init__ (self , in_channels , n_heads , d_head , depth = 1 , dropout = 0.0 , context_dim = None ):
109+ def __init__ (
110+ self ,
111+ in_channels : int ,
112+ n_heads : int ,
113+ d_head : int ,
114+ depth : int = 1 ,
115+ dropout : float = 0.0 ,
116+ context_dim : Optional [int ] = None ,
117+ ):
93118 super ().__init__ ()
94119 self .n_heads = n_heads
95120 self .d_head = d_head
@@ -112,22 +137,44 @@ def _set_attention_slice(self, slice_size):
112137 for block in self .transformer_blocks :
113138 block ._set_attention_slice (slice_size )
114139
115- def forward (self , x , context = None ):
140+ def forward (self , hidden_states , context = None ):
116141 # note: if no context is given, cross-attention defaults to self-attention
117- b , c , h , w = x .shape
118- x_in = x
119- x = self .norm (x )
120- x = self .proj_in (x )
121- x = x .permute (0 , 2 , 3 , 1 ).reshape (b , h * w , c )
142+ batch , channel , height , weight = hidden_states .shape
143+ residual = hidden_states
144+ hidden_states = self .norm (hidden_states )
145+ hidden_states = self .proj_in (hidden_states )
146+ hidden_states = hidden_states .permute (0 , 2 , 3 , 1 ).reshape (batch , height * weight , channel )
122147 for block in self .transformer_blocks :
123- x = block (x , context = context )
124- x = x .reshape (b , h , w , c ).permute (0 , 3 , 1 , 2 )
125- x = self .proj_out (x )
126- return x + x_in
148+ hidden_states = block (hidden_states , context = context )
149+ hidden_states = hidden_states .reshape (batch , height , weight , channel ).permute (0 , 3 , 1 , 2 )
150+ hidden_states = self .proj_out (hidden_states )
151+ return hidden_states + residual
127152
128153
129154class BasicTransformerBlock (nn .Module ):
130- def __init__ (self , dim , n_heads , d_head , dropout = 0.0 , context_dim = None , gated_ff = True , checkpoint = True ):
155+ r"""
156+ A basic Transformer block.
157+
158+ Parameters:
159+ dim (:obj:`int`): The number of channels in the input and output.
160+ n_heads (:obj:`int`): The number of heads to use for multi-head attention.
161+ d_head (:obj:`int`): The number of channels in each head.
162+ dropout (:obj:`float`, *optional*, defaults to 0.0): The dropout probability to use.
163+ context_dim (:obj:`int`, *optional*): The size of the context vector for cross attention.
164+ gated_ff (:obj:`bool`, *optional*, defaults to :obj:`False`): Whether to use a gated feed-forward network.
165+ checkpoint (:obj:`bool`, *optional*, defaults to :obj:`False`): Whether to use checkpointing.
166+ """
167+
168+ def __init__ (
169+ self ,
170+ dim : int ,
171+ n_heads : int ,
172+ d_head : int ,
173+ dropout = 0.0 ,
174+ context_dim : Optional [int ] = None ,
175+ gated_ff : bool = True ,
176+ checkpoint : bool = True ,
177+ ):
131178 super ().__init__ ()
132179 self .attn1 = CrossAttention (
133180 query_dim = dim , heads = n_heads , dim_head = d_head , dropout = dropout
@@ -145,15 +192,30 @@ def _set_attention_slice(self, slice_size):
145192 self .attn1 ._slice_size = slice_size
146193 self .attn2 ._slice_size = slice_size
147194
148- def forward (self , x , context = None ):
149- x = self .attn1 (self .norm1 (x )) + x
150- x = self .attn2 (self .norm2 (x ), context = context ) + x
151- x = self .ff (self .norm3 (x )) + x
152- return x
195+ def forward (self , hidden_states , context = None ):
196+ hidden_states = hidden_states .contiguous () if hidden_states .device .type == "mps" else hidden_states
197+ hidden_states = self .attn1 (self .norm1 (hidden_states )) + hidden_states
198+ hidden_states = self .attn2 (self .norm2 (hidden_states ), context = context ) + hidden_states
199+ hidden_states = self .ff (self .norm3 (hidden_states )) + hidden_states
200+ return hidden_states
153201
154202
155203class CrossAttention (nn .Module ):
156- def __init__ (self , query_dim , context_dim = None , heads = 8 , dim_head = 64 , dropout = 0.0 ):
204+ r"""
205+ A cross attention layer.
206+
207+ Parameters:
208+ query_dim (:obj:`int`): The number of channels in the query.
209+ context_dim (:obj:`int`, *optional*):
210+ The number of channels in the context. If not given, defaults to `query_dim`.
211+ heads (:obj:`int`, *optional*, defaults to 8): The number of heads to use for multi-head attention.
212+ dim_head (:obj:`int`, *optional*, defaults to 64): The number of channels in each head.
213+ dropout (:obj:`float`, *optional*, defaults to 0.0): The dropout probability to use.
214+ """
215+
216+ def __init__ (
217+ self , query_dim : int , context_dim : Optional [int ] = None , heads : int = 8 , dim_head : int = 64 , dropout : int = 0.0
218+ ):
157219 super ().__init__ ()
158220 inner_dim = dim_head * heads
159221 context_dim = context_dim if context_dim is not None else query_dim
@@ -174,77 +236,104 @@ def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.
174236 def reshape_heads_to_batch_dim (self , tensor ):
175237 batch_size , seq_len , dim = tensor .shape
176238 head_size = self .heads
177- tensor = tensor .reshape (batch_size , seq_len , head_size , dim // head_size )
178- tensor = tensor .permute (0 , 2 , 1 , 3 ).reshape (batch_size * head_size , seq_len , dim // head_size )
179- return tensor
239+ tensor2 = tensor .reshape (batch_size , seq_len , head_size , dim // head_size )
240+ tensor3 = tensor2 .permute (0 , 2 , 1 , 3 ).reshape (batch_size * head_size , seq_len , dim // head_size )
241+ return tensor3
180242
181243 def reshape_batch_dim_to_heads (self , tensor ):
182244 batch_size , seq_len , dim = tensor .shape
183245 head_size = self .heads
184- tensor = tensor .reshape (batch_size // head_size , head_size , seq_len , dim )
185- tensor = tensor .permute (0 , 2 , 1 , 3 ).reshape (batch_size // head_size , seq_len , dim * head_size )
186- return tensor
246+ tensor2 = tensor .reshape (batch_size // head_size , head_size , seq_len , dim )
247+ tensor3 = tensor2 .permute (0 , 2 , 1 , 3 ).reshape (batch_size // head_size , seq_len , dim * head_size )
248+ return tensor3
187249
188- def forward (self , x , context = None , mask = None ):
189- batch_size , sequence_length , dim = x .shape
250+ def forward (self , hidden_states , context = None , mask = None ):
251+ batch_size , sequence_length , dim = hidden_states .shape
190252
191- q = self .to_q (x )
192- context = context if context is not None else x
193- k = self .to_k (context )
194- v = self .to_v (context )
253+ query = self .to_q (hidden_states )
254+ context = context if context is not None else hidden_states
255+ key = self .to_k (context )
256+ value = self .to_v (context )
195257
196- q = self .reshape_heads_to_batch_dim (q )
197- k = self .reshape_heads_to_batch_dim (k )
198- v = self .reshape_heads_to_batch_dim (v )
258+ query = self .reshape_heads_to_batch_dim (query )
259+ key = self .reshape_heads_to_batch_dim (key )
260+ value = self .reshape_heads_to_batch_dim (value )
199261
200262 # TODO(PVP) - mask is currently never used. Remember to re-implement when used
201263
202264 # attention, what we cannot get enough of
203- hidden_states = self ._attention (q , k , v , sequence_length , dim )
265+ hidden_states = self ._attention (query , key , value , sequence_length , dim )
204266
205267 return self .to_out (hidden_states )
206268
207269 def _attention (self , query , key , value , sequence_length , dim ):
208270 batch_size_attention = query .shape [0 ]
209- hidden_states = torch .zeros (
210- (batch_size_attention , sequence_length , dim // self .heads ), device = query .device , dtype = query .dtype
211- )
212- slice_size = self ._slice_size if self ._slice_size is not None else hidden_states .shape [0 ]
213- for i in range (hidden_states .shape [0 ] // slice_size ):
214- start_idx = i * slice_size
215- end_idx = (i + 1 ) * slice_size
216- attn_slice = (
217- torch .einsum ("b i d, b j d -> b i j" , query [start_idx :end_idx ], key [start_idx :end_idx ]) * self .scale
218- )
219- attn_slice = attn_slice .softmax (dim = - 1 )
220- attn_slice = torch .einsum ("b i j, b j d -> b i d" , attn_slice , value [start_idx :end_idx ])
221-
222- hidden_states [start_idx :end_idx ] = attn_slice
271+ # hidden_states = torch.zeros(
272+ # (batch_size_attention, sequence_length, dim // self.heads), device=query.device, dtype=query.dtype
273+ # )
274+ slice_size = self ._slice_size if self ._slice_size is not None else batch_size_attention
275+ # for i in range(hidden_states.shape[0] // slice_size):
276+ # start_idx = i * slice_size
277+ # end_idx = (i + 1) * slice_size
278+ # qslice = query[start_idx:end_idx]
279+ qslice = query
280+ # kslice = key[start_idx:end_idx].transpose(1, 2)
281+ kslice = key .transpose (1 , 2 )
282+ attn_slice = torch .matmul (qslice , kslice ) * self .scale
283+ attn_slice = attn_slice .softmax (dim = - 1 )
284+ # vslice = value[start_idx:end_idx]
285+ vslice = value
286+ hidden_states = torch .matmul (attn_slice , vslice )
287+
288+
289+ # hidden_states = torch.cat(attn_slices, dim=0)
290+
223291
224292 # reshape hidden_states
225293 hidden_states = self .reshape_batch_dim_to_heads (hidden_states )
226294 return hidden_states
227295
228296
229297class FeedForward (nn .Module ):
230- def __init__ (self , dim , dim_out = None , mult = 4 , glu = False , dropout = 0.0 ):
298+ r"""
299+ A feed-forward layer.
300+
301+ Parameters:
302+ dim (:obj:`int`): The number of channels in the input.
303+ dim_out (:obj:`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`.
304+ mult (:obj:`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension.
305+ glu (:obj:`bool`, *optional*, defaults to :obj:`False`): Whether to use GLU activation.
306+ dropout (:obj:`float`, *optional*, defaults to 0.0): The dropout probability to use.
307+ """
308+
309+ def __init__ (
310+ self , dim : int , dim_out : Optional [int ] = None , mult : int = 4 , glu : bool = False , dropout : float = 0.0
311+ ):
231312 super ().__init__ ()
232313 inner_dim = int (dim * mult )
233314 dim_out = dim_out if dim_out is not None else dim
234315 project_in = GEGLU (dim , inner_dim )
235316
236317 self .net = nn .Sequential (project_in , nn .Dropout (dropout ), nn .Linear (inner_dim , dim_out ))
237318
238- def forward (self , x ):
239- return self .net (x )
319+ def forward (self , hidden_states ):
320+ return self .net (hidden_states )
240321
241322
242323# feedforward
243324class GEGLU (nn .Module ):
244- def __init__ (self , dim_in , dim_out ):
325+ r"""
326+ A variant of the gated linear unit activation function from https://arxiv.org/abs/2002.05202.
327+
328+ Parameters:
329+ dim_in (:obj:`int`): The number of channels in the input.
330+ dim_out (:obj:`int`): The number of channels in the output.
331+ """
332+
333+ def __init__ (self , dim_in : int , dim_out : int ):
245334 super ().__init__ ()
246335 self .proj = nn .Linear (dim_in , dim_out * 2 )
247336
248- def forward (self , x ):
249- x , gate = self .proj (x ).chunk (2 , dim = - 1 )
250- return x * F .gelu (gate )
337+ def forward (self , hidden_states ):
338+ hidden_states , gate = self .proj (hidden_states ).chunk (2 , dim = - 1 )
339+ return hidden_states * F .gelu (gate )
0 commit comments