|
| 1 | +# Copyright 2022 The HuggingFace Team. All rights reserved. |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | + |
| 15 | +from typing import List, Optional, Tuple, Union |
| 16 | + |
| 17 | +import PIL |
| 18 | +import torch |
| 19 | +from torchvision import transforms |
| 20 | + |
| 21 | +from diffusers.pipeline_utils import DiffusionPipeline, ImagePipelineOutput |
| 22 | +from diffusers.schedulers import DDIMScheduler |
| 23 | +from diffusers.utils import randn_tensor |
| 24 | + |
| 25 | + |
| 26 | +trans = transforms.Compose( |
| 27 | + [ |
| 28 | + transforms.Resize((256, 256)), |
| 29 | + transforms.ToTensor(), |
| 30 | + transforms.Normalize([0.5], [0.5]), |
| 31 | + ] |
| 32 | +) |
| 33 | + |
| 34 | + |
| 35 | +def preprocess(image): |
| 36 | + if isinstance(image, torch.Tensor): |
| 37 | + return image |
| 38 | + elif isinstance(image, PIL.Image.Image): |
| 39 | + image = [image] |
| 40 | + |
| 41 | + image = [trans(img.convert("RGB")) for img in image] |
| 42 | + image = torch.stack(image) |
| 43 | + return image |
| 44 | + |
| 45 | + |
| 46 | +class DDIMNoiseComparativeAnalysisPipeline(DiffusionPipeline): |
| 47 | + r""" |
| 48 | + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the |
| 49 | + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) |
| 50 | +
|
| 51 | + Parameters: |
| 52 | + unet ([`UNet2DModel`]): U-Net architecture to denoise the encoded image. |
| 53 | + scheduler ([`SchedulerMixin`]): |
| 54 | + A scheduler to be used in combination with `unet` to denoise the encoded image. Can be one of |
| 55 | + [`DDPMScheduler`], or [`DDIMScheduler`]. |
| 56 | + """ |
| 57 | + |
| 58 | + def __init__(self, unet, scheduler): |
| 59 | + super().__init__() |
| 60 | + |
| 61 | + # make sure scheduler can always be converted to DDIM |
| 62 | + scheduler = DDIMScheduler.from_config(scheduler.config) |
| 63 | + |
| 64 | + self.register_modules(unet=unet, scheduler=scheduler) |
| 65 | + |
| 66 | + def check_inputs(self, strength): |
| 67 | + if strength < 0 or strength > 1: |
| 68 | + raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") |
| 69 | + |
| 70 | + def get_timesteps(self, num_inference_steps, strength, device): |
| 71 | + # get the original timestep using init_timestep |
| 72 | + init_timestep = min(int(num_inference_steps * strength), num_inference_steps) |
| 73 | + |
| 74 | + t_start = max(num_inference_steps - init_timestep, 0) |
| 75 | + timesteps = self.scheduler.timesteps[t_start:] |
| 76 | + |
| 77 | + return timesteps, num_inference_steps - t_start |
| 78 | + |
| 79 | + def prepare_latents(self, image, timestep, batch_size, dtype, device, generator=None): |
| 80 | + if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)): |
| 81 | + raise ValueError( |
| 82 | + f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}" |
| 83 | + ) |
| 84 | + |
| 85 | + init_latents = image.to(device=device, dtype=dtype) |
| 86 | + |
| 87 | + if isinstance(generator, list) and len(generator) != batch_size: |
| 88 | + raise ValueError( |
| 89 | + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" |
| 90 | + f" size of {batch_size}. Make sure the batch size matches the length of the generators." |
| 91 | + ) |
| 92 | + |
| 93 | + shape = init_latents.shape |
| 94 | + noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) |
| 95 | + |
| 96 | + # get latents |
| 97 | + print("add noise to latents at timestep", timestep) |
| 98 | + init_latents = self.scheduler.add_noise(init_latents, noise, timestep) |
| 99 | + latents = init_latents |
| 100 | + |
| 101 | + return latents |
| 102 | + |
| 103 | + @torch.no_grad() |
| 104 | + def __call__( |
| 105 | + self, |
| 106 | + image: Union[torch.FloatTensor, PIL.Image.Image] = None, |
| 107 | + strength: float = 0.8, |
| 108 | + batch_size: int = 1, |
| 109 | + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, |
| 110 | + eta: float = 0.0, |
| 111 | + num_inference_steps: int = 50, |
| 112 | + use_clipped_model_output: Optional[bool] = None, |
| 113 | + output_type: Optional[str] = "pil", |
| 114 | + return_dict: bool = True, |
| 115 | + ) -> Union[ImagePipelineOutput, Tuple]: |
| 116 | + r""" |
| 117 | + Args: |
| 118 | + image (`torch.FloatTensor` or `PIL.Image.Image`): |
| 119 | + `Image`, or tensor representing an image batch, that will be used as the starting point for the |
| 120 | + process. |
| 121 | + strength (`float`, *optional*, defaults to 0.8): |
| 122 | + Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1. `image` |
| 123 | + will be used as a starting point, adding more noise to it the larger the `strength`. The number of |
| 124 | + denoising steps depends on the amount of noise initially added. When `strength` is 1, added noise will |
| 125 | + be maximum and the denoising process will run for the full number of iterations specified in |
| 126 | + `num_inference_steps`. A value of 1, therefore, essentially ignores `image`. |
| 127 | + batch_size (`int`, *optional*, defaults to 1): |
| 128 | + The number of images to generate. |
| 129 | + generator (`torch.Generator`, *optional*): |
| 130 | + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) |
| 131 | + to make generation deterministic. |
| 132 | + eta (`float`, *optional*, defaults to 0.0): |
| 133 | + The eta parameter which controls the scale of the variance (0 is DDIM and 1 is one type of DDPM). |
| 134 | + num_inference_steps (`int`, *optional*, defaults to 50): |
| 135 | + The number of denoising steps. More denoising steps usually lead to a higher quality image at the |
| 136 | + expense of slower inference. |
| 137 | + use_clipped_model_output (`bool`, *optional*, defaults to `None`): |
| 138 | + if `True` or `False`, see documentation for `DDIMScheduler.step`. If `None`, nothing is passed |
| 139 | + downstream to the scheduler. So use `None` for schedulers which don't support this argument. |
| 140 | + output_type (`str`, *optional*, defaults to `"pil"`): |
| 141 | + The output format of the generate image. Choose between |
| 142 | + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. |
| 143 | + return_dict (`bool`, *optional*, defaults to `True`): |
| 144 | + Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple. |
| 145 | +
|
| 146 | + Returns: |
| 147 | + [`~pipelines.ImagePipelineOutput`] or `tuple`: [`~pipelines.utils.ImagePipelineOutput`] if `return_dict` is |
| 148 | + True, otherwise a `tuple. When returning a tuple, the first element is a list with the generated images. |
| 149 | + """ |
| 150 | + # 1. Check inputs. Raise error if not correct |
| 151 | + self.check_inputs(strength) |
| 152 | + |
| 153 | + # 2. Preprocess image |
| 154 | + image = preprocess(image) |
| 155 | + |
| 156 | + # 3. set timesteps |
| 157 | + self.scheduler.set_timesteps(num_inference_steps, device=self.device) |
| 158 | + timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, self.device) |
| 159 | + latent_timestep = timesteps[:1].repeat(batch_size) |
| 160 | + |
| 161 | + # 4. Prepare latent variables |
| 162 | + latents = self.prepare_latents(image, latent_timestep, batch_size, self.unet.dtype, self.device, generator) |
| 163 | + image = latents |
| 164 | + |
| 165 | + # 5. Denoising loop |
| 166 | + for t in self.progress_bar(timesteps): |
| 167 | + # 1. predict noise model_output |
| 168 | + model_output = self.unet(image, t).sample |
| 169 | + |
| 170 | + # 2. predict previous mean of image x_t-1 and add variance depending on eta |
| 171 | + # eta corresponds to η in paper and should be between [0, 1] |
| 172 | + # do x_t -> x_t-1 |
| 173 | + image = self.scheduler.step( |
| 174 | + model_output, |
| 175 | + t, |
| 176 | + image, |
| 177 | + eta=eta, |
| 178 | + use_clipped_model_output=use_clipped_model_output, |
| 179 | + generator=generator, |
| 180 | + ).prev_sample |
| 181 | + |
| 182 | + image = (image / 2 + 0.5).clamp(0, 1) |
| 183 | + image = image.cpu().permute(0, 2, 3, 1).numpy() |
| 184 | + if output_type == "pil": |
| 185 | + image = self.numpy_to_pil(image) |
| 186 | + |
| 187 | + if not return_dict: |
| 188 | + return (image, latent_timestep.item()) |
| 189 | + |
| 190 | + return ImagePipelineOutput(images=image) |
0 commit comments