Skip to content

Commit 84a7d3d

Browse files
committed
use _from_deprecated_attn_block check re: @patrickvonplaten
1 parent e43c390 commit 84a7d3d

File tree

3 files changed

+45
-38
lines changed

3 files changed

+45
-38
lines changed

src/diffusers/models/attention_processor.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@ def __init__(
6868
eps: float = 1e-5,
6969
rescale_output_factor: float = 1.0,
7070
residual_connection: bool = False,
71+
_from_deprecated_attn_block=False,
7172
processor: Optional["AttnProcessor"] = None,
7273
):
7374
super().__init__()
@@ -78,6 +79,10 @@ def __init__(
7879
self.rescale_output_factor = rescale_output_factor
7980
self.residual_connection = residual_connection
8081

82+
# we make use of this private variable to know whether this class is loaded
83+
# with an deprecated state dict so that we can convert it on the fly
84+
self._from_deprecated_attn_block = _from_deprecated_attn_block
85+
8186
self.scale_qk = scale_qk
8287
self.scale = dim_head**-0.5 if self.scale_qk else 1.0
8388

src/diffusers/models/modeling_utils.py

Lines changed: 33 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -583,7 +583,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
583583
if device_map is None:
584584
param_device = "cpu"
585585
state_dict = load_state_dict(model_file, variant=variant)
586-
cls.convert_deprecated_attention_blocks(state_dict)
586+
model._convert_deprecated_attention_blocks(state_dict)
587587
# move the params from meta device to cpu
588588
missing_keys = set(model.state_dict().keys()) - set(state_dict.keys())
589589
if len(missing_keys) > 0:
@@ -626,7 +626,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
626626
model = cls.from_config(config, **unused_kwargs)
627627

628628
state_dict = load_state_dict(model_file, variant=variant)
629-
cls.convert_deprecated_attention_blocks(state_dict)
629+
model._convert_deprecated_attention_blocks(state_dict)
630630

631631
model, missing_keys, unexpected_keys, mismatched_keys, error_msgs = cls._load_pretrained_model(
632632
model,
@@ -763,42 +763,6 @@ def _find_mismatched_keys(
763763

764764
return model, missing_keys, unexpected_keys, mismatched_keys, error_msgs
765765

766-
@classmethod
767-
def convert_deprecated_attention_blocks(cls, state_dict):
768-
# We check for the deprecated attention block via the `proj_attn` layer in the state dict
769-
# The only other class with a layer called `proj_attn` is in `SelfAttention1d` which is used
770-
# by only the top level model, `UNet1DModel`. Since, `UNet1DModel` wont have any of the deprecated
771-
# attention blocks, we can just early return.
772-
if cls.__name__ == "UNet1DModel":
773-
return
774-
775-
deprecated_attention_block_paths = []
776-
777-
for k in state_dict.keys():
778-
if "proj_attn.weight" in k:
779-
index = k.index("proj_attn.weight")
780-
path = k[: index - 1]
781-
deprecated_attention_block_paths.append(path)
782-
783-
for path in deprecated_attention_block_paths:
784-
# group_norm path stays the same
785-
786-
# query -> to_q
787-
state_dict[f"{path}.to_q.weight"] = state_dict.pop(f"{path}.query.weight")
788-
state_dict[f"{path}.to_q.bias"] = state_dict.pop(f"{path}.query.bias")
789-
790-
# key -> to_k
791-
state_dict[f"{path}.to_k.weight"] = state_dict.pop(f"{path}.key.weight")
792-
state_dict[f"{path}.to_k.bias"] = state_dict.pop(f"{path}.key.bias")
793-
794-
# value -> to_v
795-
state_dict[f"{path}.to_v.weight"] = state_dict.pop(f"{path}.value.weight")
796-
state_dict[f"{path}.to_v.bias"] = state_dict.pop(f"{path}.value.bias")
797-
798-
# proj_attn -> to_out.0
799-
state_dict[f"{path}.to_out.0.weight"] = state_dict.pop(f"{path}.proj_attn.weight")
800-
state_dict[f"{path}.to_out.0.bias"] = state_dict.pop(f"{path}.proj_attn.bias")
801-
802766
@property
803767
def device(self) -> device:
804768
"""
@@ -841,3 +805,34 @@ def num_parameters(self, only_trainable: bool = False, exclude_embeddings: bool
841805
return sum(p.numel() for p in non_embedding_parameters if p.requires_grad or not only_trainable)
842806
else:
843807
return sum(p.numel() for p in self.parameters() if p.requires_grad or not only_trainable)
808+
809+
def _convert_deprecated_attention_blocks(self, state_dict):
810+
deprecated_attention_block_paths = []
811+
812+
def recursive_find_attn_block(name, module):
813+
if hasattr(module, "_from_deprecated_attn_block") and module._from_deprecated_attn_block:
814+
deprecated_attention_block_paths.append(name)
815+
for sub_name, sub_module in module.named_children():
816+
sub_name = sub_name if name == "" else f"{name}.{sub_name}"
817+
recursive_find_attn_block(sub_name, sub_module)
818+
819+
recursive_find_attn_block("", self)
820+
821+
for path in deprecated_attention_block_paths:
822+
# group_norm path stays the same
823+
824+
# query -> to_q
825+
state_dict[f"{path}.to_q.weight"] = state_dict.pop(f"{path}.query.weight")
826+
state_dict[f"{path}.to_q.bias"] = state_dict.pop(f"{path}.query.bias")
827+
828+
# key -> to_k
829+
state_dict[f"{path}.to_k.weight"] = state_dict.pop(f"{path}.key.weight")
830+
state_dict[f"{path}.to_k.bias"] = state_dict.pop(f"{path}.key.bias")
831+
832+
# value -> to_v
833+
state_dict[f"{path}.to_v.weight"] = state_dict.pop(f"{path}.value.weight")
834+
state_dict[f"{path}.to_v.bias"] = state_dict.pop(f"{path}.value.bias")
835+
836+
# proj_attn -> to_out.0
837+
state_dict[f"{path}.to_out.0.weight"] = state_dict.pop(f"{path}.proj_attn.weight")
838+
state_dict[f"{path}.to_out.0.bias"] = state_dict.pop(f"{path}.proj_attn.bias")

src/diffusers/models/unet_2d_blocks.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -437,6 +437,7 @@ def __init__(
437437
residual_connection=True,
438438
bias=True,
439439
upcast_softmax=True,
440+
_from_deprecated_attn_block=True,
440441
)
441442
)
442443
else:
@@ -725,6 +726,7 @@ def __init__(
725726
residual_connection=True,
726727
bias=True,
727728
upcast_softmax=True,
729+
_from_deprecated_attn_block=True,
728730
)
729731
)
730732

@@ -1078,6 +1080,7 @@ def __init__(
10781080
residual_connection=True,
10791081
bias=True,
10801082
upcast_softmax=True,
1083+
_from_deprecated_attn_block=True,
10811084
)
10821085
)
10831086

@@ -1156,6 +1159,7 @@ def __init__(
11561159
residual_connection=True,
11571160
bias=True,
11581161
upcast_softmax=True,
1162+
_from_deprecated_attn_block=True,
11591163
)
11601164
)
11611165

@@ -1730,6 +1734,7 @@ def __init__(
17301734
residual_connection=True,
17311735
bias=True,
17321736
upcast_softmax=True,
1737+
_from_deprecated_attn_block=True,
17331738
)
17341739
)
17351740

@@ -2068,6 +2073,7 @@ def __init__(
20682073
residual_connection=True,
20692074
bias=True,
20702075
upcast_softmax=True,
2076+
_from_deprecated_attn_block=True,
20712077
)
20722078
)
20732079

@@ -2144,6 +2150,7 @@ def __init__(
21442150
residual_connection=True,
21452151
bias=True,
21462152
upcast_softmax=True,
2153+
_from_deprecated_attn_block=True,
21472154
)
21482155
)
21492156

0 commit comments

Comments
 (0)