diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py index ce17f2e0ee41..9f1211b43013 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py @@ -6,6 +6,7 @@ from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer +from ...configuration_utils import FrozenDict from ...models import AutoencoderKL, UNet2DConditionModel from ...pipeline_utils import DiffusionPipeline from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler @@ -53,6 +54,21 @@ def __init__( ): super().__init__() scheduler = scheduler.set_format("pt") + + if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1: + warnings.warn( + f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`" + f" should be set to 1 istead of {scheduler.config.steps_offset}. Please make sure " + "to update the config accordingly as leaving `steps_offset` might led to incorrect results" + " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub," + " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`" + " file", + DeprecationWarning, + ) + new_config = dict(scheduler.config) + new_config["steps_offset"] = 1 + scheduler._internal_dict = FrozenDict(new_config) + self.register_modules( vae=vae, text_encoder=text_encoder, @@ -217,12 +233,7 @@ def __call__( latents = latents.to(self.device) # set timesteps - accepts_offset = "offset" in set(inspect.signature(self.scheduler.set_timesteps).parameters.keys()) - extra_set_kwargs = {} - if accepts_offset: - extra_set_kwargs["offset"] = 1 - - self.scheduler.set_timesteps(num_inference_steps, **extra_set_kwargs) + self.scheduler.set_timesteps(num_inference_steps) # if we use LMSDiscreteScheduler, let's make sure latents are multiplied by sigmas if isinstance(self.scheduler, LMSDiscreteScheduler): diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py index 817288f499ec..e7adb4d1a33b 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py @@ -1,4 +1,5 @@ import inspect +import warnings from typing import List, Optional, Union import numpy as np @@ -7,6 +8,7 @@ import PIL from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer +from ...configuration_utils import FrozenDict from ...models import AutoencoderKL, UNet2DConditionModel from ...pipeline_utils import DiffusionPipeline from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler @@ -64,6 +66,21 @@ def __init__( ): super().__init__() scheduler = scheduler.set_format("pt") + + if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1: + warnings.warn( + f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`" + f" should be set to 1 istead of {scheduler.config.steps_offset}. Please make sure " + "to update the config accordingly as leaving `steps_offset` might led to incorrect results" + " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub," + " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`" + " file", + DeprecationWarning, + ) + new_config = dict(scheduler.config) + new_config["steps_offset"] = 1 + scheduler._internal_dict = FrozenDict(new_config) + self.register_modules( vae=vae, text_encoder=text_encoder, @@ -169,14 +186,7 @@ def __call__( raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") # set timesteps - accepts_offset = "offset" in set(inspect.signature(self.scheduler.set_timesteps).parameters.keys()) - extra_set_kwargs = {} - offset = 0 - if accepts_offset: - offset = 1 - extra_set_kwargs["offset"] = 1 - - self.scheduler.set_timesteps(num_inference_steps, **extra_set_kwargs) + self.scheduler.set_timesteps(num_inference_steps) if isinstance(init_image, PIL.Image.Image): init_image = preprocess(init_image) @@ -190,6 +200,7 @@ def __call__( init_latents = torch.cat([init_latents] * batch_size) # get the original timestep using init_timestep + offset = self.scheduler.config.get("steps_offset", 0) init_timestep = int(num_inference_steps * strength) + offset init_timestep = min(init_timestep, num_inference_steps) if isinstance(self.scheduler, LMSDiscreteScheduler): diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py index 6076b854089e..b9ad36f1a2bf 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py @@ -1,4 +1,5 @@ import inspect +import warnings from typing import List, Optional, Union import numpy as np @@ -8,6 +9,7 @@ from tqdm.auto import tqdm from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer +from ...configuration_utils import FrozenDict from ...models import AutoencoderKL, UNet2DConditionModel from ...pipeline_utils import DiffusionPipeline from ...schedulers import DDIMScheduler, PNDMScheduler @@ -83,6 +85,21 @@ def __init__( super().__init__() scheduler = scheduler.set_format("pt") logger.info("`StableDiffusionInpaintPipeline` is experimental and will very likely change in the future.") + + if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1: + warnings.warn( + f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`" + f" should be set to 1 istead of {scheduler.config.steps_offset}. Please make sure " + "to update the config accordingly as leaving `steps_offset` might led to incorrect results" + " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub," + " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`" + " file", + DeprecationWarning, + ) + new_config = dict(scheduler.config) + new_config["steps_offset"] = 1 + scheduler._internal_dict = FrozenDict(new_config) + self.register_modules( vae=vae, text_encoder=text_encoder, @@ -193,19 +210,12 @@ def __call__( raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") # set timesteps - accepts_offset = "offset" in set(inspect.signature(self.scheduler.set_timesteps).parameters.keys()) - extra_set_kwargs = {} - offset = 0 - if accepts_offset: - offset = 1 - extra_set_kwargs["offset"] = 1 - - self.scheduler.set_timesteps(num_inference_steps, **extra_set_kwargs) + self.scheduler.set_timesteps(num_inference_steps) # preprocess image if not isinstance(init_image, torch.FloatTensor): init_image = preprocess_image(init_image) - init_image.to(self.device) + init_image = init_image.to(self.device) # encode the init image into latents and scale the latents init_latent_dist = self.vae.encode(init_image).latent_dist @@ -220,7 +230,7 @@ def __call__( # preprocess mask if not isinstance(mask_image, torch.FloatTensor): mask_image = preprocess_mask(mask_image) - mask_image.to(self.device) + mask_image = mask_image.to(self.device) mask = torch.cat([mask_image] * batch_size) # check sizes @@ -228,6 +238,7 @@ def __call__( raise ValueError("The mask and init_image should be the same size!") # get the original timestep using init_timestep + offset = self.scheduler.config.get("steps_offset", 0) init_timestep = int(num_inference_steps * strength) + offset init_timestep = min(init_timestep, num_inference_steps) timesteps = self.scheduler.timesteps[-init_timestep] diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_onnx.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_onnx.py index b114145d6f04..ccba29ade5d3 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_onnx.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_onnx.py @@ -100,12 +100,7 @@ def __call__( raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}") # set timesteps - accepts_offset = "offset" in set(inspect.signature(self.scheduler.set_timesteps).parameters.keys()) - extra_set_kwargs = {} - if accepts_offset: - extra_set_kwargs["offset"] = 1 - - self.scheduler.set_timesteps(num_inference_steps, **extra_set_kwargs) + self.scheduler.set_timesteps(num_inference_steps) # if we use LMSDiscreteScheduler, let's make sure latents are multiplied by sigmas if isinstance(self.scheduler, LMSDiscreteScheduler): diff --git a/src/diffusers/schedulers/scheduling_ddim.py b/src/diffusers/schedulers/scheduling_ddim.py index 85ba67523172..32be871f0bc9 100644 --- a/src/diffusers/schedulers/scheduling_ddim.py +++ b/src/diffusers/schedulers/scheduling_ddim.py @@ -16,6 +16,7 @@ # and https://github.com/hojonathanho/diffusion import math +import warnings from typing import Optional, Tuple, Union import numpy as np @@ -78,7 +79,13 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin): clip_sample (`bool`, default `True`): option to clip predicted sample between -1 and 1 for numerical stability. set_alpha_to_one (`bool`, default `True`): - if alpha for final step is 1 or the final alpha of the "non-previous" one. + each diffusion step uses the value of alphas product at that step and at the previous one. For the final + step there is no previous alpha. When this option is `True` the previous alpha product is fixed to `1`, + otherwise it uses the value of alpha at step 0. + steps_offset (`int`, default `0`): + an offset added to the inference steps. You can use a combination of `offset=1` and + `set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in + stable diffusion. tensor_format (`str`): whether the scheduler expects pytorch or numpy arrays. """ @@ -93,6 +100,7 @@ def __init__( trained_betas: Optional[np.ndarray] = None, clip_sample: bool = True, set_alpha_to_one: bool = True, + steps_offset: int = 0, tensor_format: str = "pt", ): if trained_betas is not None: @@ -134,16 +142,26 @@ def _get_variance(self, timestep, prev_timestep): return variance - def set_timesteps(self, num_inference_steps: int, offset: int = 0): + def set_timesteps(self, num_inference_steps: int, **kwargs): """ Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference. Args: num_inference_steps (`int`): the number of diffusion steps used when generating samples with a pre-trained model. - offset (`int`): - optional value to shift timestep values up by. A value of 1 is used in stable diffusion for inference. """ + + offset = self.config.steps_offset + + if "offset" in kwargs: + warnings.warn( + "`offset` is deprecated as an input argument to `set_timesteps` and will be removed in v0.4.0." + " Please pass `steps_offset` to `__init__` instead.", + DeprecationWarning, + ) + + offset = kwargs["offset"] + self.num_inference_steps = num_inference_steps step_ratio = self.config.num_train_timesteps // self.num_inference_steps # creates integer timesteps by multiplying by ratio diff --git a/src/diffusers/schedulers/scheduling_pndm.py b/src/diffusers/schedulers/scheduling_pndm.py index 6eeb15b93923..09e8a7e240c2 100644 --- a/src/diffusers/schedulers/scheduling_pndm.py +++ b/src/diffusers/schedulers/scheduling_pndm.py @@ -15,6 +15,7 @@ # DISCLAIMER: This file is strongly influenced by https://github.com/ermongroup/ddim import math +import warnings from typing import Optional, Tuple, Union import numpy as np @@ -74,10 +75,18 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin): `linear`, `scaled_linear`, or `squaredcos_cap_v2`. trained_betas (`np.ndarray`, optional): option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc. - tensor_format (`str`): whether the scheduler expects pytorch or numpy arrays skip_prk_steps (`bool`): allows the scheduler to skip the Runge-Kutta steps that are defined in the original paper as being required before plms steps; defaults to `False`. + set_alpha_to_one (`bool`, default `False`): + each diffusion step uses the value of alphas product at that step and at the previous one. For the final + step there is no previous alpha. When this option is `True` the previous alpha product is fixed to `1`, + otherwise it uses the value of alpha at step 0. + steps_offset (`int`, default `0`): + an offset added to the inference steps. You can use a combination of `offset=1` and + `set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in + stable diffusion. + tensor_format (`str`): whether the scheduler expects pytorch or numpy arrays """ @@ -89,8 +98,10 @@ def __init__( beta_end: float = 0.02, beta_schedule: str = "linear", trained_betas: Optional[np.ndarray] = None, - tensor_format: str = "pt", skip_prk_steps: bool = False, + set_alpha_to_one: bool = False, + steps_offset: int = 0, + tensor_format: str = "pt", ): if trained_betas is not None: self.betas = np.asarray(trained_betas) @@ -108,6 +119,8 @@ def __init__( self.alphas = 1.0 - self.betas self.alphas_cumprod = np.cumprod(self.alphas, axis=0) + self.final_alpha_cumprod = np.array(1.0) if set_alpha_to_one else self.alphas_cumprod[0] + # For now we only support F-PNDM, i.e. the runge-kutta method # For more information on the algorithm please take a look at the paper: https://arxiv.org/pdf/2202.09778.pdf # mainly at formula (9), (12), (13) and the Algorithm 2. @@ -122,7 +135,6 @@ def __init__( # setable values self.num_inference_steps = None self._timesteps = np.arange(0, num_train_timesteps)[::-1].copy() - self._offset = 0 self.prk_timesteps = None self.plms_timesteps = None self.timesteps = None @@ -130,23 +142,31 @@ def __init__( self.tensor_format = tensor_format self.set_format(tensor_format=tensor_format) - def set_timesteps(self, num_inference_steps: int, offset: int = 0) -> torch.FloatTensor: + def set_timesteps(self, num_inference_steps: int, **kwargs) -> torch.FloatTensor: """ Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference. Args: num_inference_steps (`int`): the number of diffusion steps used when generating samples with a pre-trained model. - offset (`int`): - optional value to shift timestep values up by. A value of 1 is used in stable diffusion for inference. """ + + offset = self.config.steps_offset + + if "offset" in kwargs: + warnings.warn( + "`offset` is deprecated as an input argument to `set_timesteps` and will be removed in v0.4.0." + " Please pass `steps_offset` to `__init__` instead." + ) + + offset = kwargs["offset"] + self.num_inference_steps = num_inference_steps step_ratio = self.config.num_train_timesteps // self.num_inference_steps # creates integer timesteps by multiplying by ratio # casting to int to avoid issues when num_inference_step is power of 3 - self._timesteps = (np.arange(0, num_inference_steps) * step_ratio).round().tolist() - self._offset = offset - self._timesteps = np.array([t + self._offset for t in self._timesteps]) + self._timesteps = (np.arange(0, num_inference_steps) * step_ratio).round() + self._timesteps += offset if self.config.skip_prk_steps: # for some models like stable diffusion the prk steps can/should be skipped to @@ -231,7 +251,7 @@ def step_prk( ) diff_to_prev = 0 if self.counter % 2 else self.config.num_train_timesteps // self.num_inference_steps // 2 - prev_timestep = max(timestep - diff_to_prev, self.prk_timesteps[-1]) + prev_timestep = timestep - diff_to_prev timestep = self.prk_timesteps[self.counter // 4 * 4] if self.counter % 4 == 0: @@ -293,7 +313,7 @@ def step_plms( "for more information." ) - prev_timestep = max(timestep - self.config.num_train_timesteps // self.num_inference_steps, 0) + prev_timestep = timestep - self.config.num_train_timesteps // self.num_inference_steps if self.counter != 1: self.ets.append(model_output) @@ -323,7 +343,7 @@ def step_plms( return SchedulerOutput(prev_sample=prev_sample) - def _get_prev_sample(self, sample, timestep, timestep_prev, model_output): + def _get_prev_sample(self, sample, timestep, prev_timestep, model_output): # See formula (9) of PNDM paper https://arxiv.org/pdf/2202.09778.pdf # this function computes x_(t−δ) using the formula of (9) # Note that x_t needs to be added to both sides of the equation @@ -336,8 +356,8 @@ def _get_prev_sample(self, sample, timestep, timestep_prev, model_output): # sample -> x_t # model_output -> e_θ(x_t, t) # prev_sample -> x_(t−δ) - alpha_prod_t = self.alphas_cumprod[timestep + 1 - self._offset] - alpha_prod_t_prev = self.alphas_cumprod[timestep_prev + 1 - self._offset] + alpha_prod_t = self.alphas_cumprod[timestep] + alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod beta_prod_t = 1 - alpha_prod_t beta_prod_t_prev = 1 - alpha_prod_t_prev diff --git a/tests/test_scheduler.py b/tests/test_scheduler.py index e78f100409f8..7377797bebfa 100755 --- a/tests/test_scheduler.py +++ b/tests/test_scheduler.py @@ -357,10 +357,38 @@ def get_scheduler_config(self, **kwargs): config.update(**kwargs) return config + def full_loop(self, **config): + scheduler_class = self.scheduler_classes[0] + scheduler_config = self.get_scheduler_config(**config) + scheduler = scheduler_class(**scheduler_config) + + num_inference_steps, eta = 10, 0.0 + + model = self.dummy_model() + sample = self.dummy_sample_deter + + scheduler.set_timesteps(num_inference_steps) + + for t in scheduler.timesteps: + residual = model(sample, t) + sample = scheduler.step(residual, t, sample, eta).prev_sample + + return sample + def test_timesteps(self): for timesteps in [100, 500, 1000]: self.check_over_configs(num_train_timesteps=timesteps) + def test_steps_offset(self): + for steps_offset in [0, 1]: + self.check_over_configs(steps_offset=steps_offset) + + scheduler_class = self.scheduler_classes[0] + scheduler_config = self.get_scheduler_config(steps_offset=1) + scheduler = scheduler_class(**scheduler_config) + scheduler.set_timesteps(5) + assert torch.equal(scheduler.timesteps, torch.tensor([801, 601, 401, 201, 1])) + def test_betas(self): for beta_start, beta_end in zip([0.0001, 0.001, 0.01, 0.1], [0.002, 0.02, 0.2, 2]): self.check_over_configs(beta_start=beta_start, beta_end=beta_end) @@ -398,26 +426,31 @@ def test_variance(self): assert torch.sum(torch.abs(scheduler._get_variance(999, 998) - 0.02)) < 1e-5 def test_full_loop_no_noise(self): - scheduler_class = self.scheduler_classes[0] - scheduler_config = self.get_scheduler_config() - scheduler = scheduler_class(**scheduler_config) + sample = self.full_loop() - num_inference_steps, eta = 10, 0.0 + result_sum = torch.sum(torch.abs(sample)) + result_mean = torch.mean(torch.abs(sample)) - model = self.dummy_model() - sample = self.dummy_sample_deter + assert abs(result_sum.item() - 172.0067) < 1e-2 + assert abs(result_mean.item() - 0.223967) < 1e-3 - scheduler.set_timesteps(num_inference_steps) - for t in scheduler.timesteps: - residual = model(sample, t) + def test_full_loop_with_set_alpha_to_one(self): + # We specify different beta, so that the first alpha is 0.99 + sample = self.full_loop(set_alpha_to_one=True, beta_start=0.01) + result_sum = torch.sum(torch.abs(sample)) + result_mean = torch.mean(torch.abs(sample)) - sample = scheduler.step(residual, t, sample, eta).prev_sample + assert abs(result_sum.item() - 149.8295) < 1e-2 + assert abs(result_mean.item() - 0.1951) < 1e-3 + def test_full_loop_with_no_set_alpha_to_one(self): + # We specify different beta, so that the first alpha is 0.99 + sample = self.full_loop(set_alpha_to_one=False, beta_start=0.01) result_sum = torch.sum(torch.abs(sample)) result_mean = torch.mean(torch.abs(sample)) - assert abs(result_sum.item() - 172.0067) < 1e-2 - assert abs(result_mean.item() - 0.223967) < 1e-3 + assert abs(result_sum.item() - 149.0784) < 1e-2 + assert abs(result_mean.item() - 0.1941) < 1e-3 class PNDMSchedulerTest(SchedulerCommonTest): @@ -503,6 +536,26 @@ def check_over_forward(self, time_step=0, **forward_kwargs): assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical" + def full_loop(self, **config): + scheduler_class = self.scheduler_classes[0] + scheduler_config = self.get_scheduler_config(**config) + scheduler = scheduler_class(**scheduler_config) + + num_inference_steps = 10 + model = self.dummy_model() + sample = self.dummy_sample_deter + scheduler.set_timesteps(num_inference_steps) + + for i, t in enumerate(scheduler.prk_timesteps): + residual = model(sample, t) + sample = scheduler.step_prk(residual, t, sample).prev_sample + + for i, t in enumerate(scheduler.plms_timesteps): + residual = model(sample, t) + sample = scheduler.step_plms(residual, t, sample).prev_sample + + return sample + def test_pytorch_equal_numpy(self): kwargs = dict(self.forward_default_kwargs) num_inference_steps = kwargs.pop("num_inference_steps", None) @@ -606,8 +659,23 @@ def test_timesteps(self): for timesteps in [100, 1000]: self.check_over_configs(num_train_timesteps=timesteps) + def test_steps_offset(self): + for steps_offset in [0, 1]: + self.check_over_configs(steps_offset=steps_offset) + + scheduler_class = self.scheduler_classes[0] + scheduler_config = self.get_scheduler_config(steps_offset=1) + scheduler = scheduler_class(**scheduler_config) + scheduler.set_timesteps(10) + assert torch.equal( + scheduler.timesteps, + torch.tensor( + [901, 851, 851, 801, 801, 751, 751, 701, 701, 651, 651, 601, 601, 501, 401, 301, 201, 101, 1] + ), + ) + def test_betas(self): - for beta_start, beta_end in zip([0.0001, 0.001, 0.01], [0.002, 0.02, 0.2]): + for beta_start, beta_end in zip([0.0001, 0.001], [0.002, 0.02]): self.check_over_configs(beta_start=beta_start, beta_end=beta_end) def test_schedules(self): @@ -620,7 +688,7 @@ def test_time_indices(self): def test_inference_steps(self): for t, num_inference_steps in zip([1, 5, 10], [10, 50, 100]): - self.check_over_forward(time_step=t, num_inference_steps=num_inference_steps) + self.check_over_forward(num_inference_steps=num_inference_steps) def test_pow_of_3_inference_steps(self): # earlier version of set_timesteps() caused an error indexing alpha's with inference steps as power of 3 @@ -648,28 +716,30 @@ def test_inference_plms_no_past_residuals(self): scheduler.step_plms(self.dummy_sample, 1, self.dummy_sample).prev_sample def test_full_loop_no_noise(self): - scheduler_class = self.scheduler_classes[0] - scheduler_config = self.get_scheduler_config() - scheduler = scheduler_class(**scheduler_config) + sample = self.full_loop() + result_sum = torch.sum(torch.abs(sample)) + result_mean = torch.mean(torch.abs(sample)) - num_inference_steps = 10 - model = self.dummy_model() - sample = self.dummy_sample_deter - scheduler.set_timesteps(num_inference_steps) + assert abs(result_sum.item() - 198.1318) < 1e-2 + assert abs(result_mean.item() - 0.2580) < 1e-3 - for i, t in enumerate(scheduler.prk_timesteps): - residual = model(sample, t) - sample = scheduler.step_prk(residual, i, sample).prev_sample + def test_full_loop_with_set_alpha_to_one(self): + # We specify different beta, so that the first alpha is 0.99 + sample = self.full_loop(set_alpha_to_one=True, beta_start=0.01) + result_sum = torch.sum(torch.abs(sample)) + result_mean = torch.mean(torch.abs(sample)) - for i, t in enumerate(scheduler.plms_timesteps): - residual = model(sample, t) - sample = scheduler.step_plms(residual, i, sample).prev_sample + assert abs(result_sum.item() - 230.0399) < 1e-2 + assert abs(result_mean.item() - 0.2995) < 1e-3 + def test_full_loop_with_no_set_alpha_to_one(self): + # We specify different beta, so that the first alpha is 0.99 + sample = self.full_loop(set_alpha_to_one=False, beta_start=0.01) result_sum = torch.sum(torch.abs(sample)) result_mean = torch.mean(torch.abs(sample)) - assert abs(result_sum.item() - 428.8788) < 1e-2 - assert abs(result_mean.item() - 0.5584) < 1e-3 + assert abs(result_sum.item() - 186.9482) < 1e-2 + assert abs(result_mean.item() - 0.2434) < 1e-3 class ScoreSdeVeSchedulerTest(unittest.TestCase):