|
| 1 | +# Copyright 2022 The HuggingFace Team. All rights reserved. |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | + |
| 15 | +import flax.linen as nn |
| 16 | +import jax.numpy as jnp |
| 17 | + |
| 18 | + |
| 19 | +class FlaxAttentionBlock(nn.Module): |
| 20 | + query_dim: int |
| 21 | + heads: int = 8 |
| 22 | + dim_head: int = 64 |
| 23 | + dropout: float = 0.0 |
| 24 | + dtype: jnp.dtype = jnp.float32 |
| 25 | + |
| 26 | + def setup(self): |
| 27 | + inner_dim = self.dim_head * self.heads |
| 28 | + self.scale = self.dim_head**-0.5 |
| 29 | + |
| 30 | + # Weights were exported with old names {to_q, to_k, to_v, to_out} |
| 31 | + self.query = nn.Dense(inner_dim, use_bias=False, dtype=self.dtype, name="to_q") |
| 32 | + self.key = nn.Dense(inner_dim, use_bias=False, dtype=self.dtype, name="to_k") |
| 33 | + self.value = nn.Dense(inner_dim, use_bias=False, dtype=self.dtype, name="to_v") |
| 34 | + |
| 35 | + self.proj_attn = nn.Dense(self.query_dim, dtype=self.dtype, name="to_out") |
| 36 | + |
| 37 | + def reshape_heads_to_batch_dim(self, tensor): |
| 38 | + batch_size, seq_len, dim = tensor.shape |
| 39 | + head_size = self.heads |
| 40 | + tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size) |
| 41 | + tensor = jnp.transpose(tensor, (0, 2, 1, 3)) |
| 42 | + tensor = tensor.reshape(batch_size * head_size, seq_len, dim // head_size) |
| 43 | + return tensor |
| 44 | + |
| 45 | + def reshape_batch_dim_to_heads(self, tensor): |
| 46 | + batch_size, seq_len, dim = tensor.shape |
| 47 | + head_size = self.heads |
| 48 | + tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim) |
| 49 | + tensor = jnp.transpose(tensor, (0, 2, 1, 3)) |
| 50 | + tensor = tensor.reshape(batch_size // head_size, seq_len, dim * head_size) |
| 51 | + return tensor |
| 52 | + |
| 53 | + def __call__(self, hidden_states, context=None, deterministic=True): |
| 54 | + context = hidden_states if context is None else context |
| 55 | + |
| 56 | + query_proj = self.query(hidden_states) |
| 57 | + key_proj = self.key(context) |
| 58 | + value_proj = self.value(context) |
| 59 | + |
| 60 | + query_states = self.reshape_heads_to_batch_dim(query_proj) |
| 61 | + key_states = self.reshape_heads_to_batch_dim(key_proj) |
| 62 | + value_states = self.reshape_heads_to_batch_dim(value_proj) |
| 63 | + |
| 64 | + # compute attentions |
| 65 | + attention_scores = jnp.einsum("b i d, b j d->b i j", query_states, key_states) |
| 66 | + attention_scores = attention_scores * self.scale |
| 67 | + attention_probs = nn.softmax(attention_scores, axis=2) |
| 68 | + |
| 69 | + # attend to values |
| 70 | + hidden_states = jnp.einsum("b i j, b j d -> b i d", attention_probs, value_states) |
| 71 | + hidden_states = self.reshape_batch_dim_to_heads(hidden_states) |
| 72 | + hidden_states = self.proj_attn(hidden_states) |
| 73 | + return hidden_states |
| 74 | + |
| 75 | + |
| 76 | +class FlaxBasicTransformerBlock(nn.Module): |
| 77 | + dim: int |
| 78 | + n_heads: int |
| 79 | + d_head: int |
| 80 | + dropout: float = 0.0 |
| 81 | + dtype: jnp.dtype = jnp.float32 |
| 82 | + |
| 83 | + def setup(self): |
| 84 | + # self attention |
| 85 | + self.self_attn = FlaxAttentionBlock(self.dim, self.n_heads, self.d_head, self.dropout, dtype=self.dtype) |
| 86 | + # cross attention |
| 87 | + self.cross_attn = FlaxAttentionBlock(self.dim, self.n_heads, self.d_head, self.dropout, dtype=self.dtype) |
| 88 | + self.ff = FlaxGluFeedForward(dim=self.dim, dropout=self.dropout, dtype=self.dtype) |
| 89 | + self.norm1 = nn.LayerNorm(epsilon=1e-5, dtype=self.dtype) |
| 90 | + self.norm2 = nn.LayerNorm(epsilon=1e-5, dtype=self.dtype) |
| 91 | + self.norm3 = nn.LayerNorm(epsilon=1e-5, dtype=self.dtype) |
| 92 | + |
| 93 | + def __call__(self, hidden_states, context, deterministic=True): |
| 94 | + # self attention |
| 95 | + residual = hidden_states |
| 96 | + hidden_states = self.self_attn(self.norm1(hidden_states), deterministic=deterministic) |
| 97 | + hidden_states = hidden_states + residual |
| 98 | + |
| 99 | + # cross attention |
| 100 | + residual = hidden_states |
| 101 | + hidden_states = self.cross_attn(self.norm2(hidden_states), context, deterministic=deterministic) |
| 102 | + hidden_states = hidden_states + residual |
| 103 | + |
| 104 | + # feed forward |
| 105 | + residual = hidden_states |
| 106 | + hidden_states = self.ff(self.norm3(hidden_states), deterministic=deterministic) |
| 107 | + hidden_states = hidden_states + residual |
| 108 | + |
| 109 | + return hidden_states |
| 110 | + |
| 111 | + |
| 112 | +class FlaxSpatialTransformer(nn.Module): |
| 113 | + in_channels: int |
| 114 | + n_heads: int |
| 115 | + d_head: int |
| 116 | + depth: int = 1 |
| 117 | + dropout: float = 0.0 |
| 118 | + dtype: jnp.dtype = jnp.float32 |
| 119 | + |
| 120 | + def setup(self): |
| 121 | + self.norm = nn.GroupNorm(num_groups=32, epsilon=1e-5) |
| 122 | + |
| 123 | + inner_dim = self.n_heads * self.d_head |
| 124 | + self.proj_in = nn.Conv( |
| 125 | + inner_dim, |
| 126 | + kernel_size=(1, 1), |
| 127 | + strides=(1, 1), |
| 128 | + padding="VALID", |
| 129 | + dtype=self.dtype, |
| 130 | + ) |
| 131 | + |
| 132 | + self.transformer_blocks = [ |
| 133 | + FlaxBasicTransformerBlock(inner_dim, self.n_heads, self.d_head, dropout=self.dropout, dtype=self.dtype) |
| 134 | + for _ in range(self.depth) |
| 135 | + ] |
| 136 | + |
| 137 | + self.proj_out = nn.Conv( |
| 138 | + inner_dim, |
| 139 | + kernel_size=(1, 1), |
| 140 | + strides=(1, 1), |
| 141 | + padding="VALID", |
| 142 | + dtype=self.dtype, |
| 143 | + ) |
| 144 | + |
| 145 | + def __call__(self, hidden_states, context, deterministic=True): |
| 146 | + batch, height, width, channels = hidden_states.shape |
| 147 | + # import ipdb; ipdb.set_trace() |
| 148 | + residual = hidden_states |
| 149 | + hidden_states = self.norm(hidden_states) |
| 150 | + hidden_states = self.proj_in(hidden_states) |
| 151 | + |
| 152 | + hidden_states = hidden_states.reshape(batch, height * width, channels) |
| 153 | + |
| 154 | + for transformer_block in self.transformer_blocks: |
| 155 | + hidden_states = transformer_block(hidden_states, context, deterministic=deterministic) |
| 156 | + |
| 157 | + hidden_states = hidden_states.reshape(batch, height, width, channels) |
| 158 | + |
| 159 | + hidden_states = self.proj_out(hidden_states) |
| 160 | + hidden_states = hidden_states + residual |
| 161 | + |
| 162 | + return hidden_states |
| 163 | + |
| 164 | + |
| 165 | +class FlaxGluFeedForward(nn.Module): |
| 166 | + dim: int |
| 167 | + dropout: float = 0.0 |
| 168 | + dtype: jnp.dtype = jnp.float32 |
| 169 | + |
| 170 | + def setup(self): |
| 171 | + inner_dim = self.dim * 4 |
| 172 | + self.dense1 = nn.Dense(inner_dim * 2, dtype=self.dtype) |
| 173 | + self.dense2 = nn.Dense(self.dim, dtype=self.dtype) |
| 174 | + |
| 175 | + def __call__(self, hidden_states, deterministic=True): |
| 176 | + hidden_states = self.dense1(hidden_states) |
| 177 | + hidden_linear, hidden_gelu = jnp.split(hidden_states, 2, axis=2) |
| 178 | + hidden_states = hidden_linear * nn.gelu(hidden_gelu) |
| 179 | + hidden_states = self.dense2(hidden_states) |
| 180 | + return hidden_states |
0 commit comments