From 159e15cd924a8ac5951a4abc6734b1174136347d Mon Sep 17 00:00:00 2001 From: anton-l Date: Tue, 4 Oct 2022 15:28:56 +0200 Subject: [PATCH 1/7] init --- .../pipeline_stable_diffusion.py | 15 ++++-------- .../pipeline_stable_diffusion_img2img.py | 13 +++-------- .../pipeline_stable_diffusion_inpaint.py | 10 +++----- .../pipeline_stable_diffusion_onnx.py | 15 +++--------- src/diffusers/schedulers/scheduling_ddim.py | 6 +++++ .../schedulers/scheduling_lms_discrete.py | 23 +++++++++++++++---- src/diffusers/schedulers/scheduling_pndm.py | 6 +++++ 7 files changed, 43 insertions(+), 45 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py index d190acb1fa1c..2804e61b1fa1 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py @@ -257,9 +257,8 @@ def __call__( else: timesteps_tensor = torch.tensor(self.scheduler.timesteps.copy(), device=self.device) - # if we use LMSDiscreteScheduler, let's make sure latents are multiplied by sigmas - if isinstance(self.scheduler, LMSDiscreteScheduler): - latents = latents * self.scheduler.sigmas[0] + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. @@ -273,10 +272,7 @@ def __call__( for i, t in enumerate(self.progress_bar(timesteps_tensor)): # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents - if isinstance(self.scheduler, LMSDiscreteScheduler): - sigma = self.scheduler.sigmas[i] - # the model input needs to be scaled to match the continuous ODE formulation in K-LMS - latent_model_input = latent_model_input / ((sigma**2 + 1) ** 0.5) + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) # predict the noise residual noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample @@ -287,10 +283,7 @@ def __call__( noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) # compute the previous noisy sample x_t -> x_t-1 - if isinstance(self.scheduler, LMSDiscreteScheduler): - latents = self.scheduler.step(noise_pred, i, latents, **extra_step_kwargs).prev_sample - else: - latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample # call the callback, if provided if callback is not None and i % callback_steps == 0: diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py index c8f02b5896d6..9df0953f2050 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py @@ -222,6 +222,7 @@ def __call__( offset = self.scheduler.config.get("steps_offset", 0) init_timestep = int(num_inference_steps * strength) + offset init_timestep = min(init_timestep, num_inference_steps) + # FIXME if isinstance(self.scheduler, LMSDiscreteScheduler): timesteps = torch.tensor( [num_inference_steps - init_timestep] * batch_size, dtype=torch.long, device=self.device @@ -291,12 +292,7 @@ def __call__( # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents - - # if we use LMSDiscreteScheduler, let's make sure latents are multiplied by sigmas - if isinstance(self.scheduler, LMSDiscreteScheduler): - sigma = self.scheduler.sigmas[t_index] - # the model input needs to be scaled to match the continuous ODE formulation in K-LMS - latent_model_input = latent_model_input / ((sigma**2 + 1) ** 0.5) + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) # predict the noise residual noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample @@ -307,10 +303,7 @@ def __call__( noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) # compute the previous noisy sample x_t -> x_t-1 - if isinstance(self.scheduler, LMSDiscreteScheduler): - latents = self.scheduler.step(noise_pred, t_index, latents, **extra_step_kwargs).prev_sample - else: - latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample # call the callback, if provided if callback is not None and i % callback_steps == 0: diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py index 21490d975730..cea77c230940 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py @@ -321,13 +321,9 @@ def __call__( timesteps_tensor = torch.tensor(self.scheduler.timesteps[t_start:], device=self.device) for i, t in tqdm(enumerate(timesteps_tensor)): - t_index = t_start + i # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents - if isinstance(self.scheduler, LMSDiscreteScheduler): - sigma = self.scheduler.sigmas[t_index] - # the model input needs to be scaled to match the continuous ODE formulation in K-LMS - latent_model_input = latent_model_input / ((sigma**2 + 1) ** 0.5) + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) # predict the noise residual noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample @@ -338,12 +334,12 @@ def __call__( noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample if isinstance(self.scheduler, LMSDiscreteScheduler): - latents = self.scheduler.step(noise_pred, t_index, latents, **extra_step_kwargs).prev_sample # masking + t_index = t_start + i init_latents_proper = self.scheduler.add_noise(init_latents_orig, noise, torch.LongTensor([t_index])) else: - latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample # masking init_latents_proper = self.scheduler.add_noise(init_latents_orig, noise, torch.LongTensor([t])) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_onnx.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_onnx.py index 48e0e147c48d..90dc3b0023db 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_onnx.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_onnx.py @@ -123,9 +123,7 @@ def __call__( # set timesteps self.scheduler.set_timesteps(num_inference_steps) - # if we use LMSDiscreteScheduler, let's make sure latents are multiplied by sigmas - if isinstance(self.scheduler, LMSDiscreteScheduler): - latents = latents * self.scheduler.sigmas[0] + latents = latents * self.scheduler.init_noise_sigma # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. @@ -139,10 +137,7 @@ def __call__( for i, t in enumerate(self.progress_bar(self.scheduler.timesteps)): # expand the latents if we are doing classifier free guidance latent_model_input = np.concatenate([latents] * 2) if do_classifier_free_guidance else latents - if isinstance(self.scheduler, LMSDiscreteScheduler): - sigma = self.scheduler.sigmas[i] - # the model input needs to be scaled to match the continuous ODE formulation in K-LMS - latent_model_input = latent_model_input / ((sigma**2 + 1) ** 0.5) + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) # predict the noise residual noise_pred = self.unet( @@ -156,11 +151,7 @@ def __call__( noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) # compute the previous noisy sample x_t -> x_t-1 - if isinstance(self.scheduler, LMSDiscreteScheduler): - latents = self.scheduler.step(noise_pred, i, latents, **extra_step_kwargs).prev_sample - else: - latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample - + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample latents = np.array(latents) # call the callback, if provided diff --git a/src/diffusers/schedulers/scheduling_ddim.py b/src/diffusers/schedulers/scheduling_ddim.py index a728ab29d7bb..56d6216ff46f 100644 --- a/src/diffusers/schedulers/scheduling_ddim.py +++ b/src/diffusers/schedulers/scheduling_ddim.py @@ -152,10 +152,16 @@ def __init__( # whether we use the final alpha of the "non-previous" one. self.final_alpha_cumprod = torch.tensor(1.0) if set_alpha_to_one else self.alphas_cumprod[0] + # standard deviation of the initial noise distribution + self.init_noise_sigma = 1.0 + # setable values self.num_inference_steps = None self.timesteps = np.arange(0, num_train_timesteps)[::-1] + def scale_model_input(self, sample, timestep): + return sample + def _get_variance(self, timestep, prev_timestep): alpha_prod_t = self.alphas_cumprod[timestep] alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod diff --git a/src/diffusers/schedulers/scheduling_lms_discrete.py b/src/diffusers/schedulers/scheduling_lms_discrete.py index 33e9558d9c38..00bc75a3bfef 100644 --- a/src/diffusers/schedulers/scheduling_lms_discrete.py +++ b/src/diffusers/schedulers/scheduling_lms_discrete.py @@ -102,12 +102,24 @@ def __init__( sigmas = np.concatenate([sigmas[::-1], [0.0]]).astype(np.float32) self.sigmas = torch.from_numpy(sigmas) + # standard deviation of the initial noise distribution + self.init_noise_sigma = self.sigmas.max() + # setable values self.num_inference_steps = None timesteps = np.linspace(0, num_train_timesteps - 1, num_train_timesteps, dtype=float)[::-1].copy() self.timesteps = torch.from_numpy(timesteps) self.derivatives = [] + def scale_model_input(self, sample, timestep): + """ + Scale the model input to match the ODE solver in K-LMS + """ + step_index = (self.timesteps == timestep).nonzero().item() + sigma = self.sigmas[step_index] + sample = sample / ((sigma**2 + 1) ** 0.5) + return sample + def get_lms_coefficient(self, order, t, current_order): """ Compute a linear multistep coefficient. @@ -154,7 +166,7 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic def step( self, model_output: torch.FloatTensor, - timestep: int, + timestep: float, sample: torch.FloatTensor, order: int = 4, return_dict: bool = True, @@ -165,7 +177,7 @@ def step( Args: model_output (`torch.FloatTensor`): direct output from learned diffusion model. - timestep (`int`): current discrete timestep in the diffusion chain. + timestep (`float`): current timestep in the diffusion chain. sample (`torch.FloatTensor`): current instance of sample being created by diffusion process. order: coefficient for multi-step inference. @@ -177,7 +189,8 @@ def step( When returning a tuple, the first element is the sample tensor. """ - sigma = self.sigmas[timestep] + step_index = (self.timesteps == timestep).nonzero().item() + sigma = self.sigmas[step_index] # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise pred_original_sample = sample - sigma * model_output @@ -189,8 +202,8 @@ def step( self.derivatives.pop(0) # 3. Compute linear multistep coefficients - order = min(timestep + 1, order) - lms_coeffs = [self.get_lms_coefficient(order, timestep, curr_order) for curr_order in range(order)] + order = min(step_index + 1, order) + lms_coeffs = [self.get_lms_coefficient(order, step_index, curr_order) for curr_order in range(order)] # 4. Compute previous sample based on the derivatives path prev_sample = sample + sum( diff --git a/src/diffusers/schedulers/scheduling_pndm.py b/src/diffusers/schedulers/scheduling_pndm.py index 3974335a2f1b..81b514ccdbb3 100644 --- a/src/diffusers/schedulers/scheduling_pndm.py +++ b/src/diffusers/schedulers/scheduling_pndm.py @@ -129,6 +129,9 @@ def __init__( self.final_alpha_cumprod = torch.tensor(1.0) if set_alpha_to_one else self.alphas_cumprod[0] + # standard deviation of the initial noise distribution + self.init_noise_sigma = 1.0 + # For now we only support F-PNDM, i.e. the runge-kutta method # For more information on the algorithm please take a look at the paper: https://arxiv.org/pdf/2202.09778.pdf # mainly at formula (9), (12), (13) and the Algorithm 2. @@ -341,6 +344,9 @@ def step_plms( return SchedulerOutput(prev_sample=prev_sample) + def scale_model_input(self, sample, timestep): + return sample + def _get_prev_sample(self, sample, timestep, prev_timestep, model_output): # See formula (9) of PNDM paper https://arxiv.org/pdf/2202.09778.pdf # this function computes x_(t−δ) using the formula of (9) From 74ae71737b8617096ea1fa5f73535f8eb9235050 Mon Sep 17 00:00:00 2001 From: anton-l Date: Tue, 4 Oct 2022 17:30:10 +0200 Subject: [PATCH 2/7] improve add_noise --- .../pipeline_stable_diffusion_img2img.py | 13 +++---------- .../pipeline_stable_diffusion_inpaint.py | 19 +++++-------------- src/diffusers/schedulers/scheduling_ddim.py | 2 +- .../schedulers/scheduling_lms_discrete.py | 16 ++++++++++++---- src/diffusers/schedulers/scheduling_pndm.py | 2 +- tests/test_scheduler.py | 10 +++++----- 6 files changed, 27 insertions(+), 35 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py index 9df0953f2050..27cd5b459f9a 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py @@ -222,14 +222,9 @@ def __call__( offset = self.scheduler.config.get("steps_offset", 0) init_timestep = int(num_inference_steps * strength) + offset init_timestep = min(init_timestep, num_inference_steps) - # FIXME - if isinstance(self.scheduler, LMSDiscreteScheduler): - timesteps = torch.tensor( - [num_inference_steps - init_timestep] * batch_size, dtype=torch.long, device=self.device - ) - else: - timesteps = self.scheduler.timesteps[-init_timestep] - timesteps = torch.tensor([timesteps] * batch_size, dtype=torch.long, device=self.device) + + timesteps = self.scheduler.timesteps[-init_timestep] + timesteps = torch.tensor([timesteps] * batch_size, device=self.device) # add noise to latents using the timesteps noise = torch.randn(init_latents.shape, generator=generator, device=self.device) @@ -288,8 +283,6 @@ def __call__( timesteps_tensor = torch.tensor(self.scheduler.timesteps[t_start:], device=self.device) for i, t in enumerate(self.progress_bar(timesteps_tensor)): - t_index = t_start + i - # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py index cea77c230940..9b877bb1fc75 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py @@ -256,13 +256,9 @@ def __call__( offset = self.scheduler.config.get("steps_offset", 0) init_timestep = int(num_inference_steps * strength) + offset init_timestep = min(init_timestep, num_inference_steps) - if isinstance(self.scheduler, LMSDiscreteScheduler): - timesteps = torch.tensor( - [num_inference_steps - init_timestep] * batch_size, dtype=torch.long, device=self.device - ) - else: - timesteps = self.scheduler.timesteps[-init_timestep] - timesteps = torch.tensor([timesteps] * batch_size, dtype=torch.long, device=self.device) + + timesteps = self.scheduler.timesteps[-init_timestep] + timesteps = torch.tensor([timesteps] * batch_size, device=self.device) # add noise to latents using the timesteps noise = torch.randn(init_latents.shape, generator=generator, device=self.device) @@ -335,13 +331,8 @@ def __call__( # compute the previous noisy sample x_t -> x_t-1 latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample - if isinstance(self.scheduler, LMSDiscreteScheduler): - # masking - t_index = t_start + i - init_latents_proper = self.scheduler.add_noise(init_latents_orig, noise, torch.LongTensor([t_index])) - else: - # masking - init_latents_proper = self.scheduler.add_noise(init_latents_orig, noise, torch.LongTensor([t])) + # masking + init_latents_proper = self.scheduler.add_noise(init_latents_orig, noise, torch.tensor([t])) latents = (init_latents_proper * mask) + (latents * (1 - mask)) diff --git a/src/diffusers/schedulers/scheduling_ddim.py b/src/diffusers/schedulers/scheduling_ddim.py index 56d6216ff46f..3472043fe065 100644 --- a/src/diffusers/schedulers/scheduling_ddim.py +++ b/src/diffusers/schedulers/scheduling_ddim.py @@ -159,7 +159,7 @@ def __init__( self.num_inference_steps = None self.timesteps = np.arange(0, num_train_timesteps)[::-1] - def scale_model_input(self, sample, timestep): + def scale_model_input(self, sample: torch.FloatTensor, timestep): return sample def _get_variance(self, timestep, prev_timestep): diff --git a/src/diffusers/schedulers/scheduling_lms_discrete.py b/src/diffusers/schedulers/scheduling_lms_discrete.py index 00bc75a3bfef..9af8daa744d3 100644 --- a/src/diffusers/schedulers/scheduling_lms_discrete.py +++ b/src/diffusers/schedulers/scheduling_lms_discrete.py @@ -111,10 +111,14 @@ def __init__( self.timesteps = torch.from_numpy(timesteps) self.derivatives = [] - def scale_model_input(self, sample, timestep): + def scale_model_input( + self, sample: torch.FloatTensor, timestep: Union[float, torch.FloatTensor] + ) -> torch.FloatTensor: """ Scale the model input to match the ODE solver in K-LMS """ + if isinstance(timestep, torch.Tensor): + timestep = timestep.to(self.timesteps.device) step_index = (self.timesteps == timestep).nonzero().item() sigma = self.sigmas[step_index] sample = sample / ((sigma**2 + 1) ** 0.5) @@ -166,7 +170,7 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic def step( self, model_output: torch.FloatTensor, - timestep: float, + timestep: Union[float, torch.FloatTensor], sample: torch.FloatTensor, order: int = 4, return_dict: bool = True, @@ -189,6 +193,8 @@ def step( When returning a tuple, the first element is the sample tensor. """ + if isinstance(timestep, torch.Tensor): + timestep = timestep.to(self.timesteps.device) step_index = (self.timesteps == timestep).nonzero().item() sigma = self.sigmas[step_index] @@ -219,12 +225,14 @@ def add_noise( self, original_samples: torch.FloatTensor, noise: torch.FloatTensor, - timesteps: torch.IntTensor, + timesteps: torch.FloatTensor, ) -> torch.FloatTensor: sigmas = self.sigmas.to(original_samples.device) + schedule_timesteps = self.timesteps.to(original_samples.device) timesteps = timesteps.to(original_samples.device) + step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps] - sigma = sigmas[timesteps].flatten() + sigma = sigmas[step_indices].flatten() while len(sigma.shape) < len(original_samples.shape): sigma = sigma.unsqueeze(-1) diff --git a/src/diffusers/schedulers/scheduling_pndm.py b/src/diffusers/schedulers/scheduling_pndm.py index 81b514ccdbb3..d5dd2600a557 100644 --- a/src/diffusers/schedulers/scheduling_pndm.py +++ b/src/diffusers/schedulers/scheduling_pndm.py @@ -344,7 +344,7 @@ def step_plms( return SchedulerOutput(prev_sample=prev_sample) - def scale_model_input(self, sample, timestep): + def scale_model_input(self, sample: torch.FloatTensor, timestep): return sample def _get_prev_sample(self, sample, timestep, prev_timestep, model_output): diff --git a/tests/test_scheduler.py b/tests/test_scheduler.py index bee36c39acdb..04ebd1a7a0be 100755 --- a/tests/test_scheduler.py +++ b/tests/test_scheduler.py @@ -201,7 +201,7 @@ def recursive_check(tuple_object, dict_object): ) kwargs = dict(self.forward_default_kwargs) - num_inference_steps = kwargs.pop("num_inference_steps", None) + num_inference_steps = kwargs.pop("num_inference_steps", 50) for scheduler_class in self.scheduler_classes: scheduler_config = self.get_scheduler_config() @@ -845,7 +845,7 @@ def test_timesteps(self): def test_betas(self): for beta_start, beta_end in zip([0.00001, 0.0001, 0.001], [0.0002, 0.002, 0.02]): - self.check_over_configs(beta_start=beta_start, beta_end=beta_end) + self.check_over_configs(beta_start=beta_start, beta_end=beta_end, time_step=0.0) def test_schedules(self): for schedule in ["linear", "scaled_linear"]: @@ -863,14 +863,14 @@ def test_full_loop_no_noise(self): scheduler.set_timesteps(self.num_inference_steps) model = self.dummy_model() - sample = self.dummy_sample_deter * scheduler.sigmas[0] + sample = self.dummy_sample_deter * scheduler.init_noise_sigma for i, t in enumerate(scheduler.timesteps): - sample = sample / ((scheduler.sigmas[i] ** 2 + 1) ** 0.5) + sample = scheduler.scale_model_input(sample, t) model_output = model(sample, t) - output = scheduler.step(model_output, i, sample) + output = scheduler.step(model_output, t, sample) sample = output.prev_sample result_sum = torch.sum(torch.abs(sample)) From fa9667f4f7291286ac7f8cee96a24df64dcf8a79 Mon Sep 17 00:00:00 2001 From: anton-l Date: Tue, 4 Oct 2022 17:37:05 +0200 Subject: [PATCH 3/7] [debug start] run slow test --- .github/workflows/push_tests.yml | 3 +++ 1 file changed, 3 insertions(+) diff --git a/.github/workflows/push_tests.yml b/.github/workflows/push_tests.yml index 3e4a81c91c01..20997da45ed7 100644 --- a/.github/workflows/push_tests.yml +++ b/.github/workflows/push_tests.yml @@ -4,6 +4,9 @@ on: push: branches: - main + pull_request: + branches: + - main env: HF_HOME: /mnt/cache From c06af2b7d7545b1e1fb1236508df07da851135be Mon Sep 17 00:00:00 2001 From: anton-l Date: Tue, 4 Oct 2022 19:23:19 +0200 Subject: [PATCH 4/7] [debug end] --- .github/workflows/push_tests.yml | 3 --- 1 file changed, 3 deletions(-) diff --git a/.github/workflows/push_tests.yml b/.github/workflows/push_tests.yml index 20997da45ed7..3e4a81c91c01 100644 --- a/.github/workflows/push_tests.yml +++ b/.github/workflows/push_tests.yml @@ -4,9 +4,6 @@ on: push: branches: - main - pull_request: - branches: - - main env: HF_HOME: /mnt/cache From 9325ca459898421bc2e62ec2ee25a556251f7838 Mon Sep 17 00:00:00 2001 From: anton-l Date: Tue, 4 Oct 2022 19:32:59 +0200 Subject: [PATCH 5/7] quick revert --- tests/test_scheduler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_scheduler.py b/tests/test_scheduler.py index 04ebd1a7a0be..cb0d554cfebc 100755 --- a/tests/test_scheduler.py +++ b/tests/test_scheduler.py @@ -845,7 +845,7 @@ def test_timesteps(self): def test_betas(self): for beta_start, beta_end in zip([0.00001, 0.0001, 0.001], [0.0002, 0.002, 0.02]): - self.check_over_configs(beta_start=beta_start, beta_end=beta_end, time_step=0.0) + self.check_over_configs(beta_start=beta_start, beta_end=beta_end) def test_schedules(self): for schedule in ["linear", "scaled_linear"]: From 46ceb1084d4b9c9c6c21beeb6ff4c4ca59b8b7e0 Mon Sep 17 00:00:00 2001 From: anton-l Date: Wed, 5 Oct 2022 14:00:17 +0200 Subject: [PATCH 6/7] Add docstrings and warnings + API tests --- .../score_sde_ve/pipeline_score_sde_ve.py | 2 +- .../pipeline_stochastic_karras_ve.py | 2 +- src/diffusers/schedulers/scheduling_ddim.py | 13 +++++++++- src/diffusers/schedulers/scheduling_ddpm.py | 17 +++++++++++++ .../schedulers/scheduling_karras_ve.py | 17 +++++++++++++ .../schedulers/scheduling_lms_discrete.py | 25 +++++++++++++++++-- src/diffusers/schedulers/scheduling_pndm.py | 12 ++++++++- src/diffusers/schedulers/scheduling_sde_ve.py | 17 +++++++++++++ tests/test_scheduler.py | 21 ++++++++++++++++ 9 files changed, 120 insertions(+), 6 deletions(-) diff --git a/src/diffusers/pipelines/score_sde_ve/pipeline_score_sde_ve.py b/src/diffusers/pipelines/score_sde_ve/pipeline_score_sde_ve.py index 1907fc7d50d8..7f63820eec28 100644 --- a/src/diffusers/pipelines/score_sde_ve/pipeline_score_sde_ve.py +++ b/src/diffusers/pipelines/score_sde_ve/pipeline_score_sde_ve.py @@ -57,7 +57,7 @@ def __call__( model = self.unet - sample = torch.randn(*shape, generator=generator) * self.scheduler.config.sigma_max + sample = torch.randn(*shape, generator=generator) * self.scheduler.init_noise_sigma sample = sample.to(self.device) self.scheduler.set_timesteps(num_inference_steps) diff --git a/src/diffusers/pipelines/stochastic_karras_ve/pipeline_stochastic_karras_ve.py b/src/diffusers/pipelines/stochastic_karras_ve/pipeline_stochastic_karras_ve.py index 35d06106869e..9e8864b4ca76 100644 --- a/src/diffusers/pipelines/stochastic_karras_ve/pipeline_stochastic_karras_ve.py +++ b/src/diffusers/pipelines/stochastic_karras_ve/pipeline_stochastic_karras_ve.py @@ -69,7 +69,7 @@ def __call__( model = self.unet # sample x_0 ~ N(0, sigma_0^2 * I) - sample = torch.randn(*shape) * self.scheduler.config.sigma_max + sample = torch.randn(*shape) * self.scheduler.init_noise_sigma sample = sample.to(self.device) self.scheduler.set_timesteps(num_inference_steps) diff --git a/src/diffusers/schedulers/scheduling_ddim.py b/src/diffusers/schedulers/scheduling_ddim.py index 3472043fe065..cd5bc8cdfbc4 100644 --- a/src/diffusers/schedulers/scheduling_ddim.py +++ b/src/diffusers/schedulers/scheduling_ddim.py @@ -159,7 +159,18 @@ def __init__( self.num_inference_steps = None self.timesteps = np.arange(0, num_train_timesteps)[::-1] - def scale_model_input(self, sample: torch.FloatTensor, timestep): + def scale_model_input(self, sample: torch.FloatTensor, timestep: Optional[int] = None) -> torch.FloatTensor: + """ + Ensures interchangeability with schedulers that need to scale the denoising model input depending on the + current timestep. + + Args: + sample (`torch.FloatTensor`): input sample + timestep (`int`, optional): current timestep + + Returns: + `torch.FloatTensor`: scaled input sample + """ return sample def _get_variance(self, timestep, prev_timestep): diff --git a/src/diffusers/schedulers/scheduling_ddpm.py b/src/diffusers/schedulers/scheduling_ddpm.py index 4d4e986a76ea..fc7df653afc1 100644 --- a/src/diffusers/schedulers/scheduling_ddpm.py +++ b/src/diffusers/schedulers/scheduling_ddpm.py @@ -140,12 +140,29 @@ def __init__( self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) self.one = torch.tensor(1.0) + # standard deviation of the initial noise distribution + self.init_noise_sigma = 1.0 + # setable values self.num_inference_steps = None self.timesteps = np.arange(0, num_train_timesteps)[::-1] self.variance_type = variance_type + def scale_model_input(self, sample: torch.FloatTensor, timestep: Optional[int] = None) -> torch.FloatTensor: + """ + Ensures interchangeability with schedulers that need to scale the denoising model input depending on the + current timestep. + + Args: + sample (`torch.FloatTensor`): input sample + timestep (`int`, optional): current timestep + + Returns: + `torch.FloatTensor`: scaled input sample + """ + return sample + def set_timesteps(self, num_inference_steps: int): """ Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference. diff --git a/src/diffusers/schedulers/scheduling_karras_ve.py b/src/diffusers/schedulers/scheduling_karras_ve.py index 63e1400262d8..8bebb160ab7f 100644 --- a/src/diffusers/schedulers/scheduling_karras_ve.py +++ b/src/diffusers/schedulers/scheduling_karras_ve.py @@ -95,11 +95,28 @@ def __init__( take_from=kwargs, ) + # standard deviation of the initial noise distribution + self.init_noise_sigma = sigma_max + # setable values self.num_inference_steps: int = None self.timesteps: np.ndarray = None self.schedule: torch.FloatTensor = None # sigma(t_i) + def scale_model_input(self, sample: torch.FloatTensor, timestep: Optional[int] = None) -> torch.FloatTensor: + """ + Ensures interchangeability with schedulers that need to scale the denoising model input depending on the + current timestep. + + Args: + sample (`torch.FloatTensor`): input sample + timestep (`int`, optional): current timestep + + Returns: + `torch.FloatTensor`: scaled input sample + """ + return sample + def set_timesteps(self, num_inference_steps: int): """ Sets the continuous timesteps used for the diffusion chain. Supporting function to be run before inference. diff --git a/src/diffusers/schedulers/scheduling_lms_discrete.py b/src/diffusers/schedulers/scheduling_lms_discrete.py index 9af8daa744d3..56b285e778cf 100644 --- a/src/diffusers/schedulers/scheduling_lms_discrete.py +++ b/src/diffusers/schedulers/scheduling_lms_discrete.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +import warnings from dataclasses import dataclass from typing import Optional, Tuple, Union @@ -110,18 +110,27 @@ def __init__( timesteps = np.linspace(0, num_train_timesteps - 1, num_train_timesteps, dtype=float)[::-1].copy() self.timesteps = torch.from_numpy(timesteps) self.derivatives = [] + self.is_scale_input_called = False def scale_model_input( self, sample: torch.FloatTensor, timestep: Union[float, torch.FloatTensor] ) -> torch.FloatTensor: """ - Scale the model input to match the ODE solver in K-LMS + Scales the denoising model input by `(sigma**2 + 1) ** 0.5` to match the K-LMS algorithm. + + Args: + sample (`torch.FloatTensor`): input sample + timestep (`float` or `torch.FloatTensor`): the current timestep in the diffusion chain + + Returns: + `torch.FloatTensor`: scaled input sample """ if isinstance(timestep, torch.Tensor): timestep = timestep.to(self.timesteps.device) step_index = (self.timesteps == timestep).nonzero().item() sigma = self.sigmas[step_index] sample = sample / ((sigma**2 + 1) ** 0.5) + self.is_scale_input_called = True return sample def get_lms_coefficient(self, order, t, current_order): @@ -193,6 +202,18 @@ def step( When returning a tuple, the first element is the sample tensor. """ + if not isinstance(timestep, float) and not isinstance(timestep, torch.FloatTensor): + warnings.warn( + f"`LMSDiscreteScheduler` timesteps must be `float` or `torch.FloatTensor`, not {type(timestep)}. " + "Make sure to pass one of the `scheduler.timesteps`" + ) + if not self.is_scale_input_called: + warnings.warn( + "The `scale_model_input` function should be called before `step` to ensure correct denoising. " + "See `StableDiffusionPipeline` for a usage example." + ) + self.is_scale_input_called = False + if isinstance(timestep, torch.Tensor): timestep = timestep.to(self.timesteps.device) step_index = (self.timesteps == timestep).nonzero().item() diff --git a/src/diffusers/schedulers/scheduling_pndm.py b/src/diffusers/schedulers/scheduling_pndm.py index d5dd2600a557..bfbd543a313b 100644 --- a/src/diffusers/schedulers/scheduling_pndm.py +++ b/src/diffusers/schedulers/scheduling_pndm.py @@ -344,7 +344,17 @@ def step_plms( return SchedulerOutput(prev_sample=prev_sample) - def scale_model_input(self, sample: torch.FloatTensor, timestep): + def scale_model_input(self, sample: torch.FloatTensor, *args, **kwargs) -> torch.FloatTensor: + """ + Ensures interchangeability with schedulers that need to scale the denoising model input depending on the + current timestep. + + Args: + sample (`torch.FloatTensor`): input sample + + Returns: + `torch.FloatTensor`: scaled input sample + """ return sample def _get_prev_sample(self, sample, timestep, prev_timestep, model_output): diff --git a/src/diffusers/schedulers/scheduling_sde_ve.py b/src/diffusers/schedulers/scheduling_sde_ve.py index 12ed1a1b656e..a2f7fb36fba9 100644 --- a/src/diffusers/schedulers/scheduling_sde_ve.py +++ b/src/diffusers/schedulers/scheduling_sde_ve.py @@ -84,11 +84,28 @@ def __init__( take_from=kwargs, ) + # standard deviation of the initial noise distribution + self.init_noise_sigma = sigma_max + # setable values self.timesteps = None self.set_sigmas(num_train_timesteps, sigma_min, sigma_max, sampling_eps) + def scale_model_input(self, sample: torch.FloatTensor, timestep: Optional[int] = None) -> torch.FloatTensor: + """ + Ensures interchangeability with schedulers that need to scale the denoising model input depending on the + current timestep. + + Args: + sample (`torch.FloatTensor`): input sample + timestep (`int`, optional): current timestep + + Returns: + `torch.FloatTensor`: scaled input sample + """ + return sample + def set_timesteps(self, num_inference_steps: int, sampling_eps: float = None): """ Sets the continuous timesteps used for the diffusion chain. Supporting function to be run before inference. diff --git a/tests/test_scheduler.py b/tests/test_scheduler.py index cb0d554cfebc..051ba05a20f1 100755 --- a/tests/test_scheduler.py +++ b/tests/test_scheduler.py @@ -226,6 +226,27 @@ def recursive_check(tuple_object, dict_object): recursive_check(outputs_tuple, outputs_dict) + def test_scheduler_public_api(self): + for scheduler_class in self.scheduler_classes: + scheduler_config = self.get_scheduler_config() + scheduler = scheduler_class(**scheduler_config) + self.assertTrue( + hasattr(scheduler, "init_noise_sigma"), + f"{scheduler_class} does not implement a required attribute `init_noise_sigma`", + ) + self.assertTrue( + hasattr(scheduler, "scale_model_input"), + f"{scheduler_class} does not implement a required class method `scale_model_input(sample, timestep)`", + ) + self.assertTrue( + hasattr(scheduler, "step"), + f"{scheduler_class} does not implement a required class method `step(...)`", + ) + + sample = self.dummy_sample + scaled_sample = scheduler.scale_model_input(sample, 0.0) + self.assertEqual(sample.shape, scaled_sample.shape) + class DDPMSchedulerTest(SchedulerCommonTest): scheduler_classes = (DDPMScheduler,) From 7a4abb01a0a765b6483399703c422268b64fd2b8 Mon Sep 17 00:00:00 2001 From: anton-l Date: Wed, 5 Oct 2022 14:06:53 +0200 Subject: [PATCH 7/7] Make the warning less spammy --- src/diffusers/schedulers/scheduling_lms_discrete.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/diffusers/schedulers/scheduling_lms_discrete.py b/src/diffusers/schedulers/scheduling_lms_discrete.py index 56b285e778cf..a55811a0629f 100644 --- a/src/diffusers/schedulers/scheduling_lms_discrete.py +++ b/src/diffusers/schedulers/scheduling_lms_discrete.py @@ -212,7 +212,6 @@ def step( "The `scale_model_input` function should be called before `step` to ensure correct denoising. " "See `StableDiffusionPipeline` for a usage example." ) - self.is_scale_input_called = False if isinstance(timestep, torch.Tensor): timestep = timestep.to(self.timesteps.device)