Skip to content

Commit 9b63854

Browse files
Improve reproduceability 2/3 (#1906)
* [Repro] Correct reproducability * up * up * uP * up * need better image * allow conversion from no state dict checkpoints * up * up * up * up * check tensors * check tensors * check tensors * check tensors * next try * up * up * better name * up * up * Apply suggestions from code review * correct more * up * replace all torch randn * fix * correct * correct * finish * fix more * up
1 parent 67e2f95 commit 9b63854

File tree

49 files changed

+171
-391
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

49 files changed

+171
-391
lines changed

src/diffusers/experimental/rl/value_guided_sampling.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020
from ...models.unet_1d import UNet1DModel
2121
from ...pipelines import DiffusionPipeline
22+
from ...utils import randn_tensor
2223
from ...utils.dummy_pt_objects import DDPMScheduler
2324

2425

@@ -127,7 +128,7 @@ def __call__(self, obs, batch_size=64, planning_horizon=32, n_guide_steps=2, sca
127128
shape = (batch_size, planning_horizon, self.state_dim + self.action_dim)
128129

129130
# generate initial noise and apply our conditions (to make the trajectories start at current state)
130-
x1 = torch.randn(shape, device=self.unet.device)
131+
x1 = randn_tensor(shape, device=self.unet.device)
131132
x = self.reset_x0(x1, conditions, self.action_dim)
132133
x = self.to_torch(x)
133134

src/diffusers/models/prior_transformer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ def __init__(
9595
self.proj_to_clip_embeddings = nn.Linear(inner_dim, embedding_dim)
9696

9797
causal_attention_mask = torch.full(
98-
[num_embeddings + additional_embeddings, num_embeddings + additional_embeddings], float("-inf")
98+
[num_embeddings + additional_embeddings, num_embeddings + additional_embeddings], -10000.0
9999
)
100100
causal_attention_mask.triu_(1)
101101
causal_attention_mask = causal_attention_mask[None, ...]

src/diffusers/models/vae.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
import torch
1919
import torch.nn as nn
2020

21-
from ..utils import BaseOutput
21+
from ..utils import BaseOutput, randn_tensor
2222
from .unet_2d_blocks import UNetMidBlock2D, get_down_block, get_up_block
2323

2424

@@ -323,11 +323,10 @@ def __init__(self, parameters, deterministic=False):
323323
)
324324

325325
def sample(self, generator: Optional[torch.Generator] = None) -> torch.FloatTensor:
326-
device = self.parameters.device
327-
sample_device = "cpu" if device.type == "mps" else device
328-
sample = torch.randn(self.mean.shape, generator=generator, device=sample_device)
329326
# make sure sample is on the same device as the parameters and has same dtype
330-
sample = sample.to(device=device, dtype=self.parameters.dtype)
327+
sample = randn_tensor(
328+
self.mean.shape, generator=generator, device=self.parameters.device, dtype=self.parameters.dtype
329+
)
331330
x = self.mean + self.std * sample
332331
return x
333332

src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py

Lines changed: 2 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
LMSDiscreteScheduler,
3232
PNDMScheduler,
3333
)
34-
from ...utils import deprecate, logging, replace_example_docstring
34+
from ...utils import deprecate, logging, randn_tensor, replace_example_docstring
3535
from ..pipeline_utils import DiffusionPipeline
3636
from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker
3737
from . import AltDiffusionPipelineOutput, RobertaSeriesModelWithTransformation
@@ -401,20 +401,8 @@ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype
401401
)
402402

403403
if latents is None:
404-
rand_device = "cpu" if device.type == "mps" else device
405-
406-
if isinstance(generator, list):
407-
shape = (1,) + shape[1:]
408-
latents = [
409-
torch.randn(shape, generator=generator[i], device=rand_device, dtype=dtype)
410-
for i in range(batch_size)
411-
]
412-
latents = torch.cat(latents, dim=0).to(device)
413-
else:
414-
latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype).to(device)
404+
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
415405
else:
416-
if latents.shape != shape:
417-
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
418406
latents = latents.to(device)
419407

420408
# scale the initial noise by the standard deviation required by the scheduler

src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
LMSDiscreteScheduler,
3434
PNDMScheduler,
3535
)
36-
from ...utils import PIL_INTERPOLATION, deprecate, logging, replace_example_docstring
36+
from ...utils import PIL_INTERPOLATION, deprecate, logging, randn_tensor, replace_example_docstring
3737
from ..pipeline_utils import DiffusionPipeline
3838
from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker
3939
from . import AltDiffusionPipelineOutput, RobertaSeriesModelWithTransformation
@@ -461,16 +461,8 @@ def prepare_latents(self, image, timestep, batch_size, num_images_per_prompt, dt
461461
else:
462462
init_latents = torch.cat([init_latents], dim=0)
463463

464-
rand_device = "cpu" if device.type == "mps" else device
465464
shape = init_latents.shape
466-
if isinstance(generator, list):
467-
shape = (1,) + shape[1:]
468-
noise = [
469-
torch.randn(shape, generator=generator[i], device=rand_device, dtype=dtype) for i in range(batch_size)
470-
]
471-
noise = torch.cat(noise, dim=0).to(device)
472-
else:
473-
noise = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype).to(device)
465+
noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
474466

475467
# get latents
476468
init_latents = self.scheduler.add_noise(init_latents, noise, timestep)

src/diffusers/pipelines/audio_diffusion/pipeline_audio_diffusion.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323

2424
from ...models import AutoencoderKL, UNet2DConditionModel
2525
from ...schedulers import DDIMScheduler, DDPMScheduler
26+
from ...utils import randn_tensor
2627
from ..pipeline_utils import AudioPipelineOutput, BaseOutput, DiffusionPipeline, ImagePipelineOutput
2728
from .mel import Mel
2829

@@ -126,7 +127,7 @@ def __call__(
126127
input_dims = self.get_input_dims()
127128
self.mel.set_resolution(x_res=input_dims[1], y_res=input_dims[0])
128129
if noise is None:
129-
noise = torch.randn(
130+
noise = randn_tensor(
130131
(
131132
batch_size,
132133
self.unet.in_channels,

src/diffusers/pipelines/dance_diffusion/pipeline_dance_diffusion.py

Lines changed: 2 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
import torch
1919

20-
from ...utils import logging
20+
from ...utils import logging, randn_tensor
2121
from ..pipeline_utils import AudioPipelineOutput, DiffusionPipeline
2222

2323

@@ -100,16 +100,7 @@ def __call__(
100100
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
101101
)
102102

103-
rand_device = "cpu" if self.device.type == "mps" else self.device
104-
if isinstance(generator, list):
105-
shape = (1,) + shape[1:]
106-
audio = [
107-
torch.randn(shape, generator=generator[i], device=rand_device, dtype=self.unet.dtype)
108-
for i in range(batch_size)
109-
]
110-
audio = torch.cat(audio, dim=0).to(self.device)
111-
else:
112-
audio = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype).to(self.device)
103+
audio = randn_tensor(shape, generator=generator, device=self.device, dtype=dtype)
113104

114105
# set step values
115106
self.scheduler.set_timesteps(num_inference_steps, device=audio.device)

src/diffusers/pipelines/ddim/pipeline_ddim.py

Lines changed: 2 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717
import torch
1818

19-
from ...utils import deprecate
19+
from ...utils import deprecate, randn_tensor
2020
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
2121

2222

@@ -103,17 +103,7 @@ def __call__(
103103
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
104104
)
105105

106-
rand_device = "cpu" if self.device.type == "mps" else self.device
107-
if isinstance(generator, list):
108-
shape = (1,) + image_shape[1:]
109-
image = [
110-
torch.randn(shape, generator=generator[i], device=rand_device, dtype=self.unet.dtype)
111-
for i in range(batch_size)
112-
]
113-
image = torch.cat(image, dim=0).to(self.device)
114-
else:
115-
image = torch.randn(image_shape, generator=generator, device=rand_device, dtype=self.unet.dtype)
116-
image = image.to(self.device)
106+
image = randn_tensor(image_shape, generator=generator, device=self.device, dtype=self.unet.dtype)
117107

118108
# set step values
119109
self.scheduler.set_timesteps(num_inference_steps)

src/diffusers/pipelines/ddpm/pipeline_ddpm.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
import torch
1919

2020
from ...configuration_utils import FrozenDict
21-
from ...utils import deprecate
21+
from ...utils import deprecate, randn_tensor
2222
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
2323

2424

@@ -100,10 +100,10 @@ def __call__(
100100

101101
if self.device.type == "mps":
102102
# randn does not work reproducibly on mps
103-
image = torch.randn(image_shape, generator=generator)
103+
image = randn_tensor(image_shape, generator=generator)
104104
image = image.to(self.device)
105105
else:
106-
image = torch.randn(image_shape, generator=generator, device=self.device)
106+
image = randn_tensor(image_shape, generator=generator, device=self.device)
107107

108108
# set step values
109109
self.scheduler.set_timesteps(num_inference_steps)

src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py

Lines changed: 2 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626

2727
from ...models import AutoencoderKL, UNet2DConditionModel, UNet2DModel, VQModel
2828
from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
29+
from ...utils import randn_tensor
2930
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
3031

3132

@@ -143,20 +144,7 @@ def __call__(
143144
)
144145

145146
if latents is None:
146-
rand_device = "cpu" if self.device.type == "mps" else self.device
147-
148-
if isinstance(generator, list):
149-
latents_shape = (1,) + latents_shape[1:]
150-
latents = [
151-
torch.randn(latents_shape, generator=generator[i], device=rand_device, dtype=text_embeddings.dtype)
152-
for i in range(batch_size)
153-
]
154-
latents = torch.cat(latents, dim=0)
155-
else:
156-
latents = torch.randn(
157-
latents_shape, generator=generator, device=rand_device, dtype=text_embeddings.dtype
158-
)
159-
latents = latents.to(self.device)
147+
latents = randn_tensor(latents_shape, generator=generator, device=self.device, dtype=text_embeddings.dtype)
160148
else:
161149
if latents.shape != latents_shape:
162150
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}")

0 commit comments

Comments
 (0)