Skip to content

Commit 7b628a2

Browse files
authored
[Type hint] PNDM pipeline (#327)
* [Type hint] PNDM pipeline * ran make style * Revert "ran make style" wrong black version
1 parent 033b77e commit 7b628a2

File tree

1 file changed

+15
-2
lines changed

1 file changed

+15
-2
lines changed

src/diffusers/pipelines/pndm/pipeline_pndm.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,20 +15,33 @@
1515

1616

1717
import warnings
18+
from typing import Optional
1819

1920
import torch
2021

22+
from ...models import UNet2DModel
2123
from ...pipeline_utils import DiffusionPipeline
24+
from ...schedulers import PNDMScheduler
2225

2326

2427
class PNDMPipeline(DiffusionPipeline):
25-
def __init__(self, unet, scheduler):
28+
unet: UNet2DModel
29+
scheduler: PNDMScheduler
30+
31+
def __init__(self, unet: UNet2DModel, scheduler: PNDMScheduler):
2632
super().__init__()
2733
scheduler = scheduler.set_format("pt")
2834
self.register_modules(unet=unet, scheduler=scheduler)
2935

3036
@torch.no_grad()
31-
def __call__(self, batch_size=1, generator=None, num_inference_steps=50, output_type="pil", **kwargs):
37+
def __call__(
38+
self,
39+
batch_size: int = 1,
40+
num_inference_steps: int = 50,
41+
generator: Optional[torch.Generator] = None,
42+
output_type: Optional[str] = "pil",
43+
**kwargs,
44+
):
3245
# For more information on the sampling method you can take a look at Algorithm 2 of
3346
# the official paper: https://arxiv.org/pdf/2202.09778.pdf
3447

0 commit comments

Comments
 (0)