|
| 1 | +# Copyright 2023 The HuggingFace Team. All rights reserved. |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | +import math |
| 15 | + |
| 16 | +import torch |
| 17 | +from torch import nn |
| 18 | + |
| 19 | +from ..configuration_utils import ConfigMixin, register_to_config |
| 20 | +from .attention_processor import Attention |
| 21 | +from .embeddings import get_timestep_embedding |
| 22 | +from .modeling_utils import ModelMixin |
| 23 | + |
| 24 | + |
| 25 | +class T5FilmDecoder(ModelMixin, ConfigMixin): |
| 26 | + @register_to_config |
| 27 | + def __init__( |
| 28 | + self, |
| 29 | + input_dims: int = 128, |
| 30 | + targets_length: int = 256, |
| 31 | + max_decoder_noise_time: float = 2000.0, |
| 32 | + d_model: int = 768, |
| 33 | + num_layers: int = 12, |
| 34 | + num_heads: int = 12, |
| 35 | + d_kv: int = 64, |
| 36 | + d_ff: int = 2048, |
| 37 | + dropout_rate: float = 0.1, |
| 38 | + ): |
| 39 | + super().__init__() |
| 40 | + |
| 41 | + self.conditioning_emb = nn.Sequential( |
| 42 | + nn.Linear(d_model, d_model * 4, bias=False), |
| 43 | + nn.SiLU(), |
| 44 | + nn.Linear(d_model * 4, d_model * 4, bias=False), |
| 45 | + nn.SiLU(), |
| 46 | + ) |
| 47 | + |
| 48 | + self.position_encoding = nn.Embedding(targets_length, d_model) |
| 49 | + self.position_encoding.weight.requires_grad = False |
| 50 | + |
| 51 | + self.continuous_inputs_projection = nn.Linear(input_dims, d_model, bias=False) |
| 52 | + |
| 53 | + self.dropout = nn.Dropout(p=dropout_rate) |
| 54 | + |
| 55 | + self.decoders = nn.ModuleList() |
| 56 | + for lyr_num in range(num_layers): |
| 57 | + # FiLM conditional T5 decoder |
| 58 | + lyr = DecoderLayer(d_model=d_model, d_kv=d_kv, num_heads=num_heads, d_ff=d_ff, dropout_rate=dropout_rate) |
| 59 | + self.decoders.append(lyr) |
| 60 | + |
| 61 | + self.decoder_norm = T5LayerNorm(d_model) |
| 62 | + |
| 63 | + self.post_dropout = nn.Dropout(p=dropout_rate) |
| 64 | + self.spec_out = nn.Linear(d_model, input_dims, bias=False) |
| 65 | + |
| 66 | + def encoder_decoder_mask(self, query_input, key_input): |
| 67 | + mask = torch.mul(query_input.unsqueeze(-1), key_input.unsqueeze(-2)) |
| 68 | + return mask.unsqueeze(-3) |
| 69 | + |
| 70 | + def forward(self, encodings_and_masks, decoder_input_tokens, decoder_noise_time): |
| 71 | + batch, _, _ = decoder_input_tokens.shape |
| 72 | + assert decoder_noise_time.shape == (batch,) |
| 73 | + |
| 74 | + # decoder_noise_time is in [0, 1), so rescale to expected timing range. |
| 75 | + time_steps = get_timestep_embedding( |
| 76 | + decoder_noise_time * self.config.max_decoder_noise_time, |
| 77 | + embedding_dim=self.config.d_model, |
| 78 | + max_period=self.config.max_decoder_noise_time, |
| 79 | + ).to(dtype=self.dtype) |
| 80 | + |
| 81 | + conditioning_emb = self.conditioning_emb(time_steps).unsqueeze(1) |
| 82 | + |
| 83 | + assert conditioning_emb.shape == (batch, 1, self.config.d_model * 4) |
| 84 | + |
| 85 | + seq_length = decoder_input_tokens.shape[1] |
| 86 | + |
| 87 | + # If we want to use relative positions for audio context, we can just offset |
| 88 | + # this sequence by the length of encodings_and_masks. |
| 89 | + decoder_positions = torch.broadcast_to( |
| 90 | + torch.arange(seq_length, device=decoder_input_tokens.device), |
| 91 | + (batch, seq_length), |
| 92 | + ) |
| 93 | + |
| 94 | + position_encodings = self.position_encoding(decoder_positions) |
| 95 | + |
| 96 | + inputs = self.continuous_inputs_projection(decoder_input_tokens) |
| 97 | + inputs += position_encodings |
| 98 | + y = self.dropout(inputs) |
| 99 | + |
| 100 | + # decoder: No padding present. |
| 101 | + decoder_mask = torch.ones( |
| 102 | + decoder_input_tokens.shape[:2], device=decoder_input_tokens.device, dtype=inputs.dtype |
| 103 | + ) |
| 104 | + |
| 105 | + # Translate encoding masks to encoder-decoder masks. |
| 106 | + encodings_and_encdec_masks = [(x, self.encoder_decoder_mask(decoder_mask, y)) for x, y in encodings_and_masks] |
| 107 | + |
| 108 | + # cross attend style: concat encodings |
| 109 | + encoded = torch.cat([x[0] for x in encodings_and_encdec_masks], dim=1) |
| 110 | + encoder_decoder_mask = torch.cat([x[1] for x in encodings_and_encdec_masks], dim=-1) |
| 111 | + |
| 112 | + for lyr in self.decoders: |
| 113 | + y = lyr( |
| 114 | + y, |
| 115 | + conditioning_emb=conditioning_emb, |
| 116 | + encoder_hidden_states=encoded, |
| 117 | + encoder_attention_mask=encoder_decoder_mask, |
| 118 | + )[0] |
| 119 | + |
| 120 | + y = self.decoder_norm(y) |
| 121 | + y = self.post_dropout(y) |
| 122 | + |
| 123 | + spec_out = self.spec_out(y) |
| 124 | + return spec_out |
| 125 | + |
| 126 | + |
| 127 | +class DecoderLayer(nn.Module): |
| 128 | + def __init__(self, d_model, d_kv, num_heads, d_ff, dropout_rate, layer_norm_epsilon=1e-6): |
| 129 | + super().__init__() |
| 130 | + self.layer = nn.ModuleList() |
| 131 | + |
| 132 | + # cond self attention: layer 0 |
| 133 | + self.layer.append( |
| 134 | + T5LayerSelfAttentionCond(d_model=d_model, d_kv=d_kv, num_heads=num_heads, dropout_rate=dropout_rate) |
| 135 | + ) |
| 136 | + |
| 137 | + # cross attention: layer 1 |
| 138 | + self.layer.append( |
| 139 | + T5LayerCrossAttention( |
| 140 | + d_model=d_model, |
| 141 | + d_kv=d_kv, |
| 142 | + num_heads=num_heads, |
| 143 | + dropout_rate=dropout_rate, |
| 144 | + layer_norm_epsilon=layer_norm_epsilon, |
| 145 | + ) |
| 146 | + ) |
| 147 | + |
| 148 | + # Film Cond MLP + dropout: last layer |
| 149 | + self.layer.append( |
| 150 | + T5LayerFFCond(d_model=d_model, d_ff=d_ff, dropout_rate=dropout_rate, layer_norm_epsilon=layer_norm_epsilon) |
| 151 | + ) |
| 152 | + |
| 153 | + def forward( |
| 154 | + self, |
| 155 | + hidden_states, |
| 156 | + conditioning_emb=None, |
| 157 | + attention_mask=None, |
| 158 | + encoder_hidden_states=None, |
| 159 | + encoder_attention_mask=None, |
| 160 | + encoder_decoder_position_bias=None, |
| 161 | + ): |
| 162 | + hidden_states = self.layer[0]( |
| 163 | + hidden_states, |
| 164 | + conditioning_emb=conditioning_emb, |
| 165 | + attention_mask=attention_mask, |
| 166 | + ) |
| 167 | + |
| 168 | + if encoder_hidden_states is not None: |
| 169 | + encoder_extended_attention_mask = torch.where(encoder_attention_mask > 0, 0, -1e10).to( |
| 170 | + encoder_hidden_states.dtype |
| 171 | + ) |
| 172 | + |
| 173 | + hidden_states = self.layer[1]( |
| 174 | + hidden_states, |
| 175 | + key_value_states=encoder_hidden_states, |
| 176 | + attention_mask=encoder_extended_attention_mask, |
| 177 | + ) |
| 178 | + |
| 179 | + # Apply Film Conditional Feed Forward layer |
| 180 | + hidden_states = self.layer[-1](hidden_states, conditioning_emb) |
| 181 | + |
| 182 | + return (hidden_states,) |
| 183 | + |
| 184 | + |
| 185 | +class T5LayerSelfAttentionCond(nn.Module): |
| 186 | + def __init__(self, d_model, d_kv, num_heads, dropout_rate): |
| 187 | + super().__init__() |
| 188 | + self.layer_norm = T5LayerNorm(d_model) |
| 189 | + self.FiLMLayer = T5FiLMLayer(in_features=d_model * 4, out_features=d_model) |
| 190 | + self.attention = Attention(query_dim=d_model, heads=num_heads, dim_head=d_kv, out_bias=False, scale_qk=False) |
| 191 | + self.dropout = nn.Dropout(dropout_rate) |
| 192 | + |
| 193 | + def forward( |
| 194 | + self, |
| 195 | + hidden_states, |
| 196 | + conditioning_emb=None, |
| 197 | + attention_mask=None, |
| 198 | + ): |
| 199 | + # pre_self_attention_layer_norm |
| 200 | + normed_hidden_states = self.layer_norm(hidden_states) |
| 201 | + |
| 202 | + if conditioning_emb is not None: |
| 203 | + normed_hidden_states = self.FiLMLayer(normed_hidden_states, conditioning_emb) |
| 204 | + |
| 205 | + # Self-attention block |
| 206 | + attention_output = self.attention(normed_hidden_states) |
| 207 | + |
| 208 | + hidden_states = hidden_states + self.dropout(attention_output) |
| 209 | + |
| 210 | + return hidden_states |
| 211 | + |
| 212 | + |
| 213 | +class T5LayerCrossAttention(nn.Module): |
| 214 | + def __init__(self, d_model, d_kv, num_heads, dropout_rate, layer_norm_epsilon): |
| 215 | + super().__init__() |
| 216 | + self.attention = Attention(query_dim=d_model, heads=num_heads, dim_head=d_kv, out_bias=False, scale_qk=False) |
| 217 | + self.layer_norm = T5LayerNorm(d_model, eps=layer_norm_epsilon) |
| 218 | + self.dropout = nn.Dropout(dropout_rate) |
| 219 | + |
| 220 | + def forward( |
| 221 | + self, |
| 222 | + hidden_states, |
| 223 | + key_value_states=None, |
| 224 | + attention_mask=None, |
| 225 | + ): |
| 226 | + normed_hidden_states = self.layer_norm(hidden_states) |
| 227 | + attention_output = self.attention( |
| 228 | + normed_hidden_states, |
| 229 | + encoder_hidden_states=key_value_states, |
| 230 | + attention_mask=attention_mask.squeeze(1), |
| 231 | + ) |
| 232 | + layer_output = hidden_states + self.dropout(attention_output) |
| 233 | + return layer_output |
| 234 | + |
| 235 | + |
| 236 | +class T5LayerFFCond(nn.Module): |
| 237 | + def __init__(self, d_model, d_ff, dropout_rate, layer_norm_epsilon): |
| 238 | + super().__init__() |
| 239 | + self.DenseReluDense = T5DenseGatedActDense(d_model=d_model, d_ff=d_ff, dropout_rate=dropout_rate) |
| 240 | + self.film = T5FiLMLayer(in_features=d_model * 4, out_features=d_model) |
| 241 | + self.layer_norm = T5LayerNorm(d_model, eps=layer_norm_epsilon) |
| 242 | + self.dropout = nn.Dropout(dropout_rate) |
| 243 | + |
| 244 | + def forward(self, hidden_states, conditioning_emb=None): |
| 245 | + forwarded_states = self.layer_norm(hidden_states) |
| 246 | + if conditioning_emb is not None: |
| 247 | + forwarded_states = self.film(forwarded_states, conditioning_emb) |
| 248 | + |
| 249 | + forwarded_states = self.DenseReluDense(forwarded_states) |
| 250 | + hidden_states = hidden_states + self.dropout(forwarded_states) |
| 251 | + return hidden_states |
| 252 | + |
| 253 | + |
| 254 | +class T5DenseGatedActDense(nn.Module): |
| 255 | + def __init__(self, d_model, d_ff, dropout_rate): |
| 256 | + super().__init__() |
| 257 | + self.wi_0 = nn.Linear(d_model, d_ff, bias=False) |
| 258 | + self.wi_1 = nn.Linear(d_model, d_ff, bias=False) |
| 259 | + self.wo = nn.Linear(d_ff, d_model, bias=False) |
| 260 | + self.dropout = nn.Dropout(dropout_rate) |
| 261 | + self.act = NewGELUActivation() |
| 262 | + |
| 263 | + def forward(self, hidden_states): |
| 264 | + hidden_gelu = self.act(self.wi_0(hidden_states)) |
| 265 | + hidden_linear = self.wi_1(hidden_states) |
| 266 | + hidden_states = hidden_gelu * hidden_linear |
| 267 | + hidden_states = self.dropout(hidden_states) |
| 268 | + |
| 269 | + hidden_states = self.wo(hidden_states) |
| 270 | + return hidden_states |
| 271 | + |
| 272 | + |
| 273 | +class T5LayerNorm(nn.Module): |
| 274 | + def __init__(self, hidden_size, eps=1e-6): |
| 275 | + """ |
| 276 | + Construct a layernorm module in the T5 style. No bias and no subtraction of mean. |
| 277 | + """ |
| 278 | + super().__init__() |
| 279 | + self.weight = nn.Parameter(torch.ones(hidden_size)) |
| 280 | + self.variance_epsilon = eps |
| 281 | + |
| 282 | + def forward(self, hidden_states): |
| 283 | + # T5 uses a layer_norm which only scales and doesn't shift, which is also known as Root Mean |
| 284 | + # Square Layer Normalization https://arxiv.org/abs/1910.07467 thus variance is calculated |
| 285 | + # w/o mean and there is no bias. Additionally we want to make sure that the accumulation for |
| 286 | + # half-precision inputs is done in fp32 |
| 287 | + |
| 288 | + variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True) |
| 289 | + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) |
| 290 | + |
| 291 | + # convert into half-precision if necessary |
| 292 | + if self.weight.dtype in [torch.float16, torch.bfloat16]: |
| 293 | + hidden_states = hidden_states.to(self.weight.dtype) |
| 294 | + |
| 295 | + return self.weight * hidden_states |
| 296 | + |
| 297 | + |
| 298 | +class NewGELUActivation(nn.Module): |
| 299 | + """ |
| 300 | + Implementation of the GELU activation function currently in Google BERT repo (identical to OpenAI GPT). Also see |
| 301 | + the Gaussian Error Linear Units paper: https://arxiv.org/abs/1606.08415 |
| 302 | + """ |
| 303 | + |
| 304 | + def forward(self, input: torch.Tensor) -> torch.Tensor: |
| 305 | + return 0.5 * input * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (input + 0.044715 * torch.pow(input, 3.0)))) |
| 306 | + |
| 307 | + |
| 308 | +class T5FiLMLayer(nn.Module): |
| 309 | + """ |
| 310 | + FiLM Layer |
| 311 | + """ |
| 312 | + |
| 313 | + def __init__(self, in_features, out_features): |
| 314 | + super().__init__() |
| 315 | + self.scale_bias = nn.Linear(in_features, out_features * 2, bias=False) |
| 316 | + |
| 317 | + def forward(self, x, conditioning_emb): |
| 318 | + emb = self.scale_bias(conditioning_emb) |
| 319 | + scale, shift = torch.chunk(emb, 2, -1) |
| 320 | + x = x * (1 + scale) + shift |
| 321 | + return x |
0 commit comments