Skip to content

Commit e30e1b8

Browse files
authored
Support one-string prompts and custom image size in LDM (#212)
* Support one-string prompts in LDM * Add other features from SD too
1 parent df90f0c commit e30e1b8

File tree

2 files changed

+22
-12
lines changed

2 files changed

+22
-12
lines changed

src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py

Lines changed: 21 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import inspect
2-
from typing import Optional, Tuple, Union
2+
from typing import List, Optional, Tuple, Union
33

44
import torch
55
import torch.nn as nn
@@ -24,20 +24,30 @@ def __init__(self, vqvae, bert, tokenizer, unet, scheduler):
2424
@torch.no_grad()
2525
def __call__(
2626
self,
27-
prompt,
28-
batch_size=1,
29-
generator=None,
30-
torch_device=None,
31-
eta=0.0,
32-
guidance_scale=1.0,
33-
num_inference_steps=50,
34-
output_type="pil",
27+
prompt: Union[str, List[str]],
28+
height: Optional[int] = 256,
29+
width: Optional[int] = 256,
30+
num_inference_steps: Optional[int] = 50,
31+
guidance_scale: Optional[float] = 1.0,
32+
eta: Optional[float] = 0.0,
33+
generator: Optional[torch.Generator] = None,
34+
torch_device: Optional[Union[str, torch.device]] = None,
35+
output_type: Optional[str] = "pil",
3536
):
3637
# eta corresponds to η in paper and should be between [0, 1]
3738

3839
if torch_device is None:
3940
torch_device = "cuda" if torch.cuda.is_available() else "cpu"
40-
batch_size = len(prompt)
41+
42+
if isinstance(prompt, str):
43+
batch_size = 1
44+
elif isinstance(prompt, list):
45+
batch_size = len(prompt)
46+
else:
47+
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
48+
49+
if height % 8 != 0 or width % 8 != 0:
50+
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
4151

4252
self.unet.to(torch_device)
4353
self.vqvae.to(torch_device)
@@ -53,7 +63,7 @@ def __call__(
5363
text_embeddings = self.bert(text_input.input_ids.to(torch_device))[0]
5464

5565
latents = torch.randn(
56-
(batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size),
66+
(batch_size, self.unet.in_channels, height // 8, width // 8),
5767
generator=generator,
5868
)
5969
latents = latents.to(torch_device)

tests/test_modeling_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -854,7 +854,7 @@ def test_ldm_text2img_fast(self):
854854

855855
prompt = "A painting of a squirrel eating a burger"
856856
generator = torch.manual_seed(0)
857-
image = ldm([prompt], generator=generator, num_inference_steps=1, output_type="numpy")["sample"]
857+
image = ldm(prompt, generator=generator, num_inference_steps=1, output_type="numpy")["sample"]
858858

859859
image_slice = image[0, -3:, -3:, -1]
860860

0 commit comments

Comments
 (0)