diff --git a/examples/conftest.py b/examples/conftest.py index a72bc85310d2..d2f9600313a1 100644 --- a/examples/conftest.py +++ b/examples/conftest.py @@ -32,13 +32,13 @@ def pytest_addoption(parser): - from diffusers.testing_utils import pytest_addoption_shared + from diffusers.utils.testing_utils import pytest_addoption_shared pytest_addoption_shared(parser) def pytest_terminal_summary(terminalreporter): - from diffusers.testing_utils import pytest_terminal_summary_main + from diffusers.utils.testing_utils import pytest_terminal_summary_main make_reports = terminalreporter.config.getoption("--make-reports") if make_reports: diff --git a/examples/test_examples.py b/examples/test_examples.py index 0099d17e638d..8838713cb7d0 100644 --- a/examples/test_examples.py +++ b/examples/test_examples.py @@ -24,7 +24,7 @@ from typing import List from accelerate.utils import write_basic_config -from diffusers.testing_utils import slow +from diffusers.utils import slow logging.basicConfig(level=logging.DEBUG) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py index 45939e410672..d190acb1fa1c 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py @@ -1,5 +1,4 @@ import inspect -import warnings from typing import Callable, List, Optional, Union import torch @@ -10,7 +9,7 @@ from ...models import AutoencoderKL, UNet2DConditionModel from ...pipeline_utils import DiffusionPipeline from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler -from ...utils import logging +from ...utils import deprecate, logging from . import StableDiffusionPipelineOutput from .safety_checker import StableDiffusionSafetyChecker @@ -59,15 +58,15 @@ def __init__( super().__init__() if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1: - warnings.warn( + deprecation_message = ( f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`" f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure " "to update the config accordingly as leaving `steps_offset` might led to incorrect results" " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub," " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`" - " file", - DeprecationWarning, + " file" ) + deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False) new_config = dict(scheduler.config) new_config["steps_offset"] = 1 scheduler._internal_dict = FrozenDict(new_config) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py index 6317877233ed..c8f02b5896d6 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py @@ -1,5 +1,4 @@ import inspect -import warnings from typing import Callable, List, Optional, Union import numpy as np @@ -12,7 +11,7 @@ from ...models import AutoencoderKL, UNet2DConditionModel from ...pipeline_utils import DiffusionPipeline from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler -from ...utils import logging +from ...utils import deprecate, logging from . import StableDiffusionPipelineOutput from .safety_checker import StableDiffusionSafetyChecker @@ -71,15 +70,15 @@ def __init__( super().__init__() if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1: - warnings.warn( + deprecation_message = ( f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`" f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure " "to update the config accordingly as leaving `steps_offset` might led to incorrect results" " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub," " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`" - " file", - DeprecationWarning, + " file" ) + deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False) new_config = dict(scheduler.config) new_config["steps_offset"] = 1 scheduler._internal_dict = FrozenDict(new_config) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py index 37f03dfbd9c4..21490d975730 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py @@ -1,5 +1,4 @@ import inspect -import warnings from typing import Callable, List, Optional, Union import numpy as np @@ -13,7 +12,7 @@ from ...models import AutoencoderKL, UNet2DConditionModel from ...pipeline_utils import DiffusionPipeline from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler -from ...utils import logging +from ...utils import deprecate, logging from . import StableDiffusionPipelineOutput from .safety_checker import StableDiffusionSafetyChecker @@ -86,15 +85,15 @@ def __init__( logger.info("`StableDiffusionInpaintPipeline` is experimental and will very likely change in the future.") if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1: - warnings.warn( + deprecation_message = ( f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`" f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure " "to update the config accordingly as leaving `steps_offset` might led to incorrect results" " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub," " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`" - " file", - DeprecationWarning, + " file" ) + deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False) new_config = dict(scheduler.config) new_config["steps_offset"] = 1 scheduler._internal_dict = FrozenDict(new_config) diff --git a/src/diffusers/schedulers/scheduling_ddim.py b/src/diffusers/schedulers/scheduling_ddim.py index 9079ba906f4d..a728ab29d7bb 100644 --- a/src/diffusers/schedulers/scheduling_ddim.py +++ b/src/diffusers/schedulers/scheduling_ddim.py @@ -16,7 +16,6 @@ # and https://github.com/hojonathanho/diffusion import math -import warnings from dataclasses import dataclass from typing import Optional, Tuple, Union @@ -24,7 +23,7 @@ import torch from ..configuration_utils import ConfigMixin, register_to_config -from ..utils import BaseOutput +from ..utils import BaseOutput, deprecate from .scheduling_utils import SchedulerMixin @@ -122,12 +121,12 @@ def __init__( steps_offset: int = 0, **kwargs, ): - if "tensor_format" in kwargs: - warnings.warn( - "`tensor_format` is deprecated as an argument and will be removed in version `0.5.0`." - "If you're running your code in PyTorch, you can safely remove this argument.", - DeprecationWarning, - ) + deprecate( + "tensor_format", + "0.5.0", + "If you're running your code in PyTorch, you can safely remove this argument.", + take_from=kwargs, + ) if trained_betas is not None: self.betas = torch.from_numpy(trained_betas) @@ -175,17 +174,10 @@ def set_timesteps(self, num_inference_steps: int, **kwargs): num_inference_steps (`int`): the number of diffusion steps used when generating samples with a pre-trained model. """ - - offset = self.config.steps_offset - - if "offset" in kwargs: - warnings.warn( - "`offset` is deprecated as an input argument to `set_timesteps` and will be removed in v0.4.0." - " Please pass `steps_offset` to `__init__` instead.", - DeprecationWarning, - ) - - offset = kwargs["offset"] + deprecated_offset = deprecate( + "offset", "0.5.0", "Please pass `steps_offset` to `__init__` instead.", take_from=kwargs + ) + offset = deprecated_offset or self.config.steps_offset self.num_inference_steps = num_inference_steps step_ratio = self.config.num_train_timesteps // self.num_inference_steps diff --git a/src/diffusers/schedulers/scheduling_ddpm.py b/src/diffusers/schedulers/scheduling_ddpm.py index cc17cee4c810..4d4e986a76ea 100644 --- a/src/diffusers/schedulers/scheduling_ddpm.py +++ b/src/diffusers/schedulers/scheduling_ddpm.py @@ -15,7 +15,6 @@ # DISCLAIMER: This file is strongly influenced by https://github.com/ermongroup/ddim import math -import warnings from dataclasses import dataclass from typing import Optional, Tuple, Union @@ -23,7 +22,7 @@ import torch from ..configuration_utils import ConfigMixin, register_to_config -from ..utils import BaseOutput +from ..utils import BaseOutput, deprecate from .scheduling_utils import SchedulerMixin @@ -115,12 +114,12 @@ def __init__( clip_sample: bool = True, **kwargs, ): - if "tensor_format" in kwargs: - warnings.warn( - "`tensor_format` is deprecated as an argument and will be removed in version `0.5.0`." - "If you're running your code in PyTorch, you can safely remove this argument.", - DeprecationWarning, - ) + deprecate( + "tensor_format", + "0.5.0", + "If you're running your code in PyTorch, you can safely remove this argument.", + take_from=kwargs, + ) if trained_betas is not None: self.betas = torch.from_numpy(trained_betas) diff --git a/src/diffusers/schedulers/scheduling_karras_ve.py b/src/diffusers/schedulers/scheduling_karras_ve.py index e6e5300e73e7..63e1400262d8 100644 --- a/src/diffusers/schedulers/scheduling_karras_ve.py +++ b/src/diffusers/schedulers/scheduling_karras_ve.py @@ -13,7 +13,6 @@ # limitations under the License. -import warnings from dataclasses import dataclass from typing import Optional, Tuple, Union @@ -21,7 +20,7 @@ import torch from ..configuration_utils import ConfigMixin, register_to_config -from ..utils import BaseOutput +from ..utils import BaseOutput, deprecate from .scheduling_utils import SchedulerMixin @@ -89,12 +88,12 @@ def __init__( s_max: float = 50, **kwargs, ): - if "tensor_format" in kwargs: - warnings.warn( - "`tensor_format` is deprecated as an argument and will be removed in version `0.5.0`." - "If you're running your code in PyTorch, you can safely remove this argument.", - DeprecationWarning, - ) + deprecate( + "tensor_format", + "0.5.0", + "If you're running your code in PyTorch, you can safely remove this argument.", + take_from=kwargs, + ) # setable values self.num_inference_steps: int = None diff --git a/src/diffusers/schedulers/scheduling_lms_discrete.py b/src/diffusers/schedulers/scheduling_lms_discrete.py index 8fd8c2b844a8..33e9558d9c38 100644 --- a/src/diffusers/schedulers/scheduling_lms_discrete.py +++ b/src/diffusers/schedulers/scheduling_lms_discrete.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import warnings from dataclasses import dataclass from typing import Optional, Tuple, Union @@ -22,7 +21,7 @@ from scipy import integrate from ..configuration_utils import ConfigMixin, register_to_config -from ..utils import BaseOutput +from ..utils import BaseOutput, deprecate from .scheduling_utils import SchedulerMixin @@ -77,12 +76,12 @@ def __init__( trained_betas: Optional[np.ndarray] = None, **kwargs, ): - if "tensor_format" in kwargs: - warnings.warn( - "`tensor_format` is deprecated as an argument and will be removed in version `0.5.0`." - "If you're running your code in PyTorch, you can safely remove this argument.", - DeprecationWarning, - ) + deprecate( + "tensor_format", + "0.5.0", + "If you're running your code in PyTorch, you can safely remove this argument.", + take_from=kwargs, + ) if trained_betas is not None: self.betas = torch.from_numpy(trained_betas) diff --git a/src/diffusers/schedulers/scheduling_pndm.py b/src/diffusers/schedulers/scheduling_pndm.py index 3015d153af90..3974335a2f1b 100644 --- a/src/diffusers/schedulers/scheduling_pndm.py +++ b/src/diffusers/schedulers/scheduling_pndm.py @@ -15,13 +15,13 @@ # DISCLAIMER: This file is strongly influenced by https://github.com/ermongroup/ddim import math -import warnings from typing import Optional, Tuple, Union import numpy as np import torch from ..configuration_utils import ConfigMixin, register_to_config +from ..utils import deprecate from .scheduling_utils import SchedulerMixin, SchedulerOutput @@ -102,12 +102,12 @@ def __init__( steps_offset: int = 0, **kwargs, ): - if "tensor_format" in kwargs: - warnings.warn( - "`tensor_format` is deprecated as an argument and will be removed in version `0.5.0`." - "If you're running your code in PyTorch, you can safely remove this argument.", - DeprecationWarning, - ) + deprecate( + "tensor_format", + "0.5.0", + "If you're running your code in PyTorch, you can safely remove this argument.", + take_from=kwargs, + ) if trained_betas is not None: self.betas = torch.from_numpy(trained_betas) @@ -155,16 +155,10 @@ def set_timesteps(self, num_inference_steps: int, **kwargs) -> torch.FloatTensor num_inference_steps (`int`): the number of diffusion steps used when generating samples with a pre-trained model. """ - - offset = self.config.steps_offset - - if "offset" in kwargs: - warnings.warn( - "`offset` is deprecated as an input argument to `set_timesteps` and will be removed in v0.4.0." - " Please pass `steps_offset` to `__init__` instead." - ) - - offset = kwargs["offset"] + deprecated_offset = deprecate( + "offset", "0.5.0", "Please pass `steps_offset` to `__init__` instead.", take_from=kwargs + ) + offset = deprecated_offset or self.config.steps_offset self.num_inference_steps = num_inference_steps step_ratio = self.config.num_train_timesteps // self.num_inference_steps diff --git a/src/diffusers/schedulers/scheduling_sde_ve.py b/src/diffusers/schedulers/scheduling_sde_ve.py index a549654c3b6f..12ed1a1b656e 100644 --- a/src/diffusers/schedulers/scheduling_sde_ve.py +++ b/src/diffusers/schedulers/scheduling_sde_ve.py @@ -15,14 +15,13 @@ # DISCLAIMER: This file is strongly influenced by https://github.com/yang-song/score_sde_pytorch import math -import warnings from dataclasses import dataclass from typing import Optional, Tuple, Union import torch from ..configuration_utils import ConfigMixin, register_to_config -from ..utils import BaseOutput +from ..utils import BaseOutput, deprecate from .scheduling_utils import SchedulerMixin, SchedulerOutput @@ -78,12 +77,12 @@ def __init__( correct_steps: int = 1, **kwargs, ): - if "tensor_format" in kwargs: - warnings.warn( - "`tensor_format` is deprecated as an argument and will be removed in version `0.5.0`." - "If you're running your code in PyTorch, you can safely remove this argument.", - DeprecationWarning, - ) + deprecate( + "tensor_format", + "0.5.0", + "If you're running your code in PyTorch, you can safely remove this argument.", + take_from=kwargs, + ) # setable values self.timesteps = None @@ -139,11 +138,7 @@ def get_adjacent_sigma(self, timesteps, t): ) def set_seed(self, seed): - warnings.warn( - "The method `set_seed` is deprecated and will be removed in version `0.4.0`. Please consider passing a" - " generator instead.", - DeprecationWarning, - ) + deprecate("set_seed", "0.5.0", "Please consider passing a generator instead.") torch.manual_seed(seed) def step_pred( diff --git a/src/diffusers/schedulers/scheduling_sde_vp.py b/src/diffusers/schedulers/scheduling_sde_vp.py index daea743873f1..7cf1da44272a 100644 --- a/src/diffusers/schedulers/scheduling_sde_vp.py +++ b/src/diffusers/schedulers/scheduling_sde_vp.py @@ -17,11 +17,11 @@ # TODO(Patrick, Anton, Suraj) - make scheduler framework independent and clean-up a bit import math -import warnings import torch from ..configuration_utils import ConfigMixin, register_to_config +from ..utils import deprecate from .scheduling_utils import SchedulerMixin @@ -42,12 +42,12 @@ class ScoreSdeVpScheduler(SchedulerMixin, ConfigMixin): @register_to_config def __init__(self, num_train_timesteps=2000, beta_min=0.1, beta_max=20, sampling_eps=1e-3, **kwargs): - if "tensor_format" in kwargs: - warnings.warn( - "`tensor_format` is deprecated as an argument and will be removed in version `0.5.0`." - "If you're running your code in PyTorch, you can safely remove this argument.", - DeprecationWarning, - ) + deprecate( + "tensor_format", + "0.5.0", + "If you're running your code in PyTorch, you can safely remove this argument.", + take_from=kwargs, + ) self.sigmas = None self.discrete_sigmas = None self.timesteps = None diff --git a/src/diffusers/schedulers/scheduling_utils.py b/src/diffusers/schedulers/scheduling_utils.py index 1cc1d94414a6..aba295bc8039 100644 --- a/src/diffusers/schedulers/scheduling_utils.py +++ b/src/diffusers/schedulers/scheduling_utils.py @@ -11,12 +11,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import warnings from dataclasses import dataclass import torch -from ..utils import BaseOutput +from ..utils import BaseOutput, deprecate SCHEDULER_CONFIG_NAME = "scheduler_config.json" @@ -44,10 +43,10 @@ class SchedulerMixin: config_name = SCHEDULER_CONFIG_NAME def set_format(self, tensor_format="pt"): - warnings.warn( - "The method `set_format` is deprecated and will be removed in version `0.5.0`." - "If you're running your code in PyTorch, you can safely remove this function as the schedulers" - "are always in Pytorch", - DeprecationWarning, + deprecate( + "set_format", + "0.5.0", + "If you're running your code in PyTorch, you can safely remove this function as the schedulers are always" + " in Pytorch", ) return self diff --git a/src/diffusers/utils/__init__.py b/src/diffusers/utils/__init__.py index b63dbd2b285c..c1285bb8c23d 100644 --- a/src/diffusers/utils/__init__.py +++ b/src/diffusers/utils/__init__.py @@ -15,6 +15,7 @@ import os +from .deprecation_utils import deprecate from .import_utils import ( ENV_VARS_TRUE_AND_AUTO_VALUES, ENV_VARS_TRUE_VALUES, @@ -35,6 +36,7 @@ ) from .logging import get_logger from .outputs import BaseOutput +from .testing_utils import floats_tensor, load_image, parse_flag_from_env, slow, torch_device logger = get_logger(__name__) diff --git a/src/diffusers/utils/deprecation_utils.py b/src/diffusers/utils/deprecation_utils.py new file mode 100644 index 000000000000..eac43031574f --- /dev/null +++ b/src/diffusers/utils/deprecation_utils.py @@ -0,0 +1,49 @@ +import inspect +import warnings +from typing import Any, Dict, Optional, Union + +from packaging import version + + +def deprecate(*args, take_from: Optional[Union[Dict, Any]] = None, standard_warn=True): + from .. import __version__ + + deprecated_kwargs = take_from + values = () + if not isinstance(args[0], tuple): + args = (args,) + + for attribute, version_name, message in args: + if version.parse(version.parse(__version__).base_version) >= version.parse(version_name): + raise ValueError( + f"The deprecation tuple {(attribute, version_name, message)} should be removed since diffusers'" + f" version {__version__} is >= {version_name}" + ) + + warning = None + if isinstance(deprecated_kwargs, dict) and attribute in deprecated_kwargs: + values += (deprecated_kwargs.pop(attribute),) + warning = f"The `{attribute}` argument is deprecated and will be removed in version {version_name}." + elif hasattr(deprecated_kwargs, attribute): + values += (getattr(deprecated_kwargs, attribute),) + warning = f"The `{attribute}` attribute is deprecated and will be removed in version {version_name}." + elif deprecated_kwargs is None: + warning = f"`{attribute}` is deprecated and will be removed in version {version_name}." + + if warning is not None: + warning = warning + " " if standard_warn else "" + warnings.warn(warning + message, DeprecationWarning) + + if isinstance(deprecated_kwargs, dict) and len(deprecated_kwargs) > 0: + call_frame = inspect.getouterframes(inspect.currentframe())[1] + filename = call_frame.filename + line_number = call_frame.lineno + function = call_frame.function + key, value = next(iter(deprecated_kwargs.items())) + raise TypeError(f"{function} in {filename} line {line_number-1} got an unexpected keyword argument `{key}`") + + if len(values) == 0: + return + elif len(values) == 1: + return values[0] + return values diff --git a/src/diffusers/utils/outputs.py b/src/diffusers/utils/outputs.py index 45d483ce7b1d..10cffeeb0d41 100644 --- a/src/diffusers/utils/outputs.py +++ b/src/diffusers/utils/outputs.py @@ -15,13 +15,13 @@ Generic utilities """ -import warnings from collections import OrderedDict from dataclasses import fields from typing import Any, Tuple import numpy as np +from .deprecation_utils import deprecate from .import_utils import is_torch_available @@ -87,11 +87,7 @@ def __getitem__(self, k): if isinstance(k, str): inner_dict = {k: v for (k, v) in self.items()} if self.__class__.__name__ in ["StableDiffusionPipelineOutput", "ImagePipelineOutput"] and k == "sample": - warnings.warn( - "The keyword 'samples' is deprecated and will be removed in version 0.4.0. Please use `.images` or" - " `'images'` instead.", - DeprecationWarning, - ) + deprecate("samples", "0.6.0", "Please use `.images` or `'images'` instead.") return inner_dict["images"] return inner_dict[k] else: diff --git a/src/diffusers/testing_utils.py b/src/diffusers/utils/testing_utils.py similarity index 100% rename from src/diffusers/testing_utils.py rename to src/diffusers/utils/testing_utils.py diff --git a/tests/conftest.py b/tests/conftest.py index e116f40e6461..3cfab533e43c 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -31,13 +31,13 @@ def pytest_addoption(parser): - from diffusers.testing_utils import pytest_addoption_shared + from diffusers.utils.testing_utils import pytest_addoption_shared pytest_addoption_shared(parser) def pytest_terminal_summary(terminalreporter): - from diffusers.testing_utils import pytest_terminal_summary_main + from diffusers.utils.testing_utils import pytest_terminal_summary_main make_reports = terminalreporter.config.getoption("--make-reports") if make_reports: diff --git a/tests/test_layers_utils.py b/tests/test_layers_utils.py index 4c9b17caa74c..f6cb184651ef 100755 --- a/tests/test_layers_utils.py +++ b/tests/test_layers_utils.py @@ -22,7 +22,7 @@ from diffusers.models.attention import AttentionBlock, SpatialTransformer from diffusers.models.embeddings import get_timestep_embedding from diffusers.models.resnet import Downsample2D, Upsample2D -from diffusers.testing_utils import torch_device +from diffusers.utils import torch_device torch.backends.cuda.matmul.allow_tf32 = False diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index b0d00b863a78..e4e546e55ac3 100644 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -22,8 +22,8 @@ import torch from diffusers.modeling_utils import ModelMixin -from diffusers.testing_utils import torch_device from diffusers.training_utils import EMAModel +from diffusers.utils import torch_device class ModelTesterMixin: diff --git a/tests/test_models_unet.py b/tests/test_models_unet.py index 94a186d1c06a..734fb5924d84 100644 --- a/tests/test_models_unet.py +++ b/tests/test_models_unet.py @@ -19,7 +19,7 @@ import torch from diffusers import UNet2DConditionModel, UNet2DModel -from diffusers.testing_utils import floats_tensor, slow, torch_device +from diffusers.utils import floats_tensor, slow, torch_device from .test_modeling_common import ModelTesterMixin diff --git a/tests/test_models_vae.py b/tests/test_models_vae.py index 361eb618ab22..9fb7e8ea3bb7 100644 --- a/tests/test_models_vae.py +++ b/tests/test_models_vae.py @@ -19,7 +19,7 @@ from diffusers import AutoencoderKL from diffusers.modeling_utils import ModelMixin -from diffusers.testing_utils import floats_tensor, torch_device +from diffusers.utils import floats_tensor, torch_device from .test_modeling_common import ModelTesterMixin diff --git a/tests/test_models_vq.py b/tests/test_models_vq.py index 7cce0ed13e01..9a2094d46cb4 100644 --- a/tests/test_models_vq.py +++ b/tests/test_models_vq.py @@ -18,7 +18,7 @@ import torch from diffusers import VQModel -from diffusers.testing_utils import floats_tensor, torch_device +from diffusers.utils import floats_tensor, torch_device from .test_modeling_common import ModelTesterMixin diff --git a/tests/test_pipelines.py b/tests/test_pipelines.py index cde652f3b828..78a22ec3138b 100644 --- a/tests/test_pipelines.py +++ b/tests/test_pipelines.py @@ -48,8 +48,7 @@ ) from diffusers.pipeline_utils import DiffusionPipeline from diffusers.schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME -from diffusers.testing_utils import floats_tensor, load_image, slow, torch_device -from diffusers.utils import CONFIG_NAME, WEIGHTS_NAME +from diffusers.utils import CONFIG_NAME, WEIGHTS_NAME, floats_tensor, load_image, slow, torch_device from PIL import Image from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer diff --git a/tests/test_training.py b/tests/test_training.py index 41aae07e33c6..fd0828329ebd 100644 --- a/tests/test_training.py +++ b/tests/test_training.py @@ -18,8 +18,8 @@ import torch from diffusers import DDIMScheduler, DDPMScheduler, UNet2DModel -from diffusers.testing_utils import slow from diffusers.training_utils import set_seed +from diffusers.utils.testing_utils import slow torch.backends.cuda.matmul.allow_tf32 = False diff --git a/tests/test_utils.py b/tests/test_utils.py new file mode 100755 index 000000000000..35cf57421014 --- /dev/null +++ b/tests/test_utils.py @@ -0,0 +1,164 @@ +# coding=utf-8 +# Copyright 2022 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +from diffusers import __version__ +from diffusers.utils import deprecate + + +class DeprecateTester(unittest.TestCase): + higher_version = ".".join([str(int(__version__.split(".")[0]) + 1)] + __version__.split(".")[1:]) + lower_version = "0.0.1" + + def test_deprecate_function_arg(self): + kwargs = {"deprecated_arg": 4} + + with self.assertWarns(DeprecationWarning) as warning: + output = deprecate("deprecated_arg", self.higher_version, "message", take_from=kwargs) + + assert output == 4 + assert ( + str(warning.warning) + == f"The `deprecated_arg` argument is deprecated and will be removed in version {self.higher_version}." + " message" + ) + + def test_deprecate_function_arg_tuple(self): + kwargs = {"deprecated_arg": 4} + + with self.assertWarns(DeprecationWarning) as warning: + output = deprecate(("deprecated_arg", self.higher_version, "message"), take_from=kwargs) + + assert output == 4 + assert ( + str(warning.warning) + == f"The `deprecated_arg` argument is deprecated and will be removed in version {self.higher_version}." + " message" + ) + + def test_deprecate_function_args(self): + kwargs = {"deprecated_arg_1": 4, "deprecated_arg_2": 8} + with self.assertWarns(DeprecationWarning) as warning: + output_1, output_2 = deprecate( + ("deprecated_arg_1", self.higher_version, "Hey"), + ("deprecated_arg_2", self.higher_version, "Hey"), + take_from=kwargs, + ) + assert output_1 == 4 + assert output_2 == 8 + assert ( + str(warning.warnings[0].message) + == "The `deprecated_arg_1` argument is deprecated and will be removed in version" + f" {self.higher_version}. Hey" + ) + assert ( + str(warning.warnings[1].message) + == "The `deprecated_arg_2` argument is deprecated and will be removed in version" + f" {self.higher_version}. Hey" + ) + + def test_deprecate_function_incorrect_arg(self): + kwargs = {"deprecated_arg": 4} + + with self.assertRaises(TypeError) as error: + deprecate(("wrong_arg", self.higher_version, "message"), take_from=kwargs) + + assert "test_deprecate_function_incorrect_arg in" in str(error.exception) + assert "line" in str(error.exception) + assert "got an unexpected keyword argument `deprecated_arg`" in str(error.exception) + + def test_deprecate_arg_no_kwarg(self): + with self.assertWarns(DeprecationWarning) as warning: + deprecate(("deprecated_arg", self.higher_version, "message")) + + assert ( + str(warning.warning) + == f"`deprecated_arg` is deprecated and will be removed in version {self.higher_version}. message" + ) + + def test_deprecate_args_no_kwarg(self): + with self.assertWarns(DeprecationWarning) as warning: + deprecate( + ("deprecated_arg_1", self.higher_version, "Hey"), + ("deprecated_arg_2", self.higher_version, "Hey"), + ) + assert ( + str(warning.warnings[0].message) + == f"`deprecated_arg_1` is deprecated and will be removed in version {self.higher_version}. Hey" + ) + assert ( + str(warning.warnings[1].message) + == f"`deprecated_arg_2` is deprecated and will be removed in version {self.higher_version}. Hey" + ) + + def test_deprecate_class_obj(self): + class Args: + arg = 5 + + with self.assertWarns(DeprecationWarning) as warning: + arg = deprecate(("arg", self.higher_version, "message"), take_from=Args()) + + assert arg == 5 + assert ( + str(warning.warning) + == f"The `arg` attribute is deprecated and will be removed in version {self.higher_version}. message" + ) + + def test_deprecate_class_objs(self): + class Args: + arg = 5 + foo = 7 + + with self.assertWarns(DeprecationWarning) as warning: + arg_1, arg_2 = deprecate( + ("arg", self.higher_version, "message"), + ("foo", self.higher_version, "message"), + ("does not exist", self.higher_version, "message"), + take_from=Args(), + ) + + assert arg_1 == 5 + assert arg_2 == 7 + assert ( + str(warning.warning) + == f"The `arg` attribute is deprecated and will be removed in version {self.higher_version}. message" + ) + assert ( + str(warning.warnings[0].message) + == f"The `arg` attribute is deprecated and will be removed in version {self.higher_version}. message" + ) + assert ( + str(warning.warnings[1].message) + == f"The `foo` attribute is deprecated and will be removed in version {self.higher_version}. message" + ) + + def test_deprecate_incorrect_version(self): + kwargs = {"deprecated_arg": 4} + + with self.assertRaises(ValueError) as error: + deprecate(("wrong_arg", self.lower_version, "message"), take_from=kwargs) + + assert ( + str(error.exception) + == "The deprecation tuple ('wrong_arg', '0.0.1', 'message') should be removed since diffusers' version" + f" {__version__} is >= {self.lower_version}" + ) + + def test_deprecate_incorrect_no_standard_warn(self): + with self.assertWarns(DeprecationWarning) as warning: + deprecate(("deprecated_arg", self.higher_version, "This message is better!!!"), standard_warn=False) + + assert str(warning.warning) == "This message is better!!!"