Skip to content

Commit 365ff8f

Browse files
[Dance Diffusion] FP16 (#980)
* add in fp16 * up
1 parent 88fa6b7 commit 365ff8f

File tree

3 files changed

+24
-3
lines changed

3 files changed

+24
-3
lines changed

src/diffusers/models/unet_1d.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,7 @@ def forward(
149149
timestep = timestep[None]
150150

151151
timestep_embed = self.time_proj(timestep)[..., None]
152-
timestep_embed = timestep_embed.repeat([1, 1, sample.shape[2]])
152+
timestep_embed = timestep_embed.repeat([1, 1, sample.shape[2]]).to(sample.dtype)
153153

154154
# 2. down
155155
down_block_res_samples = ()

src/diffusers/pipelines/dance_diffusion/pipeline_dance_diffusion.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -91,10 +91,14 @@ def __call__(
9191
)
9292
sample_size = int(sample_size)
9393

94-
audio = torch.randn((batch_size, self.unet.in_channels, sample_size), generator=generator, device=self.device)
94+
dtype = next(iter(self.unet.parameters())).dtype
95+
audio = torch.randn(
96+
(batch_size, self.unet.in_channels, sample_size), generator=generator, device=self.device, dtype=dtype
97+
)
9598

9699
# set step values
97100
self.scheduler.set_timesteps(num_inference_steps, device=audio.device)
101+
self.scheduler.timesteps = self.scheduler.timesteps.to(dtype)
98102

99103
for t in self.progress_bar(self.scheduler.timesteps):
100104
# 1. predict noise model_output
@@ -103,7 +107,7 @@ def __call__(
103107
# 2. compute previous image: x_t -> t_t-1
104108
audio = self.scheduler.step(model_output, t, audio).prev_sample
105109

106-
audio = audio.clamp(-1, 1).cpu().numpy()
110+
audio = audio.clamp(-1, 1).float().cpu().numpy()
107111

108112
audio = audio[:, :, :original_sample_size]
109113

tests/pipelines/dance_diffusion/test_dance_diffusion.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,3 +99,20 @@ def test_dance_diffusion(self):
9999
assert audio.shape == (1, 2, pipe.unet.sample_size)
100100
expected_slice = np.array([-0.1576, -0.1526, -0.127, -0.2699, -0.2762, -0.2487])
101101
assert np.abs(audio_slice.flatten() - expected_slice).max() < 1e-2
102+
103+
def test_dance_diffusion_fp16(self):
104+
device = torch_device
105+
106+
pipe = DanceDiffusionPipeline.from_pretrained("harmonai/maestro-150k", torch_dtype=torch.float16)
107+
pipe = pipe.to(device)
108+
pipe.set_progress_bar_config(disable=None)
109+
110+
generator = torch.Generator(device=device).manual_seed(0)
111+
output = pipe(generator=generator, num_inference_steps=100, sample_length_in_s=4.096)
112+
audio = output.audios
113+
114+
audio_slice = audio[0, -3:, -3:]
115+
116+
assert audio.shape == (1, 2, pipe.unet.sample_size)
117+
expected_slice = np.array([-0.1693, -0.1698, -0.1447, -0.3044, -0.3203, -0.2937])
118+
assert np.abs(audio_slice.flatten() - expected_slice).max() < 1e-2

0 commit comments

Comments
 (0)