Skip to content

Commit ea01a4c

Browse files
fix
2 parents cbbb293 + d37f08d commit ea01a4c

File tree

3 files changed

+25
-8
lines changed

3 files changed

+25
-8
lines changed

src/diffusers/utils/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343
from .testing_utils import (
4444
floats_tensor,
4545
load_image,
46+
load_numpy,
4647
parse_flag_from_env,
4748
require_torch_gpu,
4849
slow,

src/diffusers/utils/testing_utils.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,14 @@
44
import random
55
import re
66
import unittest
7+
import urllib.parse
78
from distutils.util import strtobool
8-
from io import StringIO
9+
from io import BytesIO, StringIO
910
from pathlib import Path
1011
from typing import Union
1112

13+
import numpy as np
14+
1215
import PIL.Image
1316
import PIL.ImageOps
1417
import requests
@@ -165,6 +168,19 @@ def load_image(image: Union[str, PIL.Image.Image]) -> PIL.Image.Image:
165168
return image
166169

167170

171+
def load_numpy(path) -> np.ndarray:
172+
if not path.startswith("http://") or path.startswith("https://"):
173+
path = os.path.join(
174+
"https://huggingface.co/datasets/fusing/diffusers-testing/resolve/main", urllib.parse.quote(path)
175+
)
176+
177+
response = requests.get(path)
178+
response.raise_for_status()
179+
array = np.load(BytesIO(response.content))
180+
181+
return array
182+
183+
168184
# --- pytest conf functions --- #
169185

170186
# to avoid multiple invocation from tests/conftest.py and examples/conftest.py - make sure it's called only once

tests/models/test_models_unet_2d.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
import torch
2222

2323
from diffusers import UNet2DConditionModel, UNet2DModel
24-
from diffusers.utils import floats_tensor, require_torch_gpu, slow, torch_all_close, torch_device
24+
from diffusers.utils import floats_tensor, load_numpy, require_torch_gpu, slow, torch_all_close, torch_device
2525
from parameterized import parameterized
2626

2727
from ..test_modeling_common import ModelTesterMixin
@@ -411,18 +411,18 @@ def test_forward_with_norm_groups(self):
411411

412412
@slow
413413
class UNet2DConditionModelIntegrationTests(unittest.TestCase):
414+
def get_file_format(self, seed, shape):
415+
return f"gaussian_noise_s={seed}_shape={'_'.join([str(s) for s in shape])}.npy"
416+
414417
def tearDown(self):
415418
# clean up the VRAM after each test
416419
super().tearDown()
417420
gc.collect()
418421
torch.cuda.empty_cache()
419422

420423
def get_latents(self, seed=0, shape=(4, 4, 64, 64), fp16=False):
421-
batch_size, channels, height, width = shape
422-
generator = torch.Generator(device=torch_device).manual_seed(seed)
423424
dtype = torch.float16 if fp16 else torch.float32
424-
image = torch.randn(batch_size, channels, height, width, device=torch_device, generator=generator, dtype=dtype)
425-
425+
image = torch.from_numpy(load_numpy(self.get_file_format(seed, shape))).to(torch_device).to(dtype)
426426
return image
427427

428428
def get_unet_model(self, fp16=False, model_id="CompVis/stable-diffusion-v1-4"):
@@ -437,9 +437,9 @@ def get_unet_model(self, fp16=False, model_id="CompVis/stable-diffusion-v1-4"):
437437
return model
438438

439439
def get_encoder_hidden_states(self, seed=0, shape=(4, 77, 768), fp16=False):
440-
generator = torch.Generator(device=torch_device).manual_seed(seed)
441440
dtype = torch.float16 if fp16 else torch.float32
442-
return torch.randn(shape, device=torch_device, generator=generator, dtype=dtype)
441+
hidden_states = torch.from_numpy(load_numpy(self.get_file_format(seed, shape))).to(torch_device).to(dtype)
442+
return hidden_states
443443

444444
@parameterized.expand(
445445
[

0 commit comments

Comments
 (0)