diff --git a/src/diffusers/pipelines/pndm/pipeline_pndm.py b/src/diffusers/pipelines/pndm/pipeline_pndm.py index 32ddbd8c1711..632e31777a53 100644 --- a/src/diffusers/pipelines/pndm/pipeline_pndm.py +++ b/src/diffusers/pipelines/pndm/pipeline_pndm.py @@ -15,20 +15,33 @@ import warnings +from typing import Optional import torch +from ...models import UNet2DModel from ...pipeline_utils import DiffusionPipeline +from ...schedulers import PNDMScheduler class PNDMPipeline(DiffusionPipeline): - def __init__(self, unet, scheduler): + unet: UNet2DModel + scheduler: PNDMScheduler + + def __init__(self, unet: UNet2DModel, scheduler: PNDMScheduler): super().__init__() scheduler = scheduler.set_format("pt") self.register_modules(unet=unet, scheduler=scheduler) @torch.no_grad() - def __call__(self, batch_size=1, generator=None, num_inference_steps=50, output_type="pil", **kwargs): + def __call__( + self, + batch_size: int = 1, + num_inference_steps: int = 50, + generator: Optional[torch.Generator] = None, + output_type: Optional[str] = "pil", + **kwargs, + ): # For more information on the sampling method you can take a look at Algorithm 2 of # the official paper: https://arxiv.org/pdf/2202.09778.pdf