Skip to content

Commit c18941b

Browse files
[Better scheduler docs] Improve usage examples of schedulers (#890)
* [Better scheduler docs] Improve usage examples of schedulers * finish * fix warnings and add test * finish * more replacements * adapt fast tests hf token * correct more * Apply suggestions from code review Co-authored-by: Pedro Cuenca <[email protected]> * Integrate compatibility with euler Co-authored-by: Pedro Cuenca <[email protected]>
1 parent a1ea8c0 commit c18941b

28 files changed

+402
-53
lines changed

.github/workflows/pr_tests.yml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,8 @@ jobs:
4242
python utils/print_env.py
4343
4444
- name: Run all fast tests on CPU
45+
env:
46+
HUGGING_FACE_HUB_TOKEN: ${{ secrets.HUGGING_FACE_HUB_TOKEN }}
4547
run: |
4648
python -m pytest -n 2 --max-worker-restart=0 --dist=loadfile -s -v --make-reports=tests_torch_cpu tests/
4749
@@ -91,6 +93,8 @@ jobs:
9193
9294
- name: Run all fast tests on MPS
9395
shell: arch -arch arm64 bash {0}
96+
env:
97+
HUGGING_FACE_HUB_TOKEN: ${{ secrets.HUGGING_FACE_HUB_TOKEN }}
9498
run: |
9599
${CONDA_RUN} python -m pytest -n 1 -s -v --make-reports=tests_torch_mps tests/
96100

README.md

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -142,11 +142,7 @@ it before the pipeline and pass it to `from_pretrained`.
142142
```python
143143
from diffusers import LMSDiscreteScheduler
144144

145-
lms = LMSDiscreteScheduler(
146-
beta_start=0.00085,
147-
beta_end=0.012,
148-
beta_schedule="scaled_linear"
149-
)
145+
lms = LMSDiscreteScheduler.from_config("CompVis/stable-diffusion-v1-4", subfolder="scheduler")
150146

151147
pipe = StableDiffusionPipeline.from_pretrained(
152148
"runwayml/stable-diffusion-v1-5",

docs/source/quicktour.mdx

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@ you could use it as follows:
121121
```python
122122
>>> from diffusers import LMSDiscreteScheduler
123123

124-
>>> scheduler = LMSDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear")
124+
>>> scheduler = LMSDiscreteScheduler.from_config("CompVis/stable-diffusion-v1-4", subfolder="scheduler")
125125

126126
>>> generator = StableDiffusionPipeline.from_pretrained(
127127
... "runwayml/stable-diffusion-v1-5", scheduler=scheduler, use_auth_token=AUTH_TOKEN

examples/dreambooth/train_dreambooth.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -469,9 +469,7 @@ def main(args):
469469
eps=args.adam_epsilon,
470470
)
471471

472-
noise_scheduler = DDPMScheduler(
473-
beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000
474-
)
472+
noise_scheduler = DDPMScheduler.from_config("CompVis/stable-diffusion-v1-4", subfolder="scheduler")
475473

476474
train_dataset = DreamBoothDataset(
477475
instance_data_root=args.instance_data_dir,

examples/text_to_image/train_text_to_image.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -372,11 +372,7 @@ def main():
372372
weight_decay=args.adam_weight_decay,
373373
eps=args.adam_epsilon,
374374
)
375-
376-
# TODO (patil-suraj): load scheduler using args
377-
noise_scheduler = DDPMScheduler(
378-
beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000
379-
)
375+
noise_scheduler = DDPMScheduler.from_config("CompVis/stable-diffusion-v1-4", subfolder="scheduler")
380376

381377
# Get the datasets: you can either provide your own training and evaluation files (see below)
382378
# or specify a Dataset from the hub (the dataset will be downloaded automatically from the datasets Hub).
@@ -609,9 +605,7 @@ def collate_fn(examples):
609605
vae=vae,
610606
unet=unet,
611607
tokenizer=tokenizer,
612-
scheduler=PNDMScheduler(
613-
beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", skip_prk_steps=True
614-
),
608+
scheduler=PNDMScheduler.from_config("CompVis/stable-diffusion-v1-4", subfolder="scheduler"),
615609
safety_checker=StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker"),
616610
feature_extractor=CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32"),
617611
)

examples/textual_inversion/textual_inversion.py

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -419,13 +419,7 @@ def main():
419419
eps=args.adam_epsilon,
420420
)
421421

422-
# TODO (patil-suraj): load scheduler using args
423-
noise_scheduler = DDPMScheduler(
424-
beta_start=0.00085,
425-
beta_end=0.012,
426-
beta_schedule="scaled_linear",
427-
num_train_timesteps=1000,
428-
)
422+
noise_scheduler = DDPMScheduler.from_config("CompVis/stable-diffusion-v1-4", subfolder="scheduler")
429423

430424
train_dataset = TextualInversionDataset(
431425
data_root=args.train_data_dir,
@@ -558,9 +552,7 @@ def main():
558552
vae=vae,
559553
unet=unet,
560554
tokenizer=tokenizer,
561-
scheduler=PNDMScheduler(
562-
beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", skip_prk_steps=True
563-
),
555+
scheduler=PNDMScheduler.from_config("CompVis/stable-diffusion-v1-4", subfolder="scheduler"),
564556
safety_checker=StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker"),
565557
feature_extractor=CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32"),
566558
)

src/diffusers/configuration_utils.py

Lines changed: 43 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
""" ConfigMixin base class and utilities."""
1717
import dataclasses
1818
import functools
19+
import importlib
1920
import inspect
2021
import json
2122
import os
@@ -48,9 +49,13 @@ class ConfigMixin:
4849
[`~ConfigMixin.save_config`] (should be overridden by parent class).
4950
- **ignore_for_config** (`List[str]`) -- A list of attributes that should not be saved in the config (should be
5051
overridden by parent class).
52+
- **_compatible_classes** (`List[str]`) -- A list of classes that are compatible with the parent class, so that
53+
`from_config` can be used from a class different than the one used to save the config (should be overridden
54+
by parent class).
5155
"""
5256
config_name = None
5357
ignore_for_config = []
58+
_compatible_classes = []
5459

5560
def register_to_config(self, **kwargs):
5661
if self.config_name is None:
@@ -280,9 +285,14 @@ def get_config_dict(
280285

281286
return config_dict
282287

288+
@staticmethod
289+
def _get_init_keys(cls):
290+
return set(dict(inspect.signature(cls.__init__).parameters).keys())
291+
283292
@classmethod
284293
def extract_init_dict(cls, config_dict, **kwargs):
285-
expected_keys = set(dict(inspect.signature(cls.__init__).parameters).keys())
294+
# 1. Retrieve expected config attributes from __init__ signature
295+
expected_keys = cls._get_init_keys(cls)
286296
expected_keys.remove("self")
287297
# remove general kwargs if present in dict
288298
if "kwargs" in expected_keys:
@@ -292,9 +302,36 @@ def extract_init_dict(cls, config_dict, **kwargs):
292302
for arg in cls._flax_internal_args:
293303
expected_keys.remove(arg)
294304

305+
# 2. Remove attributes that cannot be expected from expected config attributes
295306
# remove keys to be ignored
296307
if len(cls.ignore_for_config) > 0:
297308
expected_keys = expected_keys - set(cls.ignore_for_config)
309+
310+
# load diffusers library to import compatible and original scheduler
311+
diffusers_library = importlib.import_module(__name__.split(".")[0])
312+
313+
# remove attributes from compatible classes that orig cannot expect
314+
compatible_classes = [getattr(diffusers_library, c, None) for c in cls._compatible_classes]
315+
# filter out None potentially undefined dummy classes
316+
compatible_classes = [c for c in compatible_classes if c is not None]
317+
expected_keys_comp_cls = set()
318+
for c in compatible_classes:
319+
expected_keys_c = cls._get_init_keys(c)
320+
expected_keys_comp_cls = expected_keys_comp_cls.union(expected_keys_c)
321+
expected_keys_comp_cls = expected_keys_comp_cls - cls._get_init_keys(cls)
322+
config_dict = {k: v for k, v in config_dict.items() if k not in expected_keys_comp_cls}
323+
324+
# remove attributes from orig class that cannot be expected
325+
orig_cls_name = config_dict.pop("_class_name", cls.__name__)
326+
if orig_cls_name != cls.__name__:
327+
orig_cls = getattr(diffusers_library, orig_cls_name)
328+
unexpected_keys_from_orig = cls._get_init_keys(orig_cls) - expected_keys
329+
config_dict = {k: v for k, v in config_dict.items() if k not in unexpected_keys_from_orig}
330+
331+
# remove private attributes
332+
config_dict = {k: v for k, v in config_dict.items() if not k.startswith("_")}
333+
334+
# 3. Create keyword arguments that will be passed to __init__ from expected keyword arguments
298335
init_dict = {}
299336
for key in expected_keys:
300337
if key in kwargs:
@@ -304,23 +341,24 @@ def extract_init_dict(cls, config_dict, **kwargs):
304341
# use value from config dict
305342
init_dict[key] = config_dict.pop(key)
306343

307-
config_dict = {k: v for k, v in config_dict.items() if not k.startswith("_")}
308-
344+
# 4. Give nice warning if unexpected values have been passed
309345
if len(config_dict) > 0:
310346
logger.warning(
311347
f"The config attributes {config_dict} were passed to {cls.__name__}, "
312348
"but are not expected and will be ignored. Please verify your "
313349
f"{cls.config_name} configuration file."
314350
)
315351

316-
unused_kwargs = {**config_dict, **kwargs}
317-
352+
# 5. Give nice info if config attributes are initiliazed to default because they have not been passed
318353
passed_keys = set(init_dict.keys())
319354
if len(expected_keys - passed_keys) > 0:
320355
logger.info(
321356
f"{expected_keys - passed_keys} was not found in config. Values will be initialized to default values."
322357
)
323358

359+
# 6. Define unused keyword arguments
360+
unused_kwargs = {**config_dict, **kwargs}
361+
324362
return init_dict, unused_kwargs
325363

326364
@classmethod

src/diffusers/pipeline_flax_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -272,7 +272,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
272272
>>> # Download pipeline, but overwrite scheduler
273273
>>> from diffusers import LMSDiscreteScheduler
274274
275-
>>> scheduler = LMSDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear")
275+
>>> scheduler = LMSDiscreteScheduler.from_config("runwayml/stable-diffusion-v1-5", subfolder="scheduler")
276276
>>> pipeline = FlaxDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", scheduler=scheduler)
277277
```
278278
"""

src/diffusers/pipeline_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -360,7 +360,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
360360
>>> # Download pipeline, but overwrite scheduler
361361
>>> from diffusers import LMSDiscreteScheduler
362362
363-
>>> scheduler = LMSDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear")
363+
>>> scheduler = LMSDiscreteScheduler.from_config("runwayml/stable-diffusion-v1-5", subfolder="scheduler")
364364
>>> pipeline = DiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", scheduler=scheduler)
365365
```
366366
"""
@@ -602,7 +602,7 @@ def components(self) -> Dict[str, Any]:
602602
... StableDiffusionInpaintPipeline,
603603
... )
604604
605-
>>> img2text = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4")
605+
>>> img2text = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5")
606606
>>> img2img = StableDiffusionImg2ImgPipeline(**img2text.components)
607607
>>> inpaint = StableDiffusionInpaintPipeline(**img2text.components)
608608
```

src/diffusers/pipelines/stable_diffusion/README.md

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ image.save("astronaut_rides_horse.png")
7272
# make sure you're logged in with `huggingface-cli login`
7373
from diffusers import StableDiffusionPipeline, DDIMScheduler
7474

75-
scheduler = DDIMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", clip_sample=False, set_alpha_to_one=False)
75+
scheduler = DDIMScheduler.from_config("CompVis/stable-diffusion-v1-4", subfolder="scheduler")
7676

7777
pipe = StableDiffusionPipeline.from_pretrained(
7878
"runwayml/stable-diffusion-v1-5",
@@ -91,11 +91,7 @@ image.save("astronaut_rides_horse.png")
9191
# make sure you're logged in with `huggingface-cli login`
9292
from diffusers import StableDiffusionPipeline, LMSDiscreteScheduler
9393

94-
lms = LMSDiscreteScheduler(
95-
beta_start=0.00085,
96-
beta_end=0.012,
97-
beta_schedule="scaled_linear"
98-
)
94+
lms = LMSDiscreteScheduler.from_config("CompVis/stable-diffusion-v1-4", subfolder="scheduler")
9995

10096
pipe = StableDiffusionPipeline.from_pretrained(
10197
"runwayml/stable-diffusion-v1-5",

0 commit comments

Comments
 (0)