@@ -195,7 +195,7 @@ def __call__(
195195 truncation = True ,
196196 return_tensors = "pt" ,
197197 )
198- text_embeddings = self .text_encoder (text_input .input_ids .to (self . device ))[0 ]
198+ text_embeddings = self .text_encoder (text_input .input_ids .to ("cuda" ))[0 ]
199199
200200 # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
201201 # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
@@ -207,7 +207,7 @@ def __call__(
207207 uncond_input = self .tokenizer (
208208 ["" ] * batch_size , padding = "max_length" , max_length = max_length , return_tensors = "pt"
209209 )
210- uncond_embeddings = self .text_encoder (uncond_input .input_ids .to (self . device ))[0 ]
210+ uncond_embeddings = self .text_encoder (uncond_input .input_ids .to ("cuda" ))[0 ]
211211
212212 # For classifier free guidance, we need to do two forward passes.
213213 # Here we concatenate the unconditional and text embeddings into a single batch
@@ -219,8 +219,8 @@ def __call__(
219219 # Unlike in other pipelines, latents need to be generated in the target device
220220 # for 1-to-1 results reproducibility with the CompVis implementation.
221221 # However this currently doesn't work in `mps`.
222- latents_device = "cpu" if self . device . type == "mps" else self . device
223- latents_shape = (batch_size , self . unet . in_channels , height // 8 , width // 8 )
222+ latents_device = "cuda"
223+ latents_shape = (batch_size , 4 , height // 8 , width // 8 )
224224 if latents is None :
225225 latents = torch .randn (
226226 latents_shape ,
@@ -259,7 +259,7 @@ def __call__(
259259 latent_model_input = latent_model_input / ((sigma ** 2 + 1 ) ** 0.5 )
260260
261261 # predict the noise residual
262- noise_pred = self .unet (latent_model_input , t , encoder_hidden_states = text_embeddings ). sample
262+ noise_pred = self .unet (latent_model_input , t , text_embeddings )[ 0 ] # TODO: fix for return_dict case
263263
264264 # perform guidance
265265 if do_classifier_free_guidance :
@@ -280,9 +280,9 @@ def __call__(
280280 image = image .cpu ().permute (0 , 2 , 3 , 1 ).numpy ()
281281
282282 # run safety checker
283- safety_cheker_input = self .feature_extractor (self .numpy_to_pil (image ), return_tensors = "pt" ).to (self . device )
283+ safety_cheker_input = self .feature_extractor (self .numpy_to_pil (image ), return_tensors = "pt" ).to ("cuda" )
284284 image , has_nsfw_concept = self .safety_checker (images = image , clip_input = safety_cheker_input .pixel_values .to (text_embeddings .dtype ))
285-
285+
286286 if output_type == "pil" :
287287 image = self .numpy_to_pil (image )
288288
0 commit comments