Skip to content

Commit 033b77e

Browse files
authored
[Type hint] Latent Diffusion Uncond pipeline (#333)
1 parent e54206d commit 033b77e

File tree

1 file changed

+18
-2
lines changed

1 file changed

+18
-2
lines changed

src/diffusers/pipelines/latent_diffusion_uncond/pipeline_latent_diffusion_uncond.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,35 @@
11
import inspect
22
import warnings
3+
from typing import Optional
34

45
import torch
56

7+
from ...models import UNet2DModel, VQModel
68
from ...pipeline_utils import DiffusionPipeline
9+
from ...schedulers import DDIMScheduler
710

811

912
class LDMPipeline(DiffusionPipeline):
10-
def __init__(self, vqvae, unet, scheduler):
13+
14+
vqvae: VQModel
15+
unet: UNet2DModel
16+
scheduler: DDIMScheduler
17+
18+
def __init__(self, vqvae: VQModel, unet: UNet2DModel, scheduler: DDIMScheduler):
1119
super().__init__()
1220
scheduler = scheduler.set_format("pt")
1321
self.register_modules(vqvae=vqvae, unet=unet, scheduler=scheduler)
1422

1523
@torch.no_grad()
16-
def __call__(self, batch_size=1, generator=None, eta=0.0, num_inference_steps=50, output_type="pil", **kwargs):
24+
def __call__(
25+
self,
26+
batch_size: int = 1,
27+
generator: Optional[torch.Generator] = None,
28+
eta: float = 0.0,
29+
num_inference_steps: int = 50,
30+
output_type: Optional[str] = "pil",
31+
**kwargs,
32+
):
1733
# eta corresponds to η in paper and should be between [0, 1]
1834

1935
if "torch_device" in kwargs:

0 commit comments

Comments
 (0)