Skip to content

Commit dbe0719

Browse files
anton-lpcuencapatrickvonplaten
authored
Fix PyCharm/VSCode static type checking for dummy objects (#1596)
* Fix PyCharm/VSCode static type checking for dummy objects * Re-add dummies * Fix AudioDiffusion imports * fix import * fix import * Update utils/check_dummies.py Co-authored-by: Pedro Cuenca <[email protected]> * Update src/diffusers/utils/import_utils.py * Update src/diffusers/__init__.py Co-authored-by: Patrick von Platen <[email protected]> * Update src/diffusers/pipelines/stable_diffusion/__init__.py * fix double import Co-authored-by: Pedro Cuenca <[email protected]> Co-authored-by: Patrick von Platen <[email protected]>
1 parent 03566d8 commit dbe0719

File tree

9 files changed

+148
-87
lines changed

9 files changed

+148
-87
lines changed

src/diffusers/__init__.py

Lines changed: 52 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,11 @@
33
from .configuration_utils import ConfigMixin
44
from .onnx_utils import OnnxRuntimeModel
55
from .utils import (
6+
OptionalDependencyNotAvailable,
67
is_flax_available,
78
is_inflect_available,
89
is_k_diffusion_available,
10+
is_librosa_available,
911
is_onnx_available,
1012
is_scipy_available,
1113
is_torch_available,
@@ -15,7 +17,12 @@
1517
)
1618

1719

18-
if is_torch_available():
20+
try:
21+
if not is_torch_available():
22+
raise OptionalDependencyNotAvailable()
23+
except OptionalDependencyNotAvailable:
24+
from .utils.dummy_pt_objects import * # noqa F403
25+
else:
1926
from .modeling_utils import ModelMixin
2027
from .models import AutoencoderKL, Transformer2DModel, UNet1DModel, UNet2DConditionModel, UNet2DModel, VQModel
2128
from .optimization import (
@@ -29,14 +36,12 @@
2936
)
3037
from .pipeline_utils import DiffusionPipeline
3138
from .pipelines import (
32-
AudioDiffusionPipeline,
3339
DanceDiffusionPipeline,
3440
DDIMPipeline,
3541
DDPMPipeline,
3642
KarrasVePipeline,
3743
LDMPipeline,
3844
LDMSuperResolutionPipeline,
39-
Mel,
4045
PNDMPipeline,
4146
RePaintPipeline,
4247
ScoreSdeVePipeline,
@@ -60,15 +65,22 @@
6065
VQDiffusionScheduler,
6166
)
6267
from .training_utils import EMAModel
63-
else:
64-
from .utils.dummy_pt_objects import * # noqa F403
6568

66-
if is_torch_available() and is_scipy_available():
67-
from .schedulers import LMSDiscreteScheduler
68-
else:
69+
try:
70+
if not (is_torch_available() and is_scipy_available()):
71+
raise OptionalDependencyNotAvailable()
72+
except OptionalDependencyNotAvailable:
6973
from .utils.dummy_torch_and_scipy_objects import * # noqa F403
74+
else:
75+
from .schedulers import LMSDiscreteScheduler
7076

71-
if is_torch_available() and is_transformers_available():
77+
78+
try:
79+
if not (is_torch_available() and is_transformers_available()):
80+
raise OptionalDependencyNotAvailable()
81+
except OptionalDependencyNotAvailable:
82+
from .utils.dummy_torch_and_transformers_objects import * # noqa F403
83+
else:
7284
from .pipelines import (
7385
AltDiffusionImg2ImgPipeline,
7486
AltDiffusionPipeline,
@@ -88,26 +100,43 @@
88100
VersatileDiffusionTextToImagePipeline,
89101
VQDiffusionPipeline,
90102
)
91-
else:
92-
from .utils.dummy_torch_and_transformers_objects import * # noqa F403
93103

94-
if is_torch_available() and is_transformers_available() and is_k_diffusion_available():
95-
from .pipelines import StableDiffusionKDiffusionPipeline
96-
else:
104+
try:
105+
if not (is_torch_available() and is_transformers_available() and is_k_diffusion_available()):
106+
raise OptionalDependencyNotAvailable()
107+
except OptionalDependencyNotAvailable:
97108
from .utils.dummy_torch_and_transformers_and_k_diffusion_objects import * # noqa F403
109+
else:
110+
from .pipelines import StableDiffusionKDiffusionPipeline
98111

99-
if is_torch_available() and is_transformers_available() and is_onnx_available():
112+
try:
113+
if not (is_torch_available() and is_transformers_available() and is_onnx_available()):
114+
raise OptionalDependencyNotAvailable()
115+
except OptionalDependencyNotAvailable:
116+
from .utils.dummy_torch_and_transformers_and_onnx_objects import * # noqa F403
117+
else:
100118
from .pipelines import (
101119
OnnxStableDiffusionImg2ImgPipeline,
102120
OnnxStableDiffusionInpaintPipeline,
103121
OnnxStableDiffusionInpaintPipelineLegacy,
104122
OnnxStableDiffusionPipeline,
105123
StableDiffusionOnnxPipeline,
106124
)
125+
126+
try:
127+
if not (is_torch_available() and is_librosa_available()):
128+
raise OptionalDependencyNotAvailable()
129+
except OptionalDependencyNotAvailable:
130+
from .utils.dummy_torch_and_librosa_objects import * # noqa F403
107131
else:
108-
from .utils.dummy_torch_and_transformers_and_onnx_objects import * # noqa F403
132+
from .pipelines import AudioDiffusionPipeline, Mel
109133

110-
if is_flax_available():
134+
try:
135+
if not is_flax_available():
136+
raise OptionalDependencyNotAvailable()
137+
except OptionalDependencyNotAvailable:
138+
from .utils.dummy_flax_objects import * # noqa F403
139+
else:
111140
from .modeling_flax_utils import FlaxModelMixin
112141
from .models.unet_2d_condition_flax import FlaxUNet2DConditionModel
113142
from .models.vae_flax import FlaxAutoencoderKL
@@ -122,10 +151,11 @@
122151
FlaxSchedulerMixin,
123152
FlaxScoreSdeVeScheduler,
124153
)
125-
else:
126-
from .utils.dummy_flax_objects import * # noqa F403
127154

128-
if is_flax_available() and is_transformers_available():
129-
from .pipelines import FlaxStableDiffusionPipeline
130-
else:
155+
try:
156+
if not (is_flax_available() and is_transformers_available()):
157+
raise OptionalDependencyNotAvailable()
158+
except OptionalDependencyNotAvailable:
131159
from .utils.dummy_flax_and_transformers_objects import * # noqa F403
160+
else:
161+
from .pipelines import FlaxStableDiffusionPipeline

src/diffusers/pipelines/__init__.py

Lines changed: 38 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from ..utils import (
2+
OptionalDependencyNotAvailable,
23
is_flax_available,
34
is_k_diffusion_available,
45
is_librosa_available,
@@ -8,7 +9,12 @@
89
)
910

1011

11-
if is_torch_available():
12+
try:
13+
if not is_torch_available():
14+
raise OptionalDependencyNotAvailable()
15+
except OptionalDependencyNotAvailable:
16+
from ..utils.dummy_pt_objects import * # noqa F403
17+
else:
1218
from .dance_diffusion import DanceDiffusionPipeline
1319
from .ddim import DDIMPipeline
1420
from .ddpm import DDPMPipeline
@@ -18,15 +24,21 @@
1824
from .repaint import RePaintPipeline
1925
from .score_sde_ve import ScoreSdeVePipeline
2026
from .stochastic_karras_ve import KarrasVePipeline
21-
else:
22-
from ..utils.dummy_pt_objects import * # noqa F403
2327

24-
if is_torch_available() and is_librosa_available():
25-
from .audio_diffusion import AudioDiffusionPipeline, Mel
28+
try:
29+
if not (is_torch_available() and is_librosa_available()):
30+
raise OptionalDependencyNotAvailable()
31+
except OptionalDependencyNotAvailable:
32+
from ..utils.dummy_torch_and_librosa_objects import * # noqa F403
2633
else:
27-
from ..utils.dummy_torch_and_librosa_objects import AudioDiffusionPipeline, Mel # noqa F403
34+
from .audio_diffusion import AudioDiffusionPipeline, Mel
2835

29-
if is_torch_available() and is_transformers_available():
36+
try:
37+
if not (is_torch_available() and is_transformers_available()):
38+
raise OptionalDependencyNotAvailable()
39+
except OptionalDependencyNotAvailable:
40+
from ..utils.dummy_torch_and_transformers_objects import * # noqa F403
41+
else:
3042
from .alt_diffusion import AltDiffusionImg2ImgPipeline, AltDiffusionPipeline
3143
from .latent_diffusion import LDMTextToImagePipeline
3244
from .paint_by_example import PaintByExamplePipeline
@@ -48,7 +60,12 @@
4860
)
4961
from .vq_diffusion import VQDiffusionPipeline
5062

51-
if is_transformers_available() and is_onnx_available():
63+
try:
64+
if not (is_torch_available() and is_transformers_available() and is_onnx_available()):
65+
raise OptionalDependencyNotAvailable()
66+
except OptionalDependencyNotAvailable:
67+
from ..utils.dummy_torch_and_transformers_and_onnx_objects import * # noqa F403
68+
else:
5269
from .stable_diffusion import (
5370
OnnxStableDiffusionImg2ImgPipeline,
5471
OnnxStableDiffusionInpaintPipeline,
@@ -57,8 +74,19 @@
5774
StableDiffusionOnnxPipeline,
5875
)
5976

60-
if is_torch_available() and is_transformers_available() and is_k_diffusion_available():
77+
try:
78+
if not (is_torch_available() and is_transformers_available() and is_k_diffusion_available()):
79+
raise OptionalDependencyNotAvailable()
80+
except OptionalDependencyNotAvailable:
81+
from ..utils.dummy_torch_and_transformers_and_k_diffusion_objects import * # noqa F403
82+
else:
6183
from .stable_diffusion import StableDiffusionKDiffusionPipeline
6284

63-
if is_transformers_available() and is_flax_available():
85+
86+
try:
87+
if not (is_flax_available() and is_transformers_available()):
88+
raise OptionalDependencyNotAvailable()
89+
except OptionalDependencyNotAvailable:
90+
from ..utils.dummy_flax_and_transformers_objects import * # noqa F403
91+
else:
6492
from .stable_diffusion import FlaxStableDiffusionPipeline

src/diffusers/pipelines/stable_diffusion/__init__.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
from ...utils import (
1010
BaseOutput,
11+
OptionalDependencyNotAvailable,
1112
is_flax_available,
1213
is_k_diffusion_available,
1314
is_onnx_available,
@@ -44,12 +45,20 @@ class StableDiffusionPipelineOutput(BaseOutput):
4445
from .pipeline_stable_diffusion_upscale import StableDiffusionUpscalePipeline
4546
from .safety_checker import StableDiffusionSafetyChecker
4647

47-
if is_transformers_available() and is_torch_available() and is_transformers_version(">=", "4.25.0.dev0"):
48-
from .pipeline_stable_diffusion_image_variation import StableDiffusionImageVariationPipeline
49-
else:
48+
try:
49+
if not (is_transformers_available() and is_torch_available() and is_transformers_version(">=", "4.25.0.dev0")):
50+
raise OptionalDependencyNotAvailable()
51+
except OptionalDependencyNotAvailable:
5052
from ...utils.dummy_torch_and_transformers_objects import StableDiffusionImageVariationPipeline
53+
else:
54+
from .pipeline_stable_diffusion_image_variation import StableDiffusionImageVariationPipeline
5155

52-
if is_transformers_available() and is_torch_available() and is_k_diffusion_available():
56+
try:
57+
if not (is_torch_available() and is_transformers_available() and is_k_diffusion_available()):
58+
raise OptionalDependencyNotAvailable()
59+
except OptionalDependencyNotAvailable:
60+
from ...utils.dummy_torch_and_transformers_and_k_diffusion_objects import * # noqa F403
61+
else:
5362
from .pipeline_stable_diffusion_k_diffusion import StableDiffusionKDiffusionPipeline
5463

5564
if is_transformers_available() and is_onnx_available():
Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,24 @@
1-
from ...utils import is_torch_available, is_transformers_available, is_transformers_version
1+
from ...utils import (
2+
OptionalDependencyNotAvailable,
3+
is_torch_available,
4+
is_transformers_available,
5+
is_transformers_version,
6+
)
27

38

4-
if is_transformers_available() and is_torch_available() and is_transformers_version(">=", "4.25.0.dev0"):
5-
from .modeling_text_unet import UNetFlatConditionModel
6-
from .pipeline_versatile_diffusion import VersatileDiffusionPipeline
7-
from .pipeline_versatile_diffusion_dual_guided import VersatileDiffusionDualGuidedPipeline
8-
from .pipeline_versatile_diffusion_image_variation import VersatileDiffusionImageVariationPipeline
9-
from .pipeline_versatile_diffusion_text_to_image import VersatileDiffusionTextToImagePipeline
10-
else:
9+
try:
10+
if not (is_transformers_available() and is_torch_available() and is_transformers_version(">=", "4.25.0.dev0")):
11+
raise OptionalDependencyNotAvailable()
12+
except OptionalDependencyNotAvailable:
1113
from ...utils.dummy_torch_and_transformers_objects import (
1214
VersatileDiffusionDualGuidedPipeline,
1315
VersatileDiffusionImageVariationPipeline,
1416
VersatileDiffusionPipeline,
1517
VersatileDiffusionTextToImagePipeline,
1618
)
19+
else:
20+
from .modeling_text_unet import UNetFlatConditionModel
21+
from .pipeline_versatile_diffusion import VersatileDiffusionPipeline
22+
from .pipeline_versatile_diffusion_dual_guided import VersatileDiffusionDualGuidedPipeline
23+
from .pipeline_versatile_diffusion_image_variation import VersatileDiffusionImageVariationPipeline
24+
from .pipeline_versatile_diffusion_text_to_image import VersatileDiffusionTextToImagePipeline

src/diffusers/schedulers/__init__.py

Lines changed: 19 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,15 @@
1313
# limitations under the License.
1414

1515

16-
from ..utils import is_flax_available, is_scipy_available, is_torch_available
16+
from ..utils import OptionalDependencyNotAvailable, is_flax_available, is_scipy_available, is_torch_available
1717

1818

19-
if is_torch_available():
19+
try:
20+
if not is_torch_available():
21+
raise OptionalDependencyNotAvailable()
22+
except OptionalDependencyNotAvailable:
23+
from ..utils.dummy_pt_objects import * # noqa F403
24+
else:
2025
from .scheduling_ddim import DDIMScheduler
2126
from .scheduling_ddpm import DDPMScheduler
2227
from .scheduling_dpmsolver_multistep import DPMSolverMultistepScheduler
@@ -34,10 +39,13 @@
3439
from .scheduling_sde_vp import ScoreSdeVpScheduler
3540
from .scheduling_utils import SchedulerMixin
3641
from .scheduling_vq_diffusion import VQDiffusionScheduler
37-
else:
38-
from ..utils.dummy_pt_objects import * # noqa F403
3942

40-
if is_flax_available():
43+
try:
44+
if not is_flax_available():
45+
raise OptionalDependencyNotAvailable()
46+
except OptionalDependencyNotAvailable:
47+
from ..utils.dummy_flax_objects import * # noqa F403
48+
else:
4149
from .scheduling_ddim_flax import FlaxDDIMScheduler
4250
from .scheduling_ddpm_flax import FlaxDDPMScheduler
4351
from .scheduling_dpmsolver_multistep_flax import FlaxDPMSolverMultistepScheduler
@@ -46,11 +54,12 @@
4654
from .scheduling_pndm_flax import FlaxPNDMScheduler
4755
from .scheduling_sde_ve_flax import FlaxScoreSdeVeScheduler
4856
from .scheduling_utils_flax import FlaxSchedulerMixin, FlaxSchedulerOutput, broadcast_to_shape_from_left
49-
else:
50-
from ..utils.dummy_flax_objects import * # noqa F403
5157

5258

53-
if is_scipy_available() and is_torch_available():
54-
from .scheduling_lms_discrete import LMSDiscreteScheduler
55-
else:
59+
try:
60+
if not (is_torch_available() and is_scipy_available()):
61+
raise OptionalDependencyNotAvailable()
62+
except OptionalDependencyNotAvailable:
5663
from ..utils.dummy_torch_and_scipy_objects import * # noqa F403
64+
else:
65+
from .scheduling_lms_discrete import LMSDiscreteScheduler

src/diffusers/utils/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
USE_TF,
2727
USE_TORCH,
2828
DummyObject,
29+
OptionalDependencyNotAvailable,
2930
is_accelerate_available,
3031
is_flax_available,
3132
is_inflect_available,

src/diffusers/utils/dummy_pt_objects.py

Lines changed: 0 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -152,21 +152,6 @@ def from_pretrained(cls, *args, **kwargs):
152152
requires_backends(cls, ["torch"])
153153

154154

155-
class AudioDiffusionPipeline(metaclass=DummyObject):
156-
_backends = ["torch"]
157-
158-
def __init__(self, *args, **kwargs):
159-
requires_backends(self, ["torch"])
160-
161-
@classmethod
162-
def from_config(cls, *args, **kwargs):
163-
requires_backends(cls, ["torch"])
164-
165-
@classmethod
166-
def from_pretrained(cls, *args, **kwargs):
167-
requires_backends(cls, ["torch"])
168-
169-
170155
class DanceDiffusionPipeline(metaclass=DummyObject):
171156
_backends = ["torch"]
172157

@@ -257,21 +242,6 @@ def from_pretrained(cls, *args, **kwargs):
257242
requires_backends(cls, ["torch"])
258243

259244

260-
class Mel(metaclass=DummyObject):
261-
_backends = ["torch"]
262-
263-
def __init__(self, *args, **kwargs):
264-
requires_backends(self, ["torch"])
265-
266-
@classmethod
267-
def from_config(cls, *args, **kwargs):
268-
requires_backends(cls, ["torch"])
269-
270-
@classmethod
271-
def from_pretrained(cls, *args, **kwargs):
272-
requires_backends(cls, ["torch"])
273-
274-
275245
class PNDMPipeline(metaclass=DummyObject):
276246
_backends = ["torch"]
277247

0 commit comments

Comments
 (0)