Skip to content

Commit 1138d63

Browse files
authored
Temporary local test for PIL_INTERPOLATION (#1317)
* Temporary local test for PIL_INTERPOLATION * Fix examples too.
1 parent afdd7bb commit 1138d63

File tree

5 files changed

+104
-5
lines changed

5 files changed

+104
-5
lines changed

examples/community/imagic_stable_diffusion.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,31 @@
1717
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
1818
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
1919
from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
20-
from diffusers.utils import PIL_INTERPOLATION, logging
20+
from diffusers.utils import logging
2121
from tqdm.auto import tqdm
2222
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
2323

2424

25+
# TODO: remove and import from diffusers.utils when the new version of diffusers is released
26+
from packaging import version
27+
if version.parse(version.parse(PIL.__version__).base_version) >= version.parse("9.1.0"):
28+
PIL_INTERPOLATION = {
29+
"linear": PIL.Image.Resampling.BILINEAR,
30+
"bilinear": PIL.Image.Resampling.BILINEAR,
31+
"bicubic": PIL.Image.Resampling.BICUBIC,
32+
"lanczos": PIL.Image.Resampling.LANCZOS,
33+
"nearest": PIL.Image.Resampling.NEAREST,
34+
}
35+
else:
36+
PIL_INTERPOLATION = {
37+
"linear": PIL.Image.LINEAR,
38+
"bilinear": PIL.Image.BILINEAR,
39+
"bicubic": PIL.Image.BICUBIC,
40+
"lanczos": PIL.Image.LANCZOS,
41+
"nearest": PIL.Image.NEAREST,
42+
}
43+
# ------------------------------------------------------------------------------
44+
2545
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
2646

2747

examples/community/lpw_stable_diffusion.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,29 @@
1212
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
1313
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
1414
from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
15-
from diffusers.utils import PIL_INTERPOLATION, deprecate, is_accelerate_available, logging
15+
from diffusers.utils import deprecate, is_accelerate_available, logging
1616
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
1717

18+
# TODO: remove and import from diffusers.utils when the new version of diffusers is released
19+
from packaging import version
20+
if version.parse(version.parse(PIL.__version__).base_version) >= version.parse("9.1.0"):
21+
PIL_INTERPOLATION = {
22+
"linear": PIL.Image.Resampling.BILINEAR,
23+
"bilinear": PIL.Image.Resampling.BILINEAR,
24+
"bicubic": PIL.Image.Resampling.BICUBIC,
25+
"lanczos": PIL.Image.Resampling.LANCZOS,
26+
"nearest": PIL.Image.Resampling.NEAREST,
27+
}
28+
else:
29+
PIL_INTERPOLATION = {
30+
"linear": PIL.Image.LINEAR,
31+
"bilinear": PIL.Image.BILINEAR,
32+
"bicubic": PIL.Image.BICUBIC,
33+
"lanczos": PIL.Image.LANCZOS,
34+
"nearest": PIL.Image.NEAREST,
35+
}
36+
# ------------------------------------------------------------------------------
37+
1838

1939
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
2040

examples/community/lpw_stable_diffusion_onnx.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,28 @@
1010
from diffusers.pipeline_utils import DiffusionPipeline
1111
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
1212
from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
13-
from diffusers.utils import PIL_INTERPOLATION, logging
13+
from diffusers.utils import logging
1414
from transformers import CLIPFeatureExtractor, CLIPTokenizer
1515

16+
# TODO: remove and import from diffusers.utils when the new version of diffusers is released
17+
from packaging import version
18+
if version.parse(version.parse(PIL.__version__).base_version) >= version.parse("9.1.0"):
19+
PIL_INTERPOLATION = {
20+
"linear": PIL.Image.Resampling.BILINEAR,
21+
"bilinear": PIL.Image.Resampling.BILINEAR,
22+
"bicubic": PIL.Image.Resampling.BICUBIC,
23+
"lanczos": PIL.Image.Resampling.LANCZOS,
24+
"nearest": PIL.Image.Resampling.NEAREST,
25+
}
26+
else:
27+
PIL_INTERPOLATION = {
28+
"linear": PIL.Image.LINEAR,
29+
"bilinear": PIL.Image.BILINEAR,
30+
"bicubic": PIL.Image.BICUBIC,
31+
"lanczos": PIL.Image.LANCZOS,
32+
"nearest": PIL.Image.NEAREST,
33+
}
34+
# ------------------------------------------------------------------------------
1635

1736
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
1837

examples/textual_inversion/textual_inversion.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,13 +18,33 @@
1818
from diffusers import AutoencoderKL, DDPMScheduler, PNDMScheduler, StableDiffusionPipeline, UNet2DConditionModel
1919
from diffusers.optimization import get_scheduler
2020
from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
21-
from diffusers.utils import PIL_INTERPOLATION
2221
from huggingface_hub import HfFolder, Repository, whoami
2322
from PIL import Image
2423
from torchvision import transforms
2524
from tqdm.auto import tqdm
2625
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
2726

27+
# TODO: remove and import from diffusers.utils when the new version of diffusers is released
28+
from packaging import version
29+
import PIL
30+
if version.parse(version.parse(PIL.__version__).base_version) >= version.parse("9.1.0"):
31+
PIL_INTERPOLATION = {
32+
"linear": PIL.Image.Resampling.BILINEAR,
33+
"bilinear": PIL.Image.Resampling.BILINEAR,
34+
"bicubic": PIL.Image.Resampling.BICUBIC,
35+
"lanczos": PIL.Image.Resampling.LANCZOS,
36+
"nearest": PIL.Image.Resampling.NEAREST,
37+
}
38+
else:
39+
PIL_INTERPOLATION = {
40+
"linear": PIL.Image.LINEAR,
41+
"bilinear": PIL.Image.BILINEAR,
42+
"bicubic": PIL.Image.BICUBIC,
43+
"lanczos": PIL.Image.LANCZOS,
44+
"nearest": PIL.Image.NEAREST,
45+
}
46+
# ------------------------------------------------------------------------------
47+
2848

2949
logger = get_logger(__name__)
3050

examples/textual_inversion/textual_inversion_flax.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@
2323
FlaxUNet2DConditionModel,
2424
)
2525
from diffusers.pipelines.stable_diffusion import FlaxStableDiffusionSafetyChecker
26-
from diffusers.utils import PIL_INTERPOLATION
2726
from flax import jax_utils
2827
from flax.training import train_state
2928
from flax.training.common_utils import shard
@@ -34,6 +33,27 @@
3433
from transformers import CLIPFeatureExtractor, CLIPTokenizer, FlaxCLIPTextModel, set_seed
3534

3635

36+
# TODO: remove and import from diffusers.utils when the new version of diffusers is released
37+
from packaging import version
38+
import PIL
39+
if version.parse(version.parse(PIL.__version__).base_version) >= version.parse("9.1.0"):
40+
PIL_INTERPOLATION = {
41+
"linear": PIL.Image.Resampling.BILINEAR,
42+
"bilinear": PIL.Image.Resampling.BILINEAR,
43+
"bicubic": PIL.Image.Resampling.BICUBIC,
44+
"lanczos": PIL.Image.Resampling.LANCZOS,
45+
"nearest": PIL.Image.Resampling.NEAREST,
46+
}
47+
else:
48+
PIL_INTERPOLATION = {
49+
"linear": PIL.Image.LINEAR,
50+
"bilinear": PIL.Image.BILINEAR,
51+
"bicubic": PIL.Image.BICUBIC,
52+
"lanczos": PIL.Image.LANCZOS,
53+
"nearest": PIL.Image.NEAREST,
54+
}
55+
# ------------------------------------------------------------------------------
56+
3757
logger = logging.getLogger(__name__)
3858

3959

0 commit comments

Comments
 (0)