Skip to content

Commit cc59b05

Browse files
patrickvonplatenpatil-surajpcuenca
authored
[ModelOutputs] Replace dict outputs with Dict/Dataclass and allow to return tuples (#334)
* add outputs for models * add for pipelines * finish schedulers * better naming * adapt tests as well * replace dict access with . access * make schedulers works * finish * correct readme * make bcp compatible * up * small fix * finish * more fixes * more fixes * Apply suggestions from code review Co-authored-by: Suraj Patil <[email protected]> Co-authored-by: Pedro Cuenca <[email protected]> * Update src/diffusers/models/vae.py Co-authored-by: Pedro Cuenca <[email protected]> * Adapt model outputs * Apply more suggestions * finish examples * correct Co-authored-by: Suraj Patil <[email protected]> Co-authored-by: Pedro Cuenca <[email protected]>
1 parent daddd98 commit cc59b05

39 files changed

+893
-247
lines changed

README.md

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ pipe = pipe.to("cuda")
8080

8181
prompt = "a photo of an astronaut riding a horse on mars"
8282
with autocast("cuda"):
83-
image = pipe(prompt)["sample"][0]
83+
image = pipe(prompt).images[0]
8484
```
8585

8686
**Note**: If you don't want to use the token, you can also simply download the model weights
@@ -101,7 +101,7 @@ pipe = pipe.to("cuda")
101101

102102
prompt = "a photo of an astronaut riding a horse on mars"
103103
with autocast("cuda"):
104-
image = pipe(prompt)["sample"][0]
104+
image = pipe(prompt).images[0]
105105
```
106106

107107
If you are limited by GPU memory, you might want to consider using the model in `fp16`.
@@ -117,7 +117,7 @@ pipe = pipe.to("cuda")
117117

118118
prompt = "a photo of an astronaut riding a horse on mars"
119119
with autocast("cuda"):
120-
image = pipe(prompt)["sample"][0]
120+
image = pipe(prompt).images[0]
121121
```
122122

123123
Finally, if you wish to use a different scheduler, you can simply instantiate
@@ -143,7 +143,7 @@ pipe = pipe.to("cuda")
143143

144144
prompt = "a photo of an astronaut riding a horse on mars"
145145
with autocast("cuda"):
146-
image = pipe(prompt)["sample"][0]
146+
image = pipe(prompt).images[0]
147147

148148
image.save("astronaut_rides_horse.png")
149149
```
@@ -184,7 +184,7 @@ init_image = init_image.resize((768, 512))
184184
prompt = "A fantasy landscape, trending on artstation"
185185

186186
with autocast("cuda"):
187-
images = pipe(prompt=prompt, init_image=init_image, strength=0.75, guidance_scale=7.5)["sample"]
187+
images = pipe(prompt=prompt, init_image=init_image, strength=0.75, guidance_scale=7.5).images
188188

189189
images[0].save("fantasy_landscape.png")
190190
```
@@ -228,7 +228,7 @@ pipe = pipe.to(device)
228228

229229
prompt = "a cat sitting on a bench"
230230
with autocast("cuda"):
231-
images = pipe(prompt=prompt, init_image=init_image, mask_image=mask_image, strength=0.75)["sample"]
231+
images = pipe(prompt=prompt, init_image=init_image, mask_image=mask_image, strength=0.75).images
232232

233233
images[0].save("cat_on_bench.png")
234234
```
@@ -260,7 +260,7 @@ ldm = DiffusionPipeline.from_pretrained(model_id)
260260

261261
# run pipeline in inference (sample random noise and denoise)
262262
prompt = "A painting of a squirrel eating a burger"
263-
images = ldm([prompt], num_inference_steps=50, eta=0.3, guidance_scale=6)["sample"]
263+
images = ldm([prompt], num_inference_steps=50, eta=0.3, guidance_scale=6).images
264264

265265
# save images
266266
for idx, image in enumerate(images):
@@ -277,7 +277,7 @@ model_id = "google/ddpm-celebahq-256"
277277
ddpm = DDPMPipeline.from_pretrained(model_id) # you can replace DDPMPipeline with DDIMPipeline or PNDMPipeline for faster inference
278278

279279
# run pipeline in inference (sample random noise and denoise)
280-
image = ddpm()["sample"]
280+
image = ddpm().images
281281

282282
# save image
283283
image[0].save("ddpm_generated_image.png")

examples/textual_inversion/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ pipe = pipe = StableDiffusionPipeline.from_pretrained(model_id,torch_dtype=torch
7676
prompt = "A <cat-toy> backpack"
7777

7878
with autocast("cuda"):
79-
image = pipe(prompt, num_inference_steps=50, guidance_scale=7.5)["sample"][0]
79+
image = pipe(prompt, num_inference_steps=50, guidance_scale=7.5).images[0]
8080

8181
image.save("cat-backpack.png")
8282
```

examples/textual_inversion/textual_inversion.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -498,7 +498,7 @@ def main():
498498
for step, batch in enumerate(train_dataloader):
499499
with accelerator.accumulate(text_encoder):
500500
# Convert images to latent space
501-
latents = vae.encode(batch["pixel_values"]).sample().detach()
501+
latents = vae.encode(batch["pixel_values"]).latent_dist.sample().detach()
502502
latents = latents * 0.18215
503503

504504
# Sample noise that we'll add to the latents
@@ -515,7 +515,7 @@ def main():
515515
encoder_hidden_states = text_encoder(batch["input_ids"])[0]
516516

517517
# Predict the noise residual
518-
noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states)["sample"]
518+
noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
519519

520520
loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean()
521521
accelerator.backward(loss)

examples/unconditional_image_generation/train_unconditional.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,7 @@ def transforms(examples):
139139

140140
with accelerator.accumulate(model):
141141
# Predict the noise residual
142-
noise_pred = model(noisy_images, timesteps)["sample"]
142+
noise_pred = model(noisy_images, timesteps).sample
143143
loss = F.mse_loss(noise_pred, noise)
144144
accelerator.backward(loss)
145145

@@ -174,7 +174,7 @@ def transforms(examples):
174174

175175
generator = torch.manual_seed(0)
176176
# run pipeline in inference (sample random noise and denoise)
177-
images = pipeline(generator=generator, batch_size=args.eval_batch_size, output_type="numpy")["sample"]
177+
images = pipeline(generator=generator, batch_size=args.eval_batch_size, output_type="numpy").images
178178

179179
# denormalize the images and save to tensorboard
180180
images_processed = (images * 255).round().astype("uint8")

scripts/generate_logits.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@
119119
noise = torch.randn(1, model.config.in_channels, model.config.sample_size, model.config.sample_size)
120120
time_step = torch.tensor([10] * noise.shape[0])
121121
with torch.no_grad():
122-
logits = model(noise, time_step)["sample"]
122+
logits = model(noise, time_step).sample
123123

124124
assert torch.allclose(
125125
logits[0, 0, 0, :30], results["_".join("_".join(mod.modelId.split("/")).split("-"))], atol=1e-3

src/diffusers/hub_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,9 @@
1919
from pathlib import Path
2020
from typing import Optional
2121

22-
from diffusers import DiffusionPipeline
2322
from huggingface_hub import HfFolder, Repository, whoami
2423

24+
from .pipeline_utils import DiffusionPipeline
2525
from .utils import is_modelcards_available, logging
2626

2727

src/diffusers/models/unet_2d.py

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,27 @@
1-
from typing import Dict, Optional, Tuple, Union
1+
from dataclasses import dataclass
2+
from typing import Optional, Tuple, Union
23

34
import torch
45
import torch.nn as nn
56

67
from ..configuration_utils import ConfigMixin, register_to_config
78
from ..modeling_utils import ModelMixin
9+
from ..utils import BaseOutput
810
from .embeddings import GaussianFourierProjection, TimestepEmbedding, Timesteps
911
from .unet_blocks import UNetMidBlock2D, get_down_block, get_up_block
1012

1113

14+
@dataclass
15+
class UNet2DOutput(BaseOutput):
16+
"""
17+
Args:
18+
sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
19+
Hidden states output. Output of last layer of model.
20+
"""
21+
22+
sample: torch.FloatTensor
23+
24+
1225
class UNet2DModel(ModelMixin, ConfigMixin):
1326
@register_to_config
1427
def __init__(
@@ -118,8 +131,11 @@ def __init__(
118131
self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, 3, padding=1)
119132

120133
def forward(
121-
self, sample: torch.FloatTensor, timestep: Union[torch.Tensor, float, int]
122-
) -> Dict[str, torch.FloatTensor]:
134+
self,
135+
sample: torch.FloatTensor,
136+
timestep: Union[torch.Tensor, float, int],
137+
return_dict: bool = True,
138+
) -> Union[UNet2DOutput, Tuple]:
123139
# 0. center input if necessary
124140
if self.config.center_input_sample:
125141
sample = 2 * sample - 1.0
@@ -181,6 +197,7 @@ def forward(
181197
timesteps = timesteps.reshape((sample.shape[0], *([1] * len(sample.shape[1:]))))
182198
sample = sample / timesteps
183199

184-
output = {"sample": sample}
200+
if not return_dict:
201+
return (sample,)
185202

186-
return output
203+
return UNet2DOutput(sample=sample)

src/diffusers/models/unet_2d_condition.py

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,27 @@
1-
from typing import Dict, Optional, Tuple, Union
1+
from dataclasses import dataclass
2+
from typing import Optional, Tuple, Union
23

34
import torch
45
import torch.nn as nn
56

67
from ..configuration_utils import ConfigMixin, register_to_config
78
from ..modeling_utils import ModelMixin
9+
from ..utils import BaseOutput
810
from .embeddings import TimestepEmbedding, Timesteps
911
from .unet_blocks import UNetMidBlock2DCrossAttn, get_down_block, get_up_block
1012

1113

14+
@dataclass
15+
class UNet2DConditionOutput(BaseOutput):
16+
"""
17+
Args:
18+
sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
19+
Hidden states conditioned on `encoder_hidden_states` input. Output of last layer of model.
20+
"""
21+
22+
sample: torch.FloatTensor
23+
24+
1225
class UNet2DConditionModel(ModelMixin, ConfigMixin):
1326
@register_to_config
1427
def __init__(
@@ -125,7 +138,8 @@ def forward(
125138
sample: torch.FloatTensor,
126139
timestep: Union[torch.Tensor, float, int],
127140
encoder_hidden_states: torch.Tensor,
128-
) -> Dict[str, torch.FloatTensor]:
141+
return_dict: bool = True,
142+
) -> Union[UNet2DConditionOutput, Tuple]:
129143
# 0. center input if necessary
130144
if self.config.center_input_sample:
131145
sample = 2 * sample - 1.0
@@ -183,6 +197,7 @@ def forward(
183197
sample = self.conv_act(sample)
184198
sample = self.conv_out(sample)
185199

186-
output = {"sample": sample}
200+
if not return_dict:
201+
return (sample,)
187202

188-
return output
203+
return UNet2DConditionOutput(sample=sample)

src/diffusers/models/vae.py

Lines changed: 87 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,56 @@
1-
from typing import Optional, Tuple
1+
from dataclasses import dataclass
2+
from typing import Optional, Tuple, Union
23

34
import numpy as np
45
import torch
56
import torch.nn as nn
67

78
from ..configuration_utils import ConfigMixin, register_to_config
89
from ..modeling_utils import ModelMixin
10+
from ..utils import BaseOutput
911
from .unet_blocks import UNetMidBlock2D, get_down_block, get_up_block
1012

1113

14+
@dataclass
15+
class DecoderOutput(BaseOutput):
16+
"""
17+
Output of decoding method.
18+
19+
Args:
20+
sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
21+
Decoded output sample of the model. Output of the last layer of the model.
22+
"""
23+
24+
sample: torch.FloatTensor
25+
26+
27+
@dataclass
28+
class VQEncoderOutput(BaseOutput):
29+
"""
30+
Output of VQModel encoding method.
31+
32+
Args:
33+
latents (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
34+
Encoded output sample of the model. Output of the last layer of the model.
35+
"""
36+
37+
latents: torch.FloatTensor
38+
39+
40+
@dataclass
41+
class AutoencoderKLOutput(BaseOutput):
42+
"""
43+
Output of AutoencoderKL encoding method.
44+
45+
Args:
46+
latent_dist (`DiagonalGaussianDistribution`):
47+
Encoded outputs of `Encoder` represented as the mean and logvar of `DiagonalGaussianDistribution`.
48+
`DiagonalGaussianDistribution` allows for sampling latents from the distribution.
49+
"""
50+
51+
latent_dist: "DiagonalGaussianDistribution"
52+
53+
1254
class Encoder(nn.Module):
1355
def __init__(
1456
self,
@@ -369,26 +411,40 @@ def __init__(
369411
act_fn=act_fn,
370412
)
371413

372-
def encode(self, x):
414+
def encode(self, x, return_dict: bool = True):
373415
h = self.encoder(x)
374416
h = self.quant_conv(h)
375-
return h
376417

377-
def decode(self, h, force_not_quantize=False):
418+
if not return_dict:
419+
return (h,)
420+
421+
return VQEncoderOutput(latents=h)
422+
423+
def decode(
424+
self, h: torch.FloatTensor, force_not_quantize: bool = False, return_dict: bool = True
425+
) -> Union[DecoderOutput, torch.FloatTensor]:
378426
# also go through quantization layer
379427
if not force_not_quantize:
380428
quant, emb_loss, info = self.quantize(h)
381429
else:
382430
quant = h
383431
quant = self.post_quant_conv(quant)
384432
dec = self.decoder(quant)
385-
return dec
386433

387-
def forward(self, sample: torch.FloatTensor) -> torch.FloatTensor:
434+
if not return_dict:
435+
return (dec,)
436+
437+
return DecoderOutput(sample=dec)
438+
439+
def forward(self, sample: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]:
388440
x = sample
389-
h = self.encode(x)
390-
dec = self.decode(h)
391-
return dec
441+
h = self.encode(x).latents
442+
dec = self.decode(h).sample
443+
444+
if not return_dict:
445+
return (dec,)
446+
447+
return DecoderOutput(sample=dec)
392448

393449

394450
class AutoencoderKL(ModelMixin, ConfigMixin):
@@ -431,23 +487,37 @@ def __init__(
431487
self.quant_conv = torch.nn.Conv2d(2 * latent_channels, 2 * latent_channels, 1)
432488
self.post_quant_conv = torch.nn.Conv2d(latent_channels, latent_channels, 1)
433489

434-
def encode(self, x):
490+
def encode(self, x, return_dict: bool = True):
435491
h = self.encoder(x)
436492
moments = self.quant_conv(h)
437493
posterior = DiagonalGaussianDistribution(moments)
438-
return posterior
439494

440-
def decode(self, z):
495+
if not return_dict:
496+
return (posterior,)
497+
498+
return AutoencoderKLOutput(latent_dist=posterior)
499+
500+
def decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]:
441501
z = self.post_quant_conv(z)
442502
dec = self.decoder(z)
443-
return dec
444503

445-
def forward(self, sample: torch.FloatTensor, sample_posterior: bool = False) -> torch.FloatTensor:
504+
if not return_dict:
505+
return (dec,)
506+
507+
return DecoderOutput(sample=dec)
508+
509+
def forward(
510+
self, sample: torch.FloatTensor, sample_posterior: bool = False, return_dict: bool = True
511+
) -> Union[DecoderOutput, torch.FloatTensor]:
446512
x = sample
447-
posterior = self.encode(x)
513+
posterior = self.encode(x).latent_dist
448514
if sample_posterior:
449515
z = posterior.sample()
450516
else:
451517
z = posterior.mode()
452-
dec = self.decode(z)
453-
return dec
518+
dec = self.decode(z).sample
519+
520+
if not return_dict:
521+
return (dec,)
522+
523+
return DecoderOutput(sample=dec)

0 commit comments

Comments
 (0)