Skip to content

Commit be73dfb

Browse files
committed
Merge branch 'main' of github.com:piEsposito/diffusers into main
2 parents 028f824 + 009281e commit be73dfb

File tree

5 files changed

+33
-19
lines changed

5 files changed

+33
-19
lines changed

src/diffusers/configuration_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -323,7 +323,7 @@ def extract_init_dict(cls, config_dict, **kwargs):
323323

324324
# remove attributes from orig class that cannot be expected
325325
orig_cls_name = config_dict.pop("_class_name", cls.__name__)
326-
if orig_cls_name != cls.__name__:
326+
if orig_cls_name != cls.__name__ and hasattr(diffusers_library, orig_cls_name):
327327
orig_cls = getattr(diffusers_library, orig_cls_name)
328328
unexpected_keys_from_orig = cls._get_init_keys(orig_cls) - expected_keys
329329
config_dict = {k: v for k, v in config_dict.items() if k not in unexpected_keys_from_orig}

src/diffusers/models/attention.py

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -244,7 +244,9 @@ def __init__(
244244
self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
245245
self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
246246

247-
self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout))
247+
self.to_out = nn.ModuleList([])
248+
self.to_out.append(nn.Linear(inner_dim, query_dim))
249+
self.to_out.append(nn.Dropout(dropout))
248250

249251
def reshape_heads_to_batch_dim(self, tensor):
250252
batch_size, seq_len, dim = tensor.shape
@@ -283,7 +285,11 @@ def forward(self, hidden_states, context=None, mask=None):
283285
else:
284286
hidden_states = self._sliced_attention(query, key, value, sequence_length, dim)
285287

286-
return self.to_out(hidden_states)
288+
# linear proj
289+
hidden_states = self.to_out[0](hidden_states)
290+
# dropout
291+
hidden_states = self.to_out[1](hidden_states)
292+
return hidden_states
287293

288294
def _attention(self, query, key, value):
289295
# TODO: use baddbmm for better performance
@@ -354,12 +360,19 @@ def __init__(
354360
super().__init__()
355361
inner_dim = int(dim * mult)
356362
dim_out = dim_out if dim_out is not None else dim
357-
project_in = GEGLU(dim, inner_dim)
363+
self.net = nn.ModuleList([])
358364

359-
self.net = nn.Sequential(project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out))
365+
# project in
366+
self.net.append(GEGLU(dim, inner_dim))
367+
# project dropout
368+
self.net.append(nn.Dropout(dropout))
369+
# project out
370+
self.net.append(nn.Linear(inner_dim, dim_out))
360371

361372
def forward(self, hidden_states):
362-
return self.net(hidden_states)
373+
for module in self.net:
374+
hidden_states = module(hidden_states)
375+
return hidden_states
363376

364377

365378
# feedforward

src/diffusers/models/unet_2d_blocks.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1053,7 +1053,6 @@ def __init__(
10531053
cross_attention_dim=1280,
10541054
attention_type="default",
10551055
output_scale_factor=1.0,
1056-
downsample_padding=1,
10571056
add_upsample=True,
10581057
):
10591058
super().__init__()

src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -91,17 +91,18 @@ def __init__(
9191
new_config["steps_offset"] = 1
9292
scheduler._internal_dict = FrozenDict(new_config)
9393

94-
if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
94+
if hasattr(scheduler.config, "skip_prk_steps") and scheduler.config.skip_prk_steps is False:
9595
deprecation_message = (
96-
f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
97-
" `clip_sample` should be set to False in the configuration file. Please make sure to update the"
98-
" config accordingly as not setting `clip_sample` in the config might lead to incorrect results in"
99-
" future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very"
100-
" nice if you could open a Pull request for the `scheduler/scheduler_config.json` file"
96+
f"The configuration file of this scheduler: {scheduler} has not set the configuration"
97+
" `skip_prk_steps`. `skip_prk_steps` should be set to True in the configuration file. Please make"
98+
" sure to update the config accordingly as not setting `skip_prk_steps` in the config might lead to"
99+
" incorrect results in future versions. If you have downloaded this checkpoint from the Hugging Face"
100+
" Hub, it would be very nice if you could open a Pull request for the"
101+
" `scheduler/scheduler_config.json` file"
101102
)
102-
deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False)
103+
deprecate("skip_prk_steps not set", "1.0.0", deprecation_message, standard_warn=False)
103104
new_config = dict(scheduler.config)
104-
new_config["clip_sample"] = False
105+
new_config["skip_prk_steps"] = True
105106
scheduler._internal_dict = FrozenDict(new_config)
106107

107108
if safety_checker is None:

tests/pipelines/stable_diffusion/test_stable_diffusion.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -640,13 +640,14 @@ def test_stable_diffusion(self):
640640
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
641641

642642
def test_stable_diffusion_fast_ddim(self):
643-
sd_pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-1", device_map="auto")
643+
scheduler = DDIMScheduler.from_config("CompVis/stable-diffusion-v1-1", subfolder="scheduler")
644+
645+
sd_pipe = StableDiffusionPipeline.from_pretrained(
646+
"CompVis/stable-diffusion-v1-1", scheduler=scheduler, device_map="auto"
647+
)
644648
sd_pipe = sd_pipe.to(torch_device)
645649
sd_pipe.set_progress_bar_config(disable=None)
646650

647-
scheduler = DDIMScheduler.from_config("CompVis/stable-diffusion-v1-1", subfolder="scheduler")
648-
sd_pipe.scheduler = scheduler
649-
650651
prompt = "A painting of a squirrel eating a burger"
651652
generator = torch.Generator(device=torch_device).manual_seed(0)
652653

0 commit comments

Comments
 (0)