|
| 1 | +# Copyright 2022 ETH Zurich Computer Vision Lab and The HuggingFace Team. All rights reserved. |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | + |
| 15 | + |
| 16 | +from typing import Optional, Tuple, Union |
| 17 | + |
| 18 | +import numpy as np |
| 19 | +import torch |
| 20 | + |
| 21 | +import PIL |
| 22 | +from tqdm.auto import tqdm |
| 23 | + |
| 24 | +from ...models import UNet2DModel |
| 25 | +from ...pipeline_utils import DiffusionPipeline, ImagePipelineOutput |
| 26 | +from ...schedulers import RePaintScheduler |
| 27 | + |
| 28 | + |
| 29 | +def _preprocess_image(image: PIL.Image.Image): |
| 30 | + image = np.array(image.convert("RGB")) |
| 31 | + image = image[None].transpose(0, 3, 1, 2) |
| 32 | + image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0 |
| 33 | + return image |
| 34 | + |
| 35 | + |
| 36 | +def _preprocess_mask(mask: PIL.Image.Image): |
| 37 | + mask = np.array(mask.convert("L")) |
| 38 | + mask = mask.astype(np.float32) / 255.0 |
| 39 | + mask = mask[None, None] |
| 40 | + mask[mask < 0.5] = 0 |
| 41 | + mask[mask >= 0.5] = 1 |
| 42 | + mask = torch.from_numpy(mask) |
| 43 | + return mask |
| 44 | + |
| 45 | + |
| 46 | +class RePaintPipeline(DiffusionPipeline): |
| 47 | + unet: UNet2DModel |
| 48 | + scheduler: RePaintScheduler |
| 49 | + |
| 50 | + def __init__(self, unet, scheduler): |
| 51 | + super().__init__() |
| 52 | + self.register_modules(unet=unet, scheduler=scheduler) |
| 53 | + |
| 54 | + @torch.no_grad() |
| 55 | + def __call__( |
| 56 | + self, |
| 57 | + original_image: Union[torch.FloatTensor, PIL.Image.Image], |
| 58 | + mask_image: Union[torch.FloatTensor, PIL.Image.Image], |
| 59 | + num_inference_steps: int = 250, |
| 60 | + eta: float = 0.0, |
| 61 | + jump_length: int = 10, |
| 62 | + jump_n_sample: int = 10, |
| 63 | + generator: Optional[torch.Generator] = None, |
| 64 | + output_type: Optional[str] = "pil", |
| 65 | + return_dict: bool = True, |
| 66 | + ) -> Union[ImagePipelineOutput, Tuple]: |
| 67 | + r""" |
| 68 | + Args: |
| 69 | + original_image (`torch.FloatTensor` or `PIL.Image.Image`): |
| 70 | + The original image to inpaint on. |
| 71 | + mask_image (`torch.FloatTensor` or `PIL.Image.Image`): |
| 72 | + The mask_image where 0.0 values define which part of the original image to inpaint (change). |
| 73 | + num_inference_steps (`int`, *optional*, defaults to 1000): |
| 74 | + The number of denoising steps. More denoising steps usually lead to a higher quality image at the |
| 75 | + expense of slower inference. |
| 76 | + eta (`float`): |
| 77 | + The weight of noise for added noise in a diffusion step. Its value is between 0.0 and 1.0 - 0.0 is DDIM |
| 78 | + and 1.0 is DDPM scheduler respectively. |
| 79 | + jump_length (`int`, *optional*, defaults to 10): |
| 80 | + The number of steps taken forward in time before going backward in time for a single jump ("j" in |
| 81 | + RePaint paper). Take a look at Figure 9 and 10 in https://arxiv.org/pdf/2201.09865.pdf. |
| 82 | + jump_n_sample (`int`, *optional*, defaults to 10): |
| 83 | + The number of times we will make forward time jump for a given chosen time sample. Take a look at |
| 84 | + Figure 9 and 10 in https://arxiv.org/pdf/2201.09865.pdf. |
| 85 | + generator (`torch.Generator`, *optional*): |
| 86 | + A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation |
| 87 | + deterministic. |
| 88 | + output_type (`str`, *optional*, defaults to `"pil"`): |
| 89 | + The output format of the generate image. Choose between |
| 90 | + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. |
| 91 | + return_dict (`bool`, *optional*, defaults to `True`): |
| 92 | + Whether or not to return a [`~pipeline_utils.ImagePipelineOutput`] instead of a plain tuple. |
| 93 | +
|
| 94 | + Returns: |
| 95 | + [`~pipeline_utils.ImagePipelineOutput`] or `tuple`: [`~pipelines.utils.ImagePipelineOutput`] if |
| 96 | + `return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the |
| 97 | + generated images. |
| 98 | + """ |
| 99 | + |
| 100 | + if not isinstance(original_image, torch.FloatTensor): |
| 101 | + original_image = _preprocess_image(original_image) |
| 102 | + original_image = original_image.to(self.device) |
| 103 | + if not isinstance(mask_image, torch.FloatTensor): |
| 104 | + mask_image = _preprocess_mask(mask_image) |
| 105 | + mask_image = mask_image.to(self.device) |
| 106 | + |
| 107 | + # sample gaussian noise to begin the loop |
| 108 | + image = torch.randn( |
| 109 | + original_image.shape, |
| 110 | + generator=generator, |
| 111 | + device=self.device, |
| 112 | + ) |
| 113 | + image = image.to(self.device) |
| 114 | + |
| 115 | + # set step values |
| 116 | + self.scheduler.set_timesteps(num_inference_steps, jump_length, jump_n_sample, self.device) |
| 117 | + self.scheduler.eta = eta |
| 118 | + |
| 119 | + t_last = self.scheduler.timesteps[0] + 1 |
| 120 | + for i, t in enumerate(tqdm(self.scheduler.timesteps)): |
| 121 | + if t < t_last: |
| 122 | + # predict the noise residual |
| 123 | + model_output = self.unet(image, t).sample |
| 124 | + # compute previous image: x_t -> x_t-1 |
| 125 | + image = self.scheduler.step(model_output, t, image, original_image, mask_image, generator).prev_sample |
| 126 | + |
| 127 | + else: |
| 128 | + # compute the reverse: x_t-1 -> x_t |
| 129 | + image = self.scheduler.undo_step(image, t_last, generator) |
| 130 | + t_last = t |
| 131 | + |
| 132 | + image = (image / 2 + 0.5).clamp(0, 1) |
| 133 | + image = image.cpu().permute(0, 2, 3, 1).numpy() |
| 134 | + if output_type == "pil": |
| 135 | + image = self.numpy_to_pil(image) |
| 136 | + |
| 137 | + if not return_dict: |
| 138 | + return (image,) |
| 139 | + |
| 140 | + return ImagePipelineOutput(images=image) |
0 commit comments