Skip to content

Commit 0c70c0e

Browse files
committed
attn refactoring
1 parent d30f968 commit 0c70c0e

File tree

1 file changed

+151
-62
lines changed

1 file changed

+151
-62
lines changed

src/diffusers/models/attention.py

Lines changed: 151 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import math
2+
from typing import Optional
23

34
import torch
45
import torch.nn.functional as F
@@ -10,16 +11,24 @@ class AttentionBlock(nn.Module):
1011
An attention block that allows spatial positions to attend to each other. Originally ported from here, but adapted
1112
to the N-d case.
1213
https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
13-
Uses three q, k, v linear layers to compute attention
14+
Uses three q, k, v linear layers to compute attention.
15+
16+
Parameters:
17+
channels (:obj:`int`): The number of channels in the input and output.
18+
num_head_channels (:obj:`int`, *optional*):
19+
The number of channels in each head. If None, then `num_heads` = 1.
20+
num_groups (:obj:`int`, *optional*, defaults to 32): The number of groups to use for group norm.
21+
rescale_output_factor (:obj:`float`, *optional*, defaults to 1.0): The factor to rescale the output by.
22+
eps (:obj:`float`, *optional*, defaults to 1e-5): The epsilon value to use for group norm.
1423
"""
1524

1625
def __init__(
1726
self,
18-
channels,
19-
num_head_channels=None,
20-
num_groups=32,
21-
rescale_output_factor=1.0,
22-
eps=1e-5,
27+
channels: int,
28+
num_head_channels: Optional[int] = None,
29+
num_groups: int = 32,
30+
rescale_output_factor: float = 1.0,
31+
eps: float = 1e-5,
2332
):
2433
super().__init__()
2534
self.channels = channels
@@ -86,10 +95,26 @@ def forward(self, hidden_states):
8695
class SpatialTransformer(nn.Module):
8796
"""
8897
Transformer block for image-like data. First, project the input (aka embedding) and reshape to b, t, d. Then apply
89-
standard transformer action. Finally, reshape to image
98+
standard transformer action. Finally, reshape to image.
99+
100+
Parameters:
101+
in_channels (:obj:`int`): The number of channels in the input and output.
102+
n_heads (:obj:`int`): The number of heads to use for multi-head attention.
103+
d_head (:obj:`int`): The number of channels in each head.
104+
depth (:obj:`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
105+
dropout (:obj:`float`, *optional*, defaults to 0.1): The dropout probability to use.
106+
context_dim (:obj:`int`, *optional*): The number of context dimensions to use.
90107
"""
91108

92-
def __init__(self, in_channels, n_heads, d_head, depth=1, dropout=0.0, context_dim=None):
109+
def __init__(
110+
self,
111+
in_channels: int,
112+
n_heads: int,
113+
d_head: int,
114+
depth: int = 1,
115+
dropout: float = 0.0,
116+
context_dim: Optional[int] = None,
117+
):
93118
super().__init__()
94119
self.n_heads = n_heads
95120
self.d_head = d_head
@@ -112,22 +137,44 @@ def _set_attention_slice(self, slice_size):
112137
for block in self.transformer_blocks:
113138
block._set_attention_slice(slice_size)
114139

115-
def forward(self, x, context=None):
140+
def forward(self, hidden_states, context=None):
116141
# note: if no context is given, cross-attention defaults to self-attention
117-
b, c, h, w = x.shape
118-
x_in = x
119-
x = self.norm(x)
120-
x = self.proj_in(x)
121-
x = x.permute(0, 2, 3, 1).reshape(b, h * w, c)
142+
batch, channel, height, weight = hidden_states.shape
143+
residual = hidden_states
144+
hidden_states = self.norm(hidden_states)
145+
hidden_states = self.proj_in(hidden_states)
146+
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, channel)
122147
for block in self.transformer_blocks:
123-
x = block(x, context=context)
124-
x = x.reshape(b, h, w, c).permute(0, 3, 1, 2)
125-
x = self.proj_out(x)
126-
return x + x_in
148+
hidden_states = block(hidden_states, context=context)
149+
hidden_states = hidden_states.reshape(batch, height, weight, channel).permute(0, 3, 1, 2)
150+
hidden_states = self.proj_out(hidden_states)
151+
return hidden_states + residual
127152

128153

129154
class BasicTransformerBlock(nn.Module):
130-
def __init__(self, dim, n_heads, d_head, dropout=0.0, context_dim=None, gated_ff=True, checkpoint=True):
155+
r"""
156+
A basic Transformer block.
157+
158+
Parameters:
159+
dim (:obj:`int`): The number of channels in the input and output.
160+
n_heads (:obj:`int`): The number of heads to use for multi-head attention.
161+
d_head (:obj:`int`): The number of channels in each head.
162+
dropout (:obj:`float`, *optional*, defaults to 0.0): The dropout probability to use.
163+
context_dim (:obj:`int`, *optional*): The size of the context vector for cross attention.
164+
gated_ff (:obj:`bool`, *optional*, defaults to :obj:`False`): Whether to use a gated feed-forward network.
165+
checkpoint (:obj:`bool`, *optional*, defaults to :obj:`False`): Whether to use checkpointing.
166+
"""
167+
168+
def __init__(
169+
self,
170+
dim: int,
171+
n_heads: int,
172+
d_head: int,
173+
dropout=0.0,
174+
context_dim: Optional[int] = None,
175+
gated_ff: bool = True,
176+
checkpoint: bool = True,
177+
):
131178
super().__init__()
132179
self.attn1 = CrossAttention(
133180
query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout
@@ -145,15 +192,30 @@ def _set_attention_slice(self, slice_size):
145192
self.attn1._slice_size = slice_size
146193
self.attn2._slice_size = slice_size
147194

148-
def forward(self, x, context=None):
149-
x = self.attn1(self.norm1(x)) + x
150-
x = self.attn2(self.norm2(x), context=context) + x
151-
x = self.ff(self.norm3(x)) + x
152-
return x
195+
def forward(self, hidden_states, context=None):
196+
hidden_states = hidden_states.contiguous() if hidden_states.device.type == "mps" else hidden_states
197+
hidden_states = self.attn1(self.norm1(hidden_states)) + hidden_states
198+
hidden_states = self.attn2(self.norm2(hidden_states), context=context) + hidden_states
199+
hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states
200+
return hidden_states
153201

154202

155203
class CrossAttention(nn.Module):
156-
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0):
204+
r"""
205+
A cross attention layer.
206+
207+
Parameters:
208+
query_dim (:obj:`int`): The number of channels in the query.
209+
context_dim (:obj:`int`, *optional*):
210+
The number of channels in the context. If not given, defaults to `query_dim`.
211+
heads (:obj:`int`, *optional*, defaults to 8): The number of heads to use for multi-head attention.
212+
dim_head (:obj:`int`, *optional*, defaults to 64): The number of channels in each head.
213+
dropout (:obj:`float`, *optional*, defaults to 0.0): The dropout probability to use.
214+
"""
215+
216+
def __init__(
217+
self, query_dim: int, context_dim: Optional[int] = None, heads: int = 8, dim_head: int = 64, dropout: int = 0.0
218+
):
157219
super().__init__()
158220
inner_dim = dim_head * heads
159221
context_dim = context_dim if context_dim is not None else query_dim
@@ -174,77 +236,104 @@ def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.
174236
def reshape_heads_to_batch_dim(self, tensor):
175237
batch_size, seq_len, dim = tensor.shape
176238
head_size = self.heads
177-
tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size)
178-
tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size * head_size, seq_len, dim // head_size)
179-
return tensor
239+
tensor2 = tensor.reshape(batch_size, seq_len, head_size, dim // head_size)
240+
tensor3 = tensor2.permute(0, 2, 1, 3).reshape(batch_size * head_size, seq_len, dim // head_size)
241+
return tensor3
180242

181243
def reshape_batch_dim_to_heads(self, tensor):
182244
batch_size, seq_len, dim = tensor.shape
183245
head_size = self.heads
184-
tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)
185-
tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size)
186-
return tensor
246+
tensor2 = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)
247+
tensor3 = tensor2.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size)
248+
return tensor3
187249

188-
def forward(self, x, context=None, mask=None):
189-
batch_size, sequence_length, dim = x.shape
250+
def forward(self, hidden_states, context=None, mask=None):
251+
batch_size, sequence_length, dim = hidden_states.shape
190252

191-
q = self.to_q(x)
192-
context = context if context is not None else x
193-
k = self.to_k(context)
194-
v = self.to_v(context)
253+
query = self.to_q(hidden_states)
254+
context = context if context is not None else hidden_states
255+
key = self.to_k(context)
256+
value = self.to_v(context)
195257

196-
q = self.reshape_heads_to_batch_dim(q)
197-
k = self.reshape_heads_to_batch_dim(k)
198-
v = self.reshape_heads_to_batch_dim(v)
258+
query = self.reshape_heads_to_batch_dim(query)
259+
key = self.reshape_heads_to_batch_dim(key)
260+
value = self.reshape_heads_to_batch_dim(value)
199261

200262
# TODO(PVP) - mask is currently never used. Remember to re-implement when used
201263

202264
# attention, what we cannot get enough of
203-
hidden_states = self._attention(q, k, v, sequence_length, dim)
265+
hidden_states = self._attention(query, key, value, sequence_length, dim)
204266

205267
return self.to_out(hidden_states)
206268

207269
def _attention(self, query, key, value, sequence_length, dim):
208270
batch_size_attention = query.shape[0]
209-
hidden_states = torch.zeros(
210-
(batch_size_attention, sequence_length, dim // self.heads), device=query.device, dtype=query.dtype
211-
)
212-
slice_size = self._slice_size if self._slice_size is not None else hidden_states.shape[0]
213-
for i in range(hidden_states.shape[0] // slice_size):
214-
start_idx = i * slice_size
215-
end_idx = (i + 1) * slice_size
216-
attn_slice = (
217-
torch.einsum("b i d, b j d -> b i j", query[start_idx:end_idx], key[start_idx:end_idx]) * self.scale
218-
)
219-
attn_slice = attn_slice.softmax(dim=-1)
220-
attn_slice = torch.einsum("b i j, b j d -> b i d", attn_slice, value[start_idx:end_idx])
221-
222-
hidden_states[start_idx:end_idx] = attn_slice
271+
# hidden_states = torch.zeros(
272+
# (batch_size_attention, sequence_length, dim // self.heads), device=query.device, dtype=query.dtype
273+
# )
274+
slice_size = self._slice_size if self._slice_size is not None else batch_size_attention
275+
# for i in range(hidden_states.shape[0] // slice_size):
276+
# start_idx = i * slice_size
277+
# end_idx = (i + 1) * slice_size
278+
# qslice = query[start_idx:end_idx]
279+
qslice = query
280+
# kslice = key[start_idx:end_idx].transpose(1, 2)
281+
kslice = key.transpose(1, 2)
282+
attn_slice = torch.matmul(qslice, kslice) * self.scale
283+
attn_slice = attn_slice.softmax(dim=-1)
284+
# vslice = value[start_idx:end_idx]
285+
vslice = value
286+
hidden_states = torch.matmul(attn_slice, vslice)
287+
288+
289+
# hidden_states = torch.cat(attn_slices, dim=0)
290+
223291

224292
# reshape hidden_states
225293
hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
226294
return hidden_states
227295

228296

229297
class FeedForward(nn.Module):
230-
def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.0):
298+
r"""
299+
A feed-forward layer.
300+
301+
Parameters:
302+
dim (:obj:`int`): The number of channels in the input.
303+
dim_out (:obj:`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`.
304+
mult (:obj:`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension.
305+
glu (:obj:`bool`, *optional*, defaults to :obj:`False`): Whether to use GLU activation.
306+
dropout (:obj:`float`, *optional*, defaults to 0.0): The dropout probability to use.
307+
"""
308+
309+
def __init__(
310+
self, dim: int, dim_out: Optional[int] = None, mult: int = 4, glu: bool = False, dropout: float = 0.0
311+
):
231312
super().__init__()
232313
inner_dim = int(dim * mult)
233314
dim_out = dim_out if dim_out is not None else dim
234315
project_in = GEGLU(dim, inner_dim)
235316

236317
self.net = nn.Sequential(project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out))
237318

238-
def forward(self, x):
239-
return self.net(x)
319+
def forward(self, hidden_states):
320+
return self.net(hidden_states)
240321

241322

242323
# feedforward
243324
class GEGLU(nn.Module):
244-
def __init__(self, dim_in, dim_out):
325+
r"""
326+
A variant of the gated linear unit activation function from https://arxiv.org/abs/2002.05202.
327+
328+
Parameters:
329+
dim_in (:obj:`int`): The number of channels in the input.
330+
dim_out (:obj:`int`): The number of channels in the output.
331+
"""
332+
333+
def __init__(self, dim_in: int, dim_out: int):
245334
super().__init__()
246335
self.proj = nn.Linear(dim_in, dim_out * 2)
247336

248-
def forward(self, x):
249-
x, gate = self.proj(x).chunk(2, dim=-1)
250-
return x * F.gelu(gate)
337+
def forward(self, hidden_states):
338+
hidden_states, gate = self.proj(hidden_states).chunk(2, dim=-1)
339+
return hidden_states * F.gelu(gate)

0 commit comments

Comments
 (0)