11import inspect
2- from typing import Optional , Tuple , Union
2+ from typing import List , Optional , Tuple , Union
33
44import torch
55import torch .nn as nn
@@ -24,20 +24,30 @@ def __init__(self, vqvae, bert, tokenizer, unet, scheduler):
2424 @torch .no_grad ()
2525 def __call__ (
2626 self ,
27- prompt ,
28- batch_size = 1 ,
29- generator = None ,
30- torch_device = None ,
31- eta = 0.0 ,
32- guidance_scale = 1.0 ,
33- num_inference_steps = 50 ,
34- output_type = "pil" ,
27+ prompt : Union [str , List [str ]],
28+ height : Optional [int ] = 256 ,
29+ width : Optional [int ] = 256 ,
30+ num_inference_steps : Optional [int ] = 50 ,
31+ guidance_scale : Optional [float ] = 1.0 ,
32+ eta : Optional [float ] = 0.0 ,
33+ generator : Optional [torch .Generator ] = None ,
34+ torch_device : Optional [Union [str , torch .device ]] = None ,
35+ output_type : Optional [str ] = "pil" ,
3536 ):
3637 # eta corresponds to η in paper and should be between [0, 1]
3738
3839 if torch_device is None :
3940 torch_device = "cuda" if torch .cuda .is_available () else "cpu"
40- batch_size = len (prompt )
41+
42+ if isinstance (prompt , str ):
43+ batch_size = 1
44+ elif isinstance (prompt , list ):
45+ batch_size = len (prompt )
46+ else :
47+ raise ValueError (f"`prompt` has to be of type `str` or `list` but is { type (prompt )} " )
48+
49+ if height % 8 != 0 or width % 8 != 0 :
50+ raise ValueError (f"`height` and `width` have to be divisible by 8 but are { height } and { width } ." )
4151
4252 self .unet .to (torch_device )
4353 self .vqvae .to (torch_device )
@@ -53,7 +63,7 @@ def __call__(
5363 text_embeddings = self .bert (text_input .input_ids .to (torch_device ))[0 ]
5464
5565 latents = torch .randn (
56- (batch_size , self .unet .in_channels , self . unet . sample_size , self . unet . sample_size ),
66+ (batch_size , self .unet .in_channels , height // 8 , width // 8 ),
5767 generator = generator ,
5868 )
5969 latents = latents .to (torch_device )
0 commit comments