diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 1eb564986239..3b82c5a6750c 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -3,9 +3,11 @@ from .configuration_utils import ConfigMixin from .onnx_utils import OnnxRuntimeModel from .utils import ( + OptionalDependencyNotAvailable, is_flax_available, is_inflect_available, is_k_diffusion_available, + is_librosa_available, is_onnx_available, is_scipy_available, is_torch_available, @@ -15,7 +17,12 @@ ) -if is_torch_available(): +try: + if not is_torch_available(): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from .utils.dummy_pt_objects import * # noqa F403 +else: from .modeling_utils import ModelMixin from .models import AutoencoderKL, Transformer2DModel, UNet1DModel, UNet2DConditionModel, UNet2DModel, VQModel from .optimization import ( @@ -29,14 +36,12 @@ ) from .pipeline_utils import DiffusionPipeline from .pipelines import ( - AudioDiffusionPipeline, DanceDiffusionPipeline, DDIMPipeline, DDPMPipeline, KarrasVePipeline, LDMPipeline, LDMSuperResolutionPipeline, - Mel, PNDMPipeline, RePaintPipeline, ScoreSdeVePipeline, @@ -60,15 +65,22 @@ VQDiffusionScheduler, ) from .training_utils import EMAModel -else: - from .utils.dummy_pt_objects import * # noqa F403 -if is_torch_available() and is_scipy_available(): - from .schedulers import LMSDiscreteScheduler -else: +try: + if not (is_torch_available() and is_scipy_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: from .utils.dummy_torch_and_scipy_objects import * # noqa F403 +else: + from .schedulers import LMSDiscreteScheduler -if is_torch_available() and is_transformers_available(): + +try: + if not (is_torch_available() and is_transformers_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from .utils.dummy_torch_and_transformers_objects import * # noqa F403 +else: from .pipelines import ( AltDiffusionImg2ImgPipeline, AltDiffusionPipeline, @@ -88,15 +100,21 @@ VersatileDiffusionTextToImagePipeline, VQDiffusionPipeline, ) -else: - from .utils.dummy_torch_and_transformers_objects import * # noqa F403 -if is_torch_available() and is_transformers_available() and is_k_diffusion_available(): - from .pipelines import StableDiffusionKDiffusionPipeline -else: +try: + if not (is_torch_available() and is_transformers_available() and is_k_diffusion_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: from .utils.dummy_torch_and_transformers_and_k_diffusion_objects import * # noqa F403 +else: + from .pipelines import StableDiffusionKDiffusionPipeline -if is_torch_available() and is_transformers_available() and is_onnx_available(): +try: + if not (is_torch_available() and is_transformers_available() and is_onnx_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from .utils.dummy_torch_and_transformers_and_onnx_objects import * # noqa F403 +else: from .pipelines import ( OnnxStableDiffusionImg2ImgPipeline, OnnxStableDiffusionInpaintPipeline, @@ -104,10 +122,21 @@ OnnxStableDiffusionPipeline, StableDiffusionOnnxPipeline, ) + +try: + if not (is_torch_available() and is_librosa_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from .utils.dummy_torch_and_librosa_objects import * # noqa F403 else: - from .utils.dummy_torch_and_transformers_and_onnx_objects import * # noqa F403 + from .pipelines import AudioDiffusionPipeline, Mel -if is_flax_available(): +try: + if not is_flax_available(): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from .utils.dummy_flax_objects import * # noqa F403 +else: from .modeling_flax_utils import FlaxModelMixin from .models.unet_2d_condition_flax import FlaxUNet2DConditionModel from .models.vae_flax import FlaxAutoencoderKL @@ -122,10 +151,11 @@ FlaxSchedulerMixin, FlaxScoreSdeVeScheduler, ) -else: - from .utils.dummy_flax_objects import * # noqa F403 -if is_flax_available() and is_transformers_available(): - from .pipelines import FlaxStableDiffusionPipeline -else: +try: + if not (is_flax_available() and is_transformers_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: from .utils.dummy_flax_and_transformers_objects import * # noqa F403 +else: + from .pipelines import FlaxStableDiffusionPipeline diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index 49dd0c6a35bf..89605ccd30bd 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -1,4 +1,5 @@ from ..utils import ( + OptionalDependencyNotAvailable, is_flax_available, is_k_diffusion_available, is_librosa_available, @@ -8,7 +9,12 @@ ) -if is_torch_available(): +try: + if not is_torch_available(): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ..utils.dummy_pt_objects import * # noqa F403 +else: from .dance_diffusion import DanceDiffusionPipeline from .ddim import DDIMPipeline from .ddpm import DDPMPipeline @@ -18,15 +24,21 @@ from .repaint import RePaintPipeline from .score_sde_ve import ScoreSdeVePipeline from .stochastic_karras_ve import KarrasVePipeline -else: - from ..utils.dummy_pt_objects import * # noqa F403 -if is_torch_available() and is_librosa_available(): - from .audio_diffusion import AudioDiffusionPipeline, Mel +try: + if not (is_torch_available() and is_librosa_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ..utils.dummy_torch_and_librosa_objects import * # noqa F403 else: - from ..utils.dummy_torch_and_librosa_objects import AudioDiffusionPipeline, Mel # noqa F403 + from .audio_diffusion import AudioDiffusionPipeline, Mel -if is_torch_available() and is_transformers_available(): +try: + if not (is_torch_available() and is_transformers_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ..utils.dummy_torch_and_transformers_objects import * # noqa F403 +else: from .alt_diffusion import AltDiffusionImg2ImgPipeline, AltDiffusionPipeline from .latent_diffusion import LDMTextToImagePipeline from .paint_by_example import PaintByExamplePipeline @@ -48,7 +60,12 @@ ) from .vq_diffusion import VQDiffusionPipeline -if is_transformers_available() and is_onnx_available(): +try: + if not (is_torch_available() and is_transformers_available() and is_onnx_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ..utils.dummy_torch_and_transformers_and_onnx_objects import * # noqa F403 +else: from .stable_diffusion import ( OnnxStableDiffusionImg2ImgPipeline, OnnxStableDiffusionInpaintPipeline, @@ -57,8 +74,19 @@ StableDiffusionOnnxPipeline, ) -if is_torch_available() and is_transformers_available() and is_k_diffusion_available(): +try: + if not (is_torch_available() and is_transformers_available() and is_k_diffusion_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ..utils.dummy_torch_and_transformers_and_k_diffusion_objects import * # noqa F403 +else: from .stable_diffusion import StableDiffusionKDiffusionPipeline -if is_transformers_available() and is_flax_available(): + +try: + if not (is_flax_available() and is_transformers_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ..utils.dummy_flax_and_transformers_objects import * # noqa F403 +else: from .stable_diffusion import FlaxStableDiffusionPipeline diff --git a/src/diffusers/pipelines/stable_diffusion/__init__.py b/src/diffusers/pipelines/stable_diffusion/__init__.py index 729a55fa77a3..ac544cbe0c8f 100644 --- a/src/diffusers/pipelines/stable_diffusion/__init__.py +++ b/src/diffusers/pipelines/stable_diffusion/__init__.py @@ -8,6 +8,7 @@ from ...utils import ( BaseOutput, + OptionalDependencyNotAvailable, is_flax_available, is_k_diffusion_available, is_onnx_available, @@ -44,12 +45,20 @@ class StableDiffusionPipelineOutput(BaseOutput): from .pipeline_stable_diffusion_upscale import StableDiffusionUpscalePipeline from .safety_checker import StableDiffusionSafetyChecker -if is_transformers_available() and is_torch_available() and is_transformers_version(">=", "4.25.0.dev0"): - from .pipeline_stable_diffusion_image_variation import StableDiffusionImageVariationPipeline -else: +try: + if not (is_transformers_available() and is_torch_available() and is_transformers_version(">=", "4.25.0.dev0")): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: from ...utils.dummy_torch_and_transformers_objects import StableDiffusionImageVariationPipeline +else: + from .pipeline_stable_diffusion_image_variation import StableDiffusionImageVariationPipeline -if is_transformers_available() and is_torch_available() and is_k_diffusion_available(): +try: + if not (is_torch_available() and is_transformers_available() and is_k_diffusion_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_and_k_diffusion_objects import * # noqa F403 +else: from .pipeline_stable_diffusion_k_diffusion import StableDiffusionKDiffusionPipeline if is_transformers_available() and is_onnx_available(): diff --git a/src/diffusers/pipelines/versatile_diffusion/__init__.py b/src/diffusers/pipelines/versatile_diffusion/__init__.py index 1d2caa7e2399..3c4b52080a4a 100644 --- a/src/diffusers/pipelines/versatile_diffusion/__init__.py +++ b/src/diffusers/pipelines/versatile_diffusion/__init__.py @@ -1,16 +1,24 @@ -from ...utils import is_torch_available, is_transformers_available, is_transformers_version +from ...utils import ( + OptionalDependencyNotAvailable, + is_torch_available, + is_transformers_available, + is_transformers_version, +) -if is_transformers_available() and is_torch_available() and is_transformers_version(">=", "4.25.0.dev0"): - from .modeling_text_unet import UNetFlatConditionModel - from .pipeline_versatile_diffusion import VersatileDiffusionPipeline - from .pipeline_versatile_diffusion_dual_guided import VersatileDiffusionDualGuidedPipeline - from .pipeline_versatile_diffusion_image_variation import VersatileDiffusionImageVariationPipeline - from .pipeline_versatile_diffusion_text_to_image import VersatileDiffusionTextToImagePipeline -else: +try: + if not (is_transformers_available() and is_torch_available() and is_transformers_version(">=", "4.25.0.dev0")): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: from ...utils.dummy_torch_and_transformers_objects import ( VersatileDiffusionDualGuidedPipeline, VersatileDiffusionImageVariationPipeline, VersatileDiffusionPipeline, VersatileDiffusionTextToImagePipeline, ) +else: + from .modeling_text_unet import UNetFlatConditionModel + from .pipeline_versatile_diffusion import VersatileDiffusionPipeline + from .pipeline_versatile_diffusion_dual_guided import VersatileDiffusionDualGuidedPipeline + from .pipeline_versatile_diffusion_image_variation import VersatileDiffusionImageVariationPipeline + from .pipeline_versatile_diffusion_text_to_image import VersatileDiffusionTextToImagePipeline diff --git a/src/diffusers/schedulers/__init__.py b/src/diffusers/schedulers/__init__.py index 424c93bbf331..b2af345782cf 100644 --- a/src/diffusers/schedulers/__init__.py +++ b/src/diffusers/schedulers/__init__.py @@ -13,10 +13,15 @@ # limitations under the License. -from ..utils import is_flax_available, is_scipy_available, is_torch_available +from ..utils import OptionalDependencyNotAvailable, is_flax_available, is_scipy_available, is_torch_available -if is_torch_available(): +try: + if not is_torch_available(): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ..utils.dummy_pt_objects import * # noqa F403 +else: from .scheduling_ddim import DDIMScheduler from .scheduling_ddpm import DDPMScheduler from .scheduling_dpmsolver_multistep import DPMSolverMultistepScheduler @@ -34,10 +39,13 @@ from .scheduling_sde_vp import ScoreSdeVpScheduler from .scheduling_utils import SchedulerMixin from .scheduling_vq_diffusion import VQDiffusionScheduler -else: - from ..utils.dummy_pt_objects import * # noqa F403 -if is_flax_available(): +try: + if not is_flax_available(): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ..utils.dummy_flax_objects import * # noqa F403 +else: from .scheduling_ddim_flax import FlaxDDIMScheduler from .scheduling_ddpm_flax import FlaxDDPMScheduler from .scheduling_dpmsolver_multistep_flax import FlaxDPMSolverMultistepScheduler @@ -46,11 +54,12 @@ from .scheduling_pndm_flax import FlaxPNDMScheduler from .scheduling_sde_ve_flax import FlaxScoreSdeVeScheduler from .scheduling_utils_flax import FlaxSchedulerMixin, FlaxSchedulerOutput, broadcast_to_shape_from_left -else: - from ..utils.dummy_flax_objects import * # noqa F403 -if is_scipy_available() and is_torch_available(): - from .scheduling_lms_discrete import LMSDiscreteScheduler -else: +try: + if not (is_torch_available() and is_scipy_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: from ..utils.dummy_torch_and_scipy_objects import * # noqa F403 +else: + from .scheduling_lms_discrete import LMSDiscreteScheduler diff --git a/src/diffusers/utils/__init__.py b/src/diffusers/utils/__init__.py index b3f2a9698939..30ead70e43f3 100644 --- a/src/diffusers/utils/__init__.py +++ b/src/diffusers/utils/__init__.py @@ -26,6 +26,7 @@ USE_TF, USE_TORCH, DummyObject, + OptionalDependencyNotAvailable, is_accelerate_available, is_flax_available, is_inflect_available, diff --git a/src/diffusers/utils/dummy_pt_objects.py b/src/diffusers/utils/dummy_pt_objects.py index d8b7001bfd78..dc036c82c6d0 100644 --- a/src/diffusers/utils/dummy_pt_objects.py +++ b/src/diffusers/utils/dummy_pt_objects.py @@ -152,21 +152,6 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) -class AudioDiffusionPipeline(metaclass=DummyObject): - _backends = ["torch"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["torch"]) - - @classmethod - def from_config(cls, *args, **kwargs): - requires_backends(cls, ["torch"]) - - @classmethod - def from_pretrained(cls, *args, **kwargs): - requires_backends(cls, ["torch"]) - - class DanceDiffusionPipeline(metaclass=DummyObject): _backends = ["torch"] @@ -257,21 +242,6 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) -class Mel(metaclass=DummyObject): - _backends = ["torch"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["torch"]) - - @classmethod - def from_config(cls, *args, **kwargs): - requires_backends(cls, ["torch"]) - - @classmethod - def from_pretrained(cls, *args, **kwargs): - requires_backends(cls, ["torch"]) - - class PNDMPipeline(metaclass=DummyObject): _backends = ["torch"] diff --git a/src/diffusers/utils/import_utils.py b/src/diffusers/utils/import_utils.py index cb8ceb97acad..6ebdf7d94e64 100644 --- a/src/diffusers/utils/import_utils.py +++ b/src/diffusers/utils/import_utils.py @@ -430,3 +430,7 @@ def is_transformers_version(operation: str, version: str): if not _transformers_available: return False return compare_versions(parse(_transformers_version), operation, version) + + +class OptionalDependencyNotAvailable(BaseException): + """An error indicating that an optional dependency of Diffusers was not found in the environment.""" diff --git a/utils/check_dummies.py b/utils/check_dummies.py index c5664fa77a92..88b2668213e4 100644 --- a/utils/check_dummies.py +++ b/utils/check_dummies.py @@ -74,13 +74,15 @@ def read_init(): backend_specific_objects = {} # Go through the end of the file while line_index < len(lines): - # If the line is an if is_backend_available, we grab all objects associated. + # If the line contains is_backend_available, we grab all objects associated with the `else` block backend = find_backend(lines[line_index]) if backend is not None: - objects = [] + while not lines[line_index].startswith("else:"): + line_index += 1 line_index += 1 + objects = [] # Until we unindent, add backend objects to the list - while not lines[line_index].startswith("else:"): + while line_index < len(lines) and len(lines[line_index]) > 1: line = lines[line_index] single_line_import_search = _re_single_line_import.search(line) if single_line_import_search is not None: