@@ -583,7 +583,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
583583 if device_map is None :
584584 param_device = "cpu"
585585 state_dict = load_state_dict (model_file , variant = variant )
586- cls . convert_deprecated_attention_blocks (state_dict )
586+ model . _convert_deprecated_attention_blocks (state_dict )
587587 # move the params from meta device to cpu
588588 missing_keys = set (model .state_dict ().keys ()) - set (state_dict .keys ())
589589 if len (missing_keys ) > 0 :
@@ -626,7 +626,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
626626 model = cls .from_config (config , ** unused_kwargs )
627627
628628 state_dict = load_state_dict (model_file , variant = variant )
629- cls . convert_deprecated_attention_blocks (state_dict )
629+ model . _convert_deprecated_attention_blocks (state_dict )
630630
631631 model , missing_keys , unexpected_keys , mismatched_keys , error_msgs = cls ._load_pretrained_model (
632632 model ,
@@ -763,42 +763,6 @@ def _find_mismatched_keys(
763763
764764 return model , missing_keys , unexpected_keys , mismatched_keys , error_msgs
765765
766- @classmethod
767- def convert_deprecated_attention_blocks (cls , state_dict ):
768- # We check for the deprecated attention block via the `proj_attn` layer in the state dict
769- # The only other class with a layer called `proj_attn` is in `SelfAttention1d` which is used
770- # by only the top level model, `UNet1DModel`. Since, `UNet1DModel` wont have any of the deprecated
771- # attention blocks, we can just early return.
772- if cls .__name__ == "UNet1DModel" :
773- return
774-
775- deprecated_attention_block_paths = []
776-
777- for k in state_dict .keys ():
778- if "proj_attn.weight" in k :
779- index = k .index ("proj_attn.weight" )
780- path = k [: index - 1 ]
781- deprecated_attention_block_paths .append (path )
782-
783- for path in deprecated_attention_block_paths :
784- # group_norm path stays the same
785-
786- # query -> to_q
787- state_dict [f"{ path } .to_q.weight" ] = state_dict .pop (f"{ path } .query.weight" )
788- state_dict [f"{ path } .to_q.bias" ] = state_dict .pop (f"{ path } .query.bias" )
789-
790- # key -> to_k
791- state_dict [f"{ path } .to_k.weight" ] = state_dict .pop (f"{ path } .key.weight" )
792- state_dict [f"{ path } .to_k.bias" ] = state_dict .pop (f"{ path } .key.bias" )
793-
794- # value -> to_v
795- state_dict [f"{ path } .to_v.weight" ] = state_dict .pop (f"{ path } .value.weight" )
796- state_dict [f"{ path } .to_v.bias" ] = state_dict .pop (f"{ path } .value.bias" )
797-
798- # proj_attn -> to_out.0
799- state_dict [f"{ path } .to_out.0.weight" ] = state_dict .pop (f"{ path } .proj_attn.weight" )
800- state_dict [f"{ path } .to_out.0.bias" ] = state_dict .pop (f"{ path } .proj_attn.bias" )
801-
802766 @property
803767 def device (self ) -> device :
804768 """
@@ -841,3 +805,34 @@ def num_parameters(self, only_trainable: bool = False, exclude_embeddings: bool
841805 return sum (p .numel () for p in non_embedding_parameters if p .requires_grad or not only_trainable )
842806 else :
843807 return sum (p .numel () for p in self .parameters () if p .requires_grad or not only_trainable )
808+
809+ def _convert_deprecated_attention_blocks (self , state_dict ):
810+ deprecated_attention_block_paths = []
811+
812+ def recursive_find_attn_block (name , module ):
813+ if hasattr (module , "_from_deprecated_attn_block" ) and module ._from_deprecated_attn_block :
814+ deprecated_attention_block_paths .append (name )
815+ for sub_name , sub_module in module .named_children ():
816+ sub_name = sub_name if name == "" else f"{ name } .{ sub_name } "
817+ recursive_find_attn_block (sub_name , sub_module )
818+
819+ recursive_find_attn_block ("" , self )
820+
821+ for path in deprecated_attention_block_paths :
822+ # group_norm path stays the same
823+
824+ # query -> to_q
825+ state_dict [f"{ path } .to_q.weight" ] = state_dict .pop (f"{ path } .query.weight" )
826+ state_dict [f"{ path } .to_q.bias" ] = state_dict .pop (f"{ path } .query.bias" )
827+
828+ # key -> to_k
829+ state_dict [f"{ path } .to_k.weight" ] = state_dict .pop (f"{ path } .key.weight" )
830+ state_dict [f"{ path } .to_k.bias" ] = state_dict .pop (f"{ path } .key.bias" )
831+
832+ # value -> to_v
833+ state_dict [f"{ path } .to_v.weight" ] = state_dict .pop (f"{ path } .value.weight" )
834+ state_dict [f"{ path } .to_v.bias" ] = state_dict .pop (f"{ path } .value.bias" )
835+
836+ # proj_attn -> to_out.0
837+ state_dict [f"{ path } .to_out.0.weight" ] = state_dict .pop (f"{ path } .proj_attn.weight" )
838+ state_dict [f"{ path } .to_out.0.bias" ] = state_dict .pop (f"{ path } .proj_attn.bias" )
0 commit comments