2121import torch
2222
2323from diffusers import UNet2DConditionModel , UNet2DModel
24- from diffusers .utils import floats_tensor , require_torch_gpu , slow , torch_all_close , torch_device
24+ from diffusers .utils import floats_tensor , load_numpy , require_torch_gpu , slow , torch_all_close , torch_device
2525from parameterized import parameterized
2626
2727from ..test_modeling_common import ModelTesterMixin
@@ -411,18 +411,18 @@ def test_forward_with_norm_groups(self):
411411
412412@slow
413413class UNet2DConditionModelIntegrationTests (unittest .TestCase ):
414+ def get_file_format (self , seed , shape ):
415+ return f"gaussian_noise_s={ seed } _shape={ '_' .join ([str (s ) for s in shape ])} .npy"
416+
414417 def tearDown (self ):
415418 # clean up the VRAM after each test
416419 super ().tearDown ()
417420 gc .collect ()
418421 torch .cuda .empty_cache ()
419422
420423 def get_latents (self , seed = 0 , shape = (4 , 4 , 64 , 64 ), fp16 = False ):
421- batch_size , channels , height , width = shape
422- generator = torch .Generator (device = torch_device ).manual_seed (seed )
423424 dtype = torch .float16 if fp16 else torch .float32
424- image = torch .randn (batch_size , channels , height , width , device = torch_device , generator = generator , dtype = dtype )
425-
425+ image = torch .from_numpy (load_numpy (self .get_file_format (seed , shape ))).to (torch_device ).to (dtype )
426426 return image
427427
428428 def get_unet_model (self , fp16 = False , model_id = "CompVis/stable-diffusion-v1-4" ):
@@ -437,9 +437,9 @@ def get_unet_model(self, fp16=False, model_id="CompVis/stable-diffusion-v1-4"):
437437 return model
438438
439439 def get_encoder_hidden_states (self , seed = 0 , shape = (4 , 77 , 768 ), fp16 = False ):
440- generator = torch .Generator (device = torch_device ).manual_seed (seed )
441440 dtype = torch .float16 if fp16 else torch .float32
442- return torch .randn (shape , device = torch_device , generator = generator , dtype = dtype )
441+ hidden_states = torch .from_numpy (load_numpy (self .get_file_format (seed , shape ))).to (torch_device ).to (dtype )
442+ return hidden_states
443443
444444 @parameterized .expand (
445445 [
0 commit comments