Skip to content

Commit 3219e39

Browse files
kashifpatrickvonplatenanton-l
authored
Music Spectrogram diffusion pipeline (huggingface#1044)
* initial TokenEncoder and ContinuousEncoder * initial modules * added ContinuousContextTransformer * fix copy paste error * use numpy for get_sequence_length * initial terminal relative positional encodings * fix weights keys * fix assert * cross attend style: concat encodings * make style * concat once * fix formatting * Initial SpectrogramPipeline * fix input_tokens * make style * added mel output * ignore weights for config * move mel to numpy * import pipeline * fix class names and import * moved models to models folder * import ContinuousContextTransformer and SpectrogramDiffusionPipeline * initial spec diffusion converstion script * renamed config to t5config * added weight loading * use arguments instead of t5config * broadcast noise time to batch dim * fix call * added scale_to_features * fix weights * transpose laynorm weight * scale is a vector * scale the query outputs * added comment * undo scaling * undo depth_scaling * inital get_extended_attention_mask * attention_mask is none in self-attention * cleanup * manually invert attention * nn.linear need bias=False * added T5LayerFFCond * remove to fix conflict * make style and dummy * remove unsed variables * remove predict_epsilon * Move accelerate to a soft-dependency (huggingface#1134) * finish * finish * Update src/diffusers/modeling_utils.py * Update src/diffusers/pipeline_utils.py Co-authored-by: Anton Lozhkov <[email protected]> * more fixes * fix Co-authored-by: Anton Lozhkov <[email protected]> * fix order * added initial midi to note token data pipeline * added int to int tokenizer * remove duplicate * added logic for segments * add melgan to pipeline * move autoregressive gen into pipeline * added note_representation_processor_chain * fix dtypes * remove immutabledict req * initial doc * use np.where * require note_seq * fix typo * update dependency * added note-seq to test * added is_note_seq_available * fix import * added toc * added example usage * undo for now * moved docs * fix merge * fix imports * predict first segment * avoid un-needed copy to and from cpu * make style * Copyright * fix style * add test and fix inference steps * remove bogus files * reorder models * up * remove transformers dependency * make work with diffusers cross attention * clean more * remove @ * improve further * up * uP * Apply suggestions from code review * Update tests/pipelines/spectrogram_diffusion/test_spectrogram_diffusion.py * loop over all tokens * make style * Added a section on the model * fix formatting * grammer * formatting * make fix-copies * Update src/diffusers/pipelines/__init__.py Co-authored-by: Patrick von Platen <[email protected]> * Update src/diffusers/pipelines/spectrogram_diffusion/pipeline_spectrogram_diffusion.py Co-authored-by: Patrick von Platen <[email protected]> * added callback ad optional ionnx * do not squeeze batch dim * clean up more * upload * convert jax to nnumpy * make style * fix warning * make fix-copies * fix warning * add initial fast tests * add initial pipeline_params * eval mode due to dropout * skip batch tests as pipeline runs on a single file * make style * fix relative path * fix doc tests * Update src/diffusers/models/t5_film_transformer.py Co-authored-by: Patrick von Platen <[email protected]> * Update src/diffusers/models/t5_film_transformer.py Co-authored-by: Patrick von Platen <[email protected]> * Update docs/source/en/api/pipelines/spectrogram_diffusion.mdx Co-authored-by: Patrick von Platen <[email protected]> * Update tests/pipelines/spectrogram_diffusion/test_spectrogram_diffusion.py Co-authored-by: Patrick von Platen <[email protected]> * Update tests/pipelines/spectrogram_diffusion/test_spectrogram_diffusion.py Co-authored-by: Patrick von Platen <[email protected]> * Update tests/pipelines/spectrogram_diffusion/test_spectrogram_diffusion.py Co-authored-by: Patrick von Platen <[email protected]> * Update tests/pipelines/spectrogram_diffusion/test_spectrogram_diffusion.py Co-authored-by: Patrick von Platen <[email protected]> * add MidiProcessor * format * fix org * Apply suggestions from code review * Update tests/pipelines/spectrogram_diffusion/test_spectrogram_diffusion.py * make style * pin protobuf to <4 * fix formatting * white space * tensorboard needs protobuf --------- Co-authored-by: Patrick von Platen <[email protected]> Co-authored-by: Anton Lozhkov <[email protected]>
1 parent 9448241 commit 3219e39

16 files changed

+1495
-0
lines changed

__init__.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
is_k_diffusion_available,
99
is_k_diffusion_version,
1010
is_librosa_available,
11+
is_note_seq_available,
1112
is_onnx_available,
1213
is_scipy_available,
1314
is_torch_available,
@@ -37,6 +38,7 @@
3738
ControlNetModel,
3839
ModelMixin,
3940
PriorTransformer,
41+
T5FilmDecoder,
4042
Transformer2DModel,
4143
UNet1DModel,
4244
UNet2DConditionModel,
@@ -172,6 +174,14 @@
172174
else:
173175
from .pipelines import AudioDiffusionPipeline, Mel
174176

177+
try:
178+
if not (is_torch_available() and is_note_seq_available()):
179+
raise OptionalDependencyNotAvailable()
180+
except OptionalDependencyNotAvailable:
181+
from .utils.dummy_torch_and_note_seq_objects import * # noqa F403
182+
else:
183+
from .pipelines import SpectrogramDiffusionPipeline
184+
175185
try:
176186
if not is_flax_available():
177187
raise OptionalDependencyNotAvailable()
@@ -205,3 +215,11 @@
205215
FlaxStableDiffusionInpaintPipeline,
206216
FlaxStableDiffusionPipeline,
207217
)
218+
219+
try:
220+
if not (is_note_seq_available()):
221+
raise OptionalDependencyNotAvailable()
222+
except OptionalDependencyNotAvailable:
223+
from .utils.dummy_note_seq_objects import * # noqa F403
224+
else:
225+
from .pipelines import MidiProcessor

dependency_versions_table.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,10 @@
1919
"Jinja2": "Jinja2",
2020
"k-diffusion": "k-diffusion>=0.0.12",
2121
"librosa": "librosa",
22+
"note-seq": "note-seq",
2223
"numpy": "numpy",
2324
"parameterized": "parameterized",
25+
"protobuf": "protobuf>=3.20.3,<4",
2426
"pytest": "pytest",
2527
"pytest-timeout": "pytest-timeout",
2628
"pytest-xdist": "pytest-xdist",

models/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from .dual_transformer_2d import DualTransformer2DModel
2222
from .modeling_utils import ModelMixin
2323
from .prior_transformer import PriorTransformer
24+
from .t5_film_transformer import T5FilmDecoder
2425
from .transformer_2d import Transformer2DModel
2526
from .unet_1d import UNet1DModel
2627
from .unet_2d import UNet2DModel

models/t5_film_transformer.py

Lines changed: 321 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,321 @@
1+
# Copyright 2023 The HuggingFace Team. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
import math
15+
16+
import torch
17+
from torch import nn
18+
19+
from ..configuration_utils import ConfigMixin, register_to_config
20+
from .attention_processor import Attention
21+
from .embeddings import get_timestep_embedding
22+
from .modeling_utils import ModelMixin
23+
24+
25+
class T5FilmDecoder(ModelMixin, ConfigMixin):
26+
@register_to_config
27+
def __init__(
28+
self,
29+
input_dims: int = 128,
30+
targets_length: int = 256,
31+
max_decoder_noise_time: float = 2000.0,
32+
d_model: int = 768,
33+
num_layers: int = 12,
34+
num_heads: int = 12,
35+
d_kv: int = 64,
36+
d_ff: int = 2048,
37+
dropout_rate: float = 0.1,
38+
):
39+
super().__init__()
40+
41+
self.conditioning_emb = nn.Sequential(
42+
nn.Linear(d_model, d_model * 4, bias=False),
43+
nn.SiLU(),
44+
nn.Linear(d_model * 4, d_model * 4, bias=False),
45+
nn.SiLU(),
46+
)
47+
48+
self.position_encoding = nn.Embedding(targets_length, d_model)
49+
self.position_encoding.weight.requires_grad = False
50+
51+
self.continuous_inputs_projection = nn.Linear(input_dims, d_model, bias=False)
52+
53+
self.dropout = nn.Dropout(p=dropout_rate)
54+
55+
self.decoders = nn.ModuleList()
56+
for lyr_num in range(num_layers):
57+
# FiLM conditional T5 decoder
58+
lyr = DecoderLayer(d_model=d_model, d_kv=d_kv, num_heads=num_heads, d_ff=d_ff, dropout_rate=dropout_rate)
59+
self.decoders.append(lyr)
60+
61+
self.decoder_norm = T5LayerNorm(d_model)
62+
63+
self.post_dropout = nn.Dropout(p=dropout_rate)
64+
self.spec_out = nn.Linear(d_model, input_dims, bias=False)
65+
66+
def encoder_decoder_mask(self, query_input, key_input):
67+
mask = torch.mul(query_input.unsqueeze(-1), key_input.unsqueeze(-2))
68+
return mask.unsqueeze(-3)
69+
70+
def forward(self, encodings_and_masks, decoder_input_tokens, decoder_noise_time):
71+
batch, _, _ = decoder_input_tokens.shape
72+
assert decoder_noise_time.shape == (batch,)
73+
74+
# decoder_noise_time is in [0, 1), so rescale to expected timing range.
75+
time_steps = get_timestep_embedding(
76+
decoder_noise_time * self.config.max_decoder_noise_time,
77+
embedding_dim=self.config.d_model,
78+
max_period=self.config.max_decoder_noise_time,
79+
).to(dtype=self.dtype)
80+
81+
conditioning_emb = self.conditioning_emb(time_steps).unsqueeze(1)
82+
83+
assert conditioning_emb.shape == (batch, 1, self.config.d_model * 4)
84+
85+
seq_length = decoder_input_tokens.shape[1]
86+
87+
# If we want to use relative positions for audio context, we can just offset
88+
# this sequence by the length of encodings_and_masks.
89+
decoder_positions = torch.broadcast_to(
90+
torch.arange(seq_length, device=decoder_input_tokens.device),
91+
(batch, seq_length),
92+
)
93+
94+
position_encodings = self.position_encoding(decoder_positions)
95+
96+
inputs = self.continuous_inputs_projection(decoder_input_tokens)
97+
inputs += position_encodings
98+
y = self.dropout(inputs)
99+
100+
# decoder: No padding present.
101+
decoder_mask = torch.ones(
102+
decoder_input_tokens.shape[:2], device=decoder_input_tokens.device, dtype=inputs.dtype
103+
)
104+
105+
# Translate encoding masks to encoder-decoder masks.
106+
encodings_and_encdec_masks = [(x, self.encoder_decoder_mask(decoder_mask, y)) for x, y in encodings_and_masks]
107+
108+
# cross attend style: concat encodings
109+
encoded = torch.cat([x[0] for x in encodings_and_encdec_masks], dim=1)
110+
encoder_decoder_mask = torch.cat([x[1] for x in encodings_and_encdec_masks], dim=-1)
111+
112+
for lyr in self.decoders:
113+
y = lyr(
114+
y,
115+
conditioning_emb=conditioning_emb,
116+
encoder_hidden_states=encoded,
117+
encoder_attention_mask=encoder_decoder_mask,
118+
)[0]
119+
120+
y = self.decoder_norm(y)
121+
y = self.post_dropout(y)
122+
123+
spec_out = self.spec_out(y)
124+
return spec_out
125+
126+
127+
class DecoderLayer(nn.Module):
128+
def __init__(self, d_model, d_kv, num_heads, d_ff, dropout_rate, layer_norm_epsilon=1e-6):
129+
super().__init__()
130+
self.layer = nn.ModuleList()
131+
132+
# cond self attention: layer 0
133+
self.layer.append(
134+
T5LayerSelfAttentionCond(d_model=d_model, d_kv=d_kv, num_heads=num_heads, dropout_rate=dropout_rate)
135+
)
136+
137+
# cross attention: layer 1
138+
self.layer.append(
139+
T5LayerCrossAttention(
140+
d_model=d_model,
141+
d_kv=d_kv,
142+
num_heads=num_heads,
143+
dropout_rate=dropout_rate,
144+
layer_norm_epsilon=layer_norm_epsilon,
145+
)
146+
)
147+
148+
# Film Cond MLP + dropout: last layer
149+
self.layer.append(
150+
T5LayerFFCond(d_model=d_model, d_ff=d_ff, dropout_rate=dropout_rate, layer_norm_epsilon=layer_norm_epsilon)
151+
)
152+
153+
def forward(
154+
self,
155+
hidden_states,
156+
conditioning_emb=None,
157+
attention_mask=None,
158+
encoder_hidden_states=None,
159+
encoder_attention_mask=None,
160+
encoder_decoder_position_bias=None,
161+
):
162+
hidden_states = self.layer[0](
163+
hidden_states,
164+
conditioning_emb=conditioning_emb,
165+
attention_mask=attention_mask,
166+
)
167+
168+
if encoder_hidden_states is not None:
169+
encoder_extended_attention_mask = torch.where(encoder_attention_mask > 0, 0, -1e10).to(
170+
encoder_hidden_states.dtype
171+
)
172+
173+
hidden_states = self.layer[1](
174+
hidden_states,
175+
key_value_states=encoder_hidden_states,
176+
attention_mask=encoder_extended_attention_mask,
177+
)
178+
179+
# Apply Film Conditional Feed Forward layer
180+
hidden_states = self.layer[-1](hidden_states, conditioning_emb)
181+
182+
return (hidden_states,)
183+
184+
185+
class T5LayerSelfAttentionCond(nn.Module):
186+
def __init__(self, d_model, d_kv, num_heads, dropout_rate):
187+
super().__init__()
188+
self.layer_norm = T5LayerNorm(d_model)
189+
self.FiLMLayer = T5FiLMLayer(in_features=d_model * 4, out_features=d_model)
190+
self.attention = Attention(query_dim=d_model, heads=num_heads, dim_head=d_kv, out_bias=False, scale_qk=False)
191+
self.dropout = nn.Dropout(dropout_rate)
192+
193+
def forward(
194+
self,
195+
hidden_states,
196+
conditioning_emb=None,
197+
attention_mask=None,
198+
):
199+
# pre_self_attention_layer_norm
200+
normed_hidden_states = self.layer_norm(hidden_states)
201+
202+
if conditioning_emb is not None:
203+
normed_hidden_states = self.FiLMLayer(normed_hidden_states, conditioning_emb)
204+
205+
# Self-attention block
206+
attention_output = self.attention(normed_hidden_states)
207+
208+
hidden_states = hidden_states + self.dropout(attention_output)
209+
210+
return hidden_states
211+
212+
213+
class T5LayerCrossAttention(nn.Module):
214+
def __init__(self, d_model, d_kv, num_heads, dropout_rate, layer_norm_epsilon):
215+
super().__init__()
216+
self.attention = Attention(query_dim=d_model, heads=num_heads, dim_head=d_kv, out_bias=False, scale_qk=False)
217+
self.layer_norm = T5LayerNorm(d_model, eps=layer_norm_epsilon)
218+
self.dropout = nn.Dropout(dropout_rate)
219+
220+
def forward(
221+
self,
222+
hidden_states,
223+
key_value_states=None,
224+
attention_mask=None,
225+
):
226+
normed_hidden_states = self.layer_norm(hidden_states)
227+
attention_output = self.attention(
228+
normed_hidden_states,
229+
encoder_hidden_states=key_value_states,
230+
attention_mask=attention_mask.squeeze(1),
231+
)
232+
layer_output = hidden_states + self.dropout(attention_output)
233+
return layer_output
234+
235+
236+
class T5LayerFFCond(nn.Module):
237+
def __init__(self, d_model, d_ff, dropout_rate, layer_norm_epsilon):
238+
super().__init__()
239+
self.DenseReluDense = T5DenseGatedActDense(d_model=d_model, d_ff=d_ff, dropout_rate=dropout_rate)
240+
self.film = T5FiLMLayer(in_features=d_model * 4, out_features=d_model)
241+
self.layer_norm = T5LayerNorm(d_model, eps=layer_norm_epsilon)
242+
self.dropout = nn.Dropout(dropout_rate)
243+
244+
def forward(self, hidden_states, conditioning_emb=None):
245+
forwarded_states = self.layer_norm(hidden_states)
246+
if conditioning_emb is not None:
247+
forwarded_states = self.film(forwarded_states, conditioning_emb)
248+
249+
forwarded_states = self.DenseReluDense(forwarded_states)
250+
hidden_states = hidden_states + self.dropout(forwarded_states)
251+
return hidden_states
252+
253+
254+
class T5DenseGatedActDense(nn.Module):
255+
def __init__(self, d_model, d_ff, dropout_rate):
256+
super().__init__()
257+
self.wi_0 = nn.Linear(d_model, d_ff, bias=False)
258+
self.wi_1 = nn.Linear(d_model, d_ff, bias=False)
259+
self.wo = nn.Linear(d_ff, d_model, bias=False)
260+
self.dropout = nn.Dropout(dropout_rate)
261+
self.act = NewGELUActivation()
262+
263+
def forward(self, hidden_states):
264+
hidden_gelu = self.act(self.wi_0(hidden_states))
265+
hidden_linear = self.wi_1(hidden_states)
266+
hidden_states = hidden_gelu * hidden_linear
267+
hidden_states = self.dropout(hidden_states)
268+
269+
hidden_states = self.wo(hidden_states)
270+
return hidden_states
271+
272+
273+
class T5LayerNorm(nn.Module):
274+
def __init__(self, hidden_size, eps=1e-6):
275+
"""
276+
Construct a layernorm module in the T5 style. No bias and no subtraction of mean.
277+
"""
278+
super().__init__()
279+
self.weight = nn.Parameter(torch.ones(hidden_size))
280+
self.variance_epsilon = eps
281+
282+
def forward(self, hidden_states):
283+
# T5 uses a layer_norm which only scales and doesn't shift, which is also known as Root Mean
284+
# Square Layer Normalization https://arxiv.org/abs/1910.07467 thus variance is calculated
285+
# w/o mean and there is no bias. Additionally we want to make sure that the accumulation for
286+
# half-precision inputs is done in fp32
287+
288+
variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
289+
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
290+
291+
# convert into half-precision if necessary
292+
if self.weight.dtype in [torch.float16, torch.bfloat16]:
293+
hidden_states = hidden_states.to(self.weight.dtype)
294+
295+
return self.weight * hidden_states
296+
297+
298+
class NewGELUActivation(nn.Module):
299+
"""
300+
Implementation of the GELU activation function currently in Google BERT repo (identical to OpenAI GPT). Also see
301+
the Gaussian Error Linear Units paper: https://arxiv.org/abs/1606.08415
302+
"""
303+
304+
def forward(self, input: torch.Tensor) -> torch.Tensor:
305+
return 0.5 * input * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (input + 0.044715 * torch.pow(input, 3.0))))
306+
307+
308+
class T5FiLMLayer(nn.Module):
309+
"""
310+
FiLM Layer
311+
"""
312+
313+
def __init__(self, in_features, out_features):
314+
super().__init__()
315+
self.scale_bias = nn.Linear(in_features, out_features * 2, bias=False)
316+
317+
def forward(self, x, conditioning_emb):
318+
emb = self.scale_bias(conditioning_emb)
319+
scale, shift = torch.chunk(emb, 2, -1)
320+
x = x * (1 + scale) + shift
321+
return x

0 commit comments

Comments
 (0)