Skip to content

Commit 2db92f8

Browse files
[Dance Diffusion] FP16 (huggingface#980)
* add in fp16 * up
1 parent 92dd118 commit 2db92f8

File tree

2 files changed

+7
-3
lines changed

2 files changed

+7
-3
lines changed

models/unet_1d.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,7 @@ def forward(
149149
timestep = timestep[None]
150150

151151
timestep_embed = self.time_proj(timestep)[..., None]
152-
timestep_embed = timestep_embed.repeat([1, 1, sample.shape[2]])
152+
timestep_embed = timestep_embed.repeat([1, 1, sample.shape[2]]).to(sample.dtype)
153153

154154
# 2. down
155155
down_block_res_samples = ()

pipelines/dance_diffusion/pipeline_dance_diffusion.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -91,10 +91,14 @@ def __call__(
9191
)
9292
sample_size = int(sample_size)
9393

94-
audio = torch.randn((batch_size, self.unet.in_channels, sample_size), generator=generator, device=self.device)
94+
dtype = next(iter(self.unet.parameters())).dtype
95+
audio = torch.randn(
96+
(batch_size, self.unet.in_channels, sample_size), generator=generator, device=self.device, dtype=dtype
97+
)
9598

9699
# set step values
97100
self.scheduler.set_timesteps(num_inference_steps, device=audio.device)
101+
self.scheduler.timesteps = self.scheduler.timesteps.to(dtype)
98102

99103
for t in self.progress_bar(self.scheduler.timesteps):
100104
# 1. predict noise model_output
@@ -103,7 +107,7 @@ def __call__(
103107
# 2. compute previous image: x_t -> t_t-1
104108
audio = self.scheduler.step(model_output, t, audio).prev_sample
105109

106-
audio = audio.clamp(-1, 1).cpu().numpy()
110+
audio = audio.clamp(-1, 1).float().cpu().numpy()
107111

108112
audio = audio[:, :, :original_sample_size]
109113

0 commit comments

Comments
 (0)