Skip to content

Commit 5973e43

Browse files
committed
add more conditions on conversion
- add better test to check for keys conversion
1 parent 1facd9f commit 5973e43

File tree

1 file changed

+18
-15
lines changed

1 file changed

+18
-15
lines changed

src/diffusers/modeling_flax_pytorch_utils.py

Lines changed: 18 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,6 @@
1414
# limitations under the License.
1515
""" PyTorch - Flax general utilities."""
1616
import re
17-
from typing import Tuple
18-
19-
import numpy as np
2017

2118
import jax.numpy as jnp
2219
from flax.traverse_util import flatten_dict, unflatten_dict
@@ -40,23 +37,29 @@ def rename_key(key):
4037
# PyTorch => Flax #
4138
#####################
4239

43-
# Inspired from https://github.com/huggingface/transformers/blob/c603c80f46881ae18b2ca50770ef65fa4033eacd/src/transformers/modeling_flax_pytorch_utils.py#L69
44-
def rename_key_and_reshape_tensor(
45-
pt_tuple_key: Tuple[str],
46-
pt_tensor: np.ndarray,
47-
) -> (Tuple[str], np.ndarray):
40+
# Adapted from https://github.com/huggingface/transformers/blob/c603c80f46881ae18b2ca50770ef65fa4033eacd/src/transformers/modeling_flax_pytorch_utils.py#L69
41+
# and https://github.com/patil-suraj/stable-diffusion-jax/blob/main/stable_diffusion_jax/convert_diffusers_to_jax.py
42+
def rename_key_and_reshape_tensor(pt_tuple_key, pt_tensor, random_flax_state_dict):
4843
"""Rename PT weight names to corresponding Flax weight names and reshape tensor if necessary"""
4944

50-
# # conv norm or layer norm
51-
# This is not really stable since any module that has the name 'scale'
52-
# Will be affected. Maybe just check pt_tuple_key[-2] ?
45+
# conv norm or layer norm
5346
renamed_pt_tuple_key = pt_tuple_key[:-1] + ("scale",)
54-
if any("norm" in str_ for str_ in pt_tuple_key) and pt_tuple_key[-1] == "weight":
47+
if (
48+
any("norm" in str_ for str_ in pt_tuple_key)
49+
and (pt_tuple_key[-1] == "bias")
50+
and (pt_tuple_key[:-1] + ("bias",) not in random_flax_state_dict)
51+
and (pt_tuple_key[:-1] + ("scale",) in random_flax_state_dict)
52+
):
53+
renamed_pt_tuple_key = pt_tuple_key[:-1] + ("scale",)
54+
return renamed_pt_tuple_key, pt_tensor
55+
elif pt_tuple_key[-1] in ["weight", "gamma"] and pt_tuple_key[:-1] + ("scale",) in random_flax_state_dict:
56+
renamed_pt_tuple_key = pt_tuple_key[:-1] + ("scale",)
5557
return renamed_pt_tuple_key, pt_tensor
5658

5759
# embedding
58-
# For now the embedding layers are not converted
59-
# TODO: figure out how to detect embedding layers
60+
if pt_tuple_key[-1] == "weight" and pt_tuple_key[:-1] + ("embedding",) in random_flax_state_dict:
61+
pt_tuple_key = pt_tuple_key[:-1] + ("embedding",)
62+
return renamed_pt_tuple_key, pt_tensor
6063

6164
# conv layer
6265
renamed_pt_tuple_key = pt_tuple_key[:-1] + ("kernel",)
@@ -99,7 +102,7 @@ def convert_pytorch_state_dict_to_flax(pt_state_dict, flax_model, init_key=42):
99102
pt_tuple_key = tuple(renamed_pt_key.split("."))
100103

101104
# Correctly rename weight parameters
102-
flax_key, flax_tensor = rename_key_and_reshape_tensor(pt_tuple_key, pt_tensor)
105+
flax_key, flax_tensor = rename_key_and_reshape_tensor(pt_tuple_key, pt_tensor, random_flax_state_dict)
103106

104107
if flax_key in random_flax_state_dict:
105108
if flax_tensor.shape != random_flax_state_dict[flax_key].shape:

0 commit comments

Comments
 (0)