Skip to content

Commit 0902449

Browse files
Add from_pt argument in .from_pretrained (#527)
* first commit: - add `from_pt` argument in `from_pretrained` function - add `modeling_flax_pytorch_utils.py` file * small nit - fix a small nit - to not enter in the second if condition * major changes - modify FlaxUnet modules - first conversion script - more keys to be matched * keys match - now all keys match - change module names for correct matching - upsample module name changed * working v1 - test pass with atol and rtol= `4e-02` * replace unsued arg * make quality * add small docstring * add more comments - add TODO for embedding layers * small change - use `jnp.expand_dims` for converting `timesteps` in case it is a 0-dimensional array * add more conditions on conversion - add better test to check for keys conversion * make shapes consistent - output `img_w x img_h x n_channels` from the VAE * Revert "make shapes consistent" This reverts commit 4cad1ae. * fix unet shape - channels first!
1 parent ca74951 commit 0902449

File tree

5 files changed

+198
-42
lines changed

5 files changed

+198
-42
lines changed
Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
# coding=utf-8
2+
# Copyright 2022 The HuggingFace Inc. team.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
""" PyTorch - Flax general utilities."""
16+
import re
17+
18+
import jax.numpy as jnp
19+
from flax.traverse_util import flatten_dict, unflatten_dict
20+
from jax.random import PRNGKey
21+
22+
from .utils import logging
23+
24+
25+
logger = logging.get_logger(__name__)
26+
27+
28+
def rename_key(key):
29+
regex = r"\w+[.]\d+"
30+
pats = re.findall(regex, key)
31+
for pat in pats:
32+
key = key.replace(pat, "_".join(pat.split(".")))
33+
return key
34+
35+
36+
#####################
37+
# PyTorch => Flax #
38+
#####################
39+
40+
# Adapted from https://github.com/huggingface/transformers/blob/c603c80f46881ae18b2ca50770ef65fa4033eacd/src/transformers/modeling_flax_pytorch_utils.py#L69
41+
# and https://github.com/patil-suraj/stable-diffusion-jax/blob/main/stable_diffusion_jax/convert_diffusers_to_jax.py
42+
def rename_key_and_reshape_tensor(pt_tuple_key, pt_tensor, random_flax_state_dict):
43+
"""Rename PT weight names to corresponding Flax weight names and reshape tensor if necessary"""
44+
45+
# conv norm or layer norm
46+
renamed_pt_tuple_key = pt_tuple_key[:-1] + ("scale",)
47+
if (
48+
any("norm" in str_ for str_ in pt_tuple_key)
49+
and (pt_tuple_key[-1] == "bias")
50+
and (pt_tuple_key[:-1] + ("bias",) not in random_flax_state_dict)
51+
and (pt_tuple_key[:-1] + ("scale",) in random_flax_state_dict)
52+
):
53+
renamed_pt_tuple_key = pt_tuple_key[:-1] + ("scale",)
54+
return renamed_pt_tuple_key, pt_tensor
55+
elif pt_tuple_key[-1] in ["weight", "gamma"] and pt_tuple_key[:-1] + ("scale",) in random_flax_state_dict:
56+
renamed_pt_tuple_key = pt_tuple_key[:-1] + ("scale",)
57+
return renamed_pt_tuple_key, pt_tensor
58+
59+
# embedding
60+
if pt_tuple_key[-1] == "weight" and pt_tuple_key[:-1] + ("embedding",) in random_flax_state_dict:
61+
pt_tuple_key = pt_tuple_key[:-1] + ("embedding",)
62+
return renamed_pt_tuple_key, pt_tensor
63+
64+
# conv layer
65+
renamed_pt_tuple_key = pt_tuple_key[:-1] + ("kernel",)
66+
if pt_tuple_key[-1] == "weight" and pt_tensor.ndim == 4:
67+
pt_tensor = pt_tensor.transpose(2, 3, 1, 0)
68+
return renamed_pt_tuple_key, pt_tensor
69+
70+
# linear layer
71+
renamed_pt_tuple_key = pt_tuple_key[:-1] + ("kernel",)
72+
if pt_tuple_key[-1] == "weight":
73+
pt_tensor = pt_tensor.T
74+
return renamed_pt_tuple_key, pt_tensor
75+
76+
# old PyTorch layer norm weight
77+
renamed_pt_tuple_key = pt_tuple_key[:-1] + ("weight",)
78+
if pt_tuple_key[-1] == "gamma":
79+
return renamed_pt_tuple_key, pt_tensor
80+
81+
# old PyTorch layer norm bias
82+
renamed_pt_tuple_key = pt_tuple_key[:-1] + ("bias",)
83+
if pt_tuple_key[-1] == "beta":
84+
return renamed_pt_tuple_key, pt_tensor
85+
86+
return pt_tuple_key, pt_tensor
87+
88+
89+
def convert_pytorch_state_dict_to_flax(pt_state_dict, flax_model, init_key=42):
90+
# Step 1: Convert pytorch tensor to numpy
91+
pt_state_dict = {k: v.numpy() for k, v in pt_state_dict.items()}
92+
93+
# Step 2: Since the model is stateless, get random Flax params
94+
random_flax_params = flax_model.init_weights(PRNGKey(init_key))
95+
96+
random_flax_state_dict = flatten_dict(random_flax_params)
97+
flax_state_dict = {}
98+
99+
# Need to change some parameters name to match Flax names
100+
for pt_key, pt_tensor in pt_state_dict.items():
101+
renamed_pt_key = rename_key(pt_key)
102+
pt_tuple_key = tuple(renamed_pt_key.split("."))
103+
104+
# Correctly rename weight parameters
105+
flax_key, flax_tensor = rename_key_and_reshape_tensor(pt_tuple_key, pt_tensor, random_flax_state_dict)
106+
107+
if flax_key in random_flax_state_dict:
108+
if flax_tensor.shape != random_flax_state_dict[flax_key].shape:
109+
raise ValueError(
110+
f"PyTorch checkpoint seems to be incorrect. Weight {pt_key} was expected to be of shape "
111+
f"{random_flax_state_dict[flax_key].shape}, but is {flax_tensor.shape}."
112+
)
113+
114+
# also add unexpected weight so that warning is thrown
115+
flax_state_dict[flax_key] = jnp.asarray(flax_tensor)
116+
117+
return unflatten_dict(flax_state_dict)

src/diffusers/modeling_flax_utils.py

Lines changed: 39 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,8 @@
2727
from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError, RevisionNotFoundError
2828
from requests import HTTPError
2929

30-
from .modeling_utils import WEIGHTS_NAME
30+
from .modeling_flax_pytorch_utils import convert_pytorch_state_dict_to_flax
31+
from .modeling_utils import WEIGHTS_NAME, load_state_dict
3132
from .utils import CONFIG_NAME, DIFFUSERS_CACHE, HUGGINGFACE_CO_RESOLVE_ENDPOINT, logging
3233

3334

@@ -245,6 +246,8 @@ def from_pretrained(
245246
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
246247
git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
247248
identifier allowed by git.
249+
from_pt (`bool`, *optional*, defaults to `False`):
250+
Load the model weights from a PyTorch checkpoint save file.
248251
kwargs (remaining dictionary of keyword arguments, *optional*):
249252
Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
250253
`output_attentions=True`). Behaves differently depending on whether a `config` is provided or
@@ -272,6 +275,7 @@ def from_pretrained(
272275
config = kwargs.pop("config", None)
273276
cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
274277
force_download = kwargs.pop("force_download", False)
278+
from_pt = kwargs.pop("from_pt", False)
275279
resume_download = kwargs.pop("resume_download", False)
276280
proxies = kwargs.pop("proxies", None)
277281
local_files_only = kwargs.pop("local_files_only", False)
@@ -306,10 +310,16 @@ def from_pretrained(
306310
# Load from a Flax checkpoint
307311
model_file = os.path.join(pretrained_model_name_or_path, FLAX_WEIGHTS_NAME)
308312
# At this stage we don't have a weight file so we will raise an error.
313+
elif from_pt:
314+
if not os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME):
315+
raise EnvironmentError(
316+
f"Error no file named {WEIGHTS_NAME} found in directory {pretrained_model_name_or_path} "
317+
)
318+
model_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)
309319
elif os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME):
310320
raise EnvironmentError(
311-
f"Error no file named {FLAX_WEIGHTS_NAME} found in directory {pretrained_model_name_or_path} "
312-
"but there is a file for PyTorch weights."
321+
f"{WEIGHTS_NAME} file found in directory {pretrained_model_name_or_path}. Please load the model"
322+
" using `from_pt=True`."
313323
)
314324
else:
315325
raise EnvironmentError(
@@ -320,7 +330,7 @@ def from_pretrained(
320330
try:
321331
model_file = hf_hub_download(
322332
pretrained_model_name_or_path,
323-
filename=FLAX_WEIGHTS_NAME,
333+
filename=FLAX_WEIGHTS_NAME if not from_pt else WEIGHTS_NAME,
324334
cache_dir=cache_dir,
325335
force_download=force_download,
326336
proxies=proxies,
@@ -370,25 +380,32 @@ def from_pretrained(
370380
f"containing a file named {FLAX_WEIGHTS_NAME} or {WEIGHTS_NAME}."
371381
)
372382

373-
try:
374-
with open(model_file, "rb") as state_f:
375-
state = from_bytes(cls, state_f.read())
376-
except (UnpicklingError, msgpack.exceptions.ExtraData) as e:
383+
if from_pt:
384+
# Step 1: Get the pytorch file
385+
pytorch_model_file = load_state_dict(model_file)
386+
387+
# Step 2: Convert the weights
388+
state = convert_pytorch_state_dict_to_flax(pytorch_model_file, model)
389+
else:
377390
try:
378-
with open(model_file) as f:
379-
if f.read().startswith("version"):
380-
raise OSError(
381-
"You seem to have cloned a repository without having git-lfs installed. Please"
382-
" install git-lfs and run `git lfs install` followed by `git lfs pull` in the"
383-
" folder you cloned."
384-
)
385-
else:
386-
raise ValueError from e
387-
except (UnicodeDecodeError, ValueError):
388-
raise EnvironmentError(f"Unable to convert {model_file} to Flax deserializable object. ")
389-
# make sure all arrays are stored as jnp.ndarray
390-
# NOTE: This is to prevent a bug this will be fixed in Flax >= v0.3.4:
391-
# https://github.com/google/flax/issues/1261
391+
with open(model_file, "rb") as state_f:
392+
state = from_bytes(cls, state_f.read())
393+
except (UnpicklingError, msgpack.exceptions.ExtraData) as e:
394+
try:
395+
with open(model_file) as f:
396+
if f.read().startswith("version"):
397+
raise OSError(
398+
"You seem to have cloned a repository without having git-lfs installed. Please"
399+
" install git-lfs and run `git lfs install` followed by `git lfs pull` in the"
400+
" folder you cloned."
401+
)
402+
else:
403+
raise ValueError from e
404+
except (UnicodeDecodeError, ValueError):
405+
raise EnvironmentError(f"Unable to convert {model_file} to Flax deserializable object. ")
406+
# make sure all arrays are stored as jnp.ndarray
407+
# NOTE: This is to prevent a bug this will be fixed in Flax >= v0.3.4:
408+
# https://github.com/google/flax/issues/1261
392409
state = jax.tree_util.tree_map(lambda x: jax.device_put(x, jax.devices("cpu")[0]), state)
393410

394411
# flatten dicts

src/diffusers/models/attention_flax.py

Lines changed: 25 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ def setup(self):
3232
self.key = nn.Dense(inner_dim, use_bias=False, dtype=self.dtype, name="to_k")
3333
self.value = nn.Dense(inner_dim, use_bias=False, dtype=self.dtype, name="to_v")
3434

35-
self.proj_attn = nn.Dense(self.query_dim, dtype=self.dtype, name="to_out")
35+
self.proj_attn = nn.Dense(self.query_dim, dtype=self.dtype, name="to_out_0")
3636

3737
def reshape_heads_to_batch_dim(self, tensor):
3838
batch_size, seq_len, dim = tensor.shape
@@ -82,9 +82,9 @@ class FlaxBasicTransformerBlock(nn.Module):
8282

8383
def setup(self):
8484
# self attention
85-
self.self_attn = FlaxAttentionBlock(self.dim, self.n_heads, self.d_head, self.dropout, dtype=self.dtype)
85+
self.attn1 = FlaxAttentionBlock(self.dim, self.n_heads, self.d_head, self.dropout, dtype=self.dtype)
8686
# cross attention
87-
self.cross_attn = FlaxAttentionBlock(self.dim, self.n_heads, self.d_head, self.dropout, dtype=self.dtype)
87+
self.attn2 = FlaxAttentionBlock(self.dim, self.n_heads, self.d_head, self.dropout, dtype=self.dtype)
8888
self.ff = FlaxGluFeedForward(dim=self.dim, dropout=self.dropout, dtype=self.dtype)
8989
self.norm1 = nn.LayerNorm(epsilon=1e-5, dtype=self.dtype)
9090
self.norm2 = nn.LayerNorm(epsilon=1e-5, dtype=self.dtype)
@@ -93,12 +93,12 @@ def setup(self):
9393
def __call__(self, hidden_states, context, deterministic=True):
9494
# self attention
9595
residual = hidden_states
96-
hidden_states = self.self_attn(self.norm1(hidden_states), deterministic=deterministic)
96+
hidden_states = self.attn1(self.norm1(hidden_states), deterministic=deterministic)
9797
hidden_states = hidden_states + residual
9898

9999
# cross attention
100100
residual = hidden_states
101-
hidden_states = self.cross_attn(self.norm2(hidden_states), context, deterministic=deterministic)
101+
hidden_states = self.attn2(self.norm2(hidden_states), context, deterministic=deterministic)
102102
hidden_states = hidden_states + residual
103103

104104
# feed forward
@@ -167,14 +167,28 @@ class FlaxGluFeedForward(nn.Module):
167167
dropout: float = 0.0
168168
dtype: jnp.dtype = jnp.float32
169169

170+
def setup(self):
171+
# The second linear layer needs to be called
172+
# net_2 for now to match the index of the Sequential layer
173+
self.net_0 = FlaxGEGLU(self.dim, self.dropout, self.dtype)
174+
self.net_2 = nn.Dense(self.dim, dtype=self.dtype)
175+
176+
def __call__(self, hidden_states, deterministic=True):
177+
hidden_states = self.net_0(hidden_states)
178+
hidden_states = self.net_2(hidden_states)
179+
return hidden_states
180+
181+
182+
class FlaxGEGLU(nn.Module):
183+
dim: int
184+
dropout: float = 0.0
185+
dtype: jnp.dtype = jnp.float32
186+
170187
def setup(self):
171188
inner_dim = self.dim * 4
172-
self.dense1 = nn.Dense(inner_dim * 2, dtype=self.dtype)
173-
self.dense2 = nn.Dense(self.dim, dtype=self.dtype)
189+
self.proj = nn.Dense(inner_dim * 2, dtype=self.dtype)
174190

175191
def __call__(self, hidden_states, deterministic=True):
176-
hidden_states = self.dense1(hidden_states)
192+
hidden_states = self.proj(hidden_states)
177193
hidden_linear, hidden_gelu = jnp.split(hidden_states, 2, axis=2)
178-
hidden_states = hidden_linear * nn.gelu(hidden_gelu)
179-
hidden_states = self.dense2(hidden_states)
180-
return hidden_states
194+
return hidden_linear * nn.gelu(hidden_gelu)

src/diffusers/models/unet_2d_condition_flax.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
7676

7777
def init_weights(self, rng: jax.random.PRNGKey) -> FrozenDict:
7878
# init input tensors
79-
sample_shape = (1, self.sample_size, self.sample_size, self.in_channels)
79+
sample_shape = (1, self.in_channels, self.sample_size, self.sample_size)
8080
sample = jnp.zeros(sample_shape, dtype=jnp.float32)
8181
timesteps = jnp.ones((1,), dtype=jnp.int32)
8282
encoder_hidden_states = jnp.zeros((1, 1, self.cross_attention_dim), dtype=jnp.float32)
@@ -214,10 +214,17 @@ def __call__(
214214
When returning a tuple, the first element is the sample tensor.
215215
"""
216216
# 1. time
217+
if not isinstance(timesteps, jnp.ndarray):
218+
timesteps = jnp.array([timesteps], dtype=jnp.int32)
219+
elif isinstance(timesteps, jnp.ndarray) and len(timesteps.shape) == 0:
220+
timesteps = timesteps.astype(dtype=jnp.float32)
221+
timesteps = jnp.expand_dims(timesteps, 0)
222+
217223
t_emb = self.time_proj(timesteps)
218224
t_emb = self.time_embedding(t_emb)
219225

220226
# 2. pre-process
227+
sample = jnp.transpose(sample, (0, 2, 3, 1))
221228
sample = self.conv_in(sample)
222229

223230
# 3. down
@@ -251,6 +258,7 @@ def __call__(
251258
sample = self.conv_norm_out(sample)
252259
sample = nn.silu(sample)
253260
sample = self.conv_out(sample)
261+
sample = jnp.transpose(sample, (0, 3, 1, 2))
254262

255263
if not return_dict:
256264
return (sample,)

0 commit comments

Comments
 (0)