Skip to content

Commit 9ada90f

Browse files
committed
[WIP] Add Flax LoRA Support to Dreambooth
I saw @patrickvonplaten is working on LoRA support for the non-Flax Dreambooth. We've been taking a stab at implementing LoRA support for TPUs, taking example from the patching method used by @cloneofsimo in cloneofsimo/lora. I've got it successfully patching and training, but the output is currently no good. I'm reaching the end of the time I have allocated for this—might pick it up in the future, but for now I'm putting this up in case anyone finds it useful!
1 parent 247b5fe commit 9ada90f

File tree

3 files changed

+223
-3
lines changed

3 files changed

+223
-3
lines changed

examples/dreambooth/train_dreambooth_flax.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
FlaxStableDiffusionPipeline,
2323
FlaxUNet2DConditionModel,
2424
)
25+
from diffusers.experimental.lora.linear_with_lora_flax import FlaxLora
2526
from diffusers.pipelines.stable_diffusion import FlaxStableDiffusionSafetyChecker
2627
from diffusers.utils import check_min_version
2728
from flax import jax_utils
@@ -114,6 +115,7 @@ def parse_args():
114115
" class_data_dir, additional images will be sampled with class_prompt."
115116
),
116117
)
118+
parser.add_argument("--lora", action="store_true", help="Use LoRA (https://arxiv.org/abs/2106.09685)")
117119
parser.add_argument(
118120
"--output_dir",
119121
type=str,
@@ -474,9 +476,6 @@ def collate_fn(examples):
474476
dtype=weight_dtype,
475477
**vae_kwargs,
476478
)
477-
unet, unet_params = FlaxUNet2DConditionModel.from_pretrained(
478-
args.pretrained_model_name_or_path, subfolder="unet", dtype=weight_dtype, revision=args.revision
479-
)
480479

481480
# Optimization
482481
if args.scale_lr:
@@ -497,6 +496,22 @@ def collate_fn(examples):
497496
adamw,
498497
)
499498

499+
if args.lora:
500+
unet, unet_params = FlaxLora(FlaxUNet2DConditionModel).from_pretrained(
501+
args.pretrained_model_name_or_path,
502+
subfolder="unet",
503+
dtype=weight_dtype,
504+
revision=args.revision,
505+
)
506+
optimizer = optax.masked(optimizer, mask=unet.get_mask)
507+
else:
508+
unet, unet_params = FlaxUNet2DConditionModel.from_pretrained(
509+
args.pretrained_model_name_or_path,
510+
subfolder="unet",
511+
dtype=weight_dtype,
512+
revision=args.revision,
513+
)
514+
500515
unet_state = train_state.TrainState.create(apply_fn=unet.__call__, params=unet_params, tx=optimizer)
501516
text_encoder_state = train_state.TrainState.create(
502517
apply_fn=text_encoder.__call__, params=text_encoder.params, tx=optimizer
Lines changed: 171 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,171 @@
1+
import copy
2+
import dataclasses
3+
from collections import defaultdict
4+
from typing import Dict, List, Type, Union, cast
5+
6+
import flax.linen as nn
7+
import jax
8+
import jax.numpy as jnp
9+
from diffusers.models.modeling_flax_utils import FlaxModelMixin
10+
from flax.core.frozen_dict import FrozenDict
11+
from flax.linen.initializers import zeros
12+
from flax.traverse_util import flatten_dict, unflatten_dict
13+
14+
15+
def replace_module(parent, old_child, new_child):
16+
for k, v in parent.__dict__.items():
17+
if isinstance(v, nn.Module) and v.name == old_child.name:
18+
object.__setattr__(parent, k, new_child)
19+
elif isinstance(v, tuple):
20+
for i, c in enumerate(v):
21+
if isinstance(c, nn.Module) and c.name == old_child.name:
22+
object.__setattr__(parent, k, v[:i] + (new_child,) + v[i + 1 :])
23+
24+
parent._state.children[old_child.name] = new_child
25+
object.__setattr__(new_child, "parent", old_child.parent)
26+
object.__setattr__(new_child, "scope", old_child.scope)
27+
28+
29+
class LoRA:
30+
pass
31+
32+
33+
class FlaxLinearWithLora(nn.Module, LoRA):
34+
features: int
35+
in_features: int = -1
36+
rank: int = 5
37+
scale: float = 1.0
38+
use_bias: bool = True
39+
40+
@nn.compact
41+
def __call__(self, inputs):
42+
linear = nn.Dense(features=self.features, use_bias=self.use_bias, name="linear")
43+
lora_down = nn.Dense(features=self.rank, use_bias=False, name="lora_down")
44+
lora_up = nn.Dense(features=self.features, use_bias=False, kernel_init=zeros, name="lora_up")
45+
46+
return linear(inputs) + lora_up(lora_down(inputs)) * self.scale
47+
48+
49+
class FlaxLoraUtils(nn.Module):
50+
@staticmethod
51+
def _get_children(model: nn.Module) -> Dict[str, nn.Module]:
52+
model._try_setup(shallow=True)
53+
return {k: v for k, v in model._state.children.items() if isinstance(v, nn.Module)}
54+
55+
@staticmethod
56+
def _wrap_dense(params: dict, parent: nn.Module, model: Union[nn.Dense, nn.Module], name: str):
57+
if not isinstance(model, nn.Dense):
58+
return params, {}
59+
60+
lora = FlaxLinearWithLora(
61+
in_features=jnp.shape(params["kernel"])[0],
62+
features=model.features,
63+
use_bias=model.use_bias,
64+
name=name,
65+
parent=None,
66+
)
67+
68+
lora_params = {
69+
"linear": params,
70+
"lora_down": {
71+
"kernel": jax.random.normal(jax.random.PRNGKey(0), (lora.in_features, lora.rank)) * 1.0 / lora.rank
72+
},
73+
"lora_up": {"kernel": jnp.zeros((lora.rank, lora.features))},
74+
}
75+
76+
params_to_optimize = defaultdict(dict)
77+
for n in ["lora_up", "lora_down"]:
78+
params_to_optimize[n] = {k: True for k in lora_params[n].keys()}
79+
params_to_optimize["linear"] = {k: False for k in lora_params["linear"].keys()}
80+
81+
return lora_params, dict(params_to_optimize)
82+
83+
@staticmethod
84+
def wrap(
85+
params: Union[dict, FrozenDict],
86+
model: nn.Module,
87+
targets: List[str],
88+
is_target: bool = False,
89+
):
90+
91+
model = model.bind({"params": params})
92+
if hasattr(model, "init_weights"):
93+
model.init_weights(jax.random.PRNGKey(0))
94+
95+
params = params.unfreeze() if isinstance(params, FrozenDict) else copy.copy(params)
96+
params_to_optimize = {}
97+
98+
for name, child in FlaxLoraUtils._get_children(model).items():
99+
if is_target:
100+
results = FlaxLoraUtils._wrap_dense(params.get(name, {}), model, child, name)
101+
elif child.__class__.__name__ in targets:
102+
results = FlaxLoraUtils.wrap(params.get(name, {}), child, targets=targets, is_target=True)
103+
else:
104+
results = FlaxLoraUtils.wrap(params.get(name, {}), child, targets=targets)
105+
106+
params[name], params_to_optimize[name] = results
107+
108+
return params, params_to_optimize
109+
110+
111+
def wrap_in_lora(model: Type[nn.Module], targets: List[str]):
112+
class _FlaxLora(model, LoRA):
113+
def __init__(self, *args, **kwargs):
114+
super().__init__(*args, **kwargs)
115+
116+
def wrap(self):
117+
for attr in self._state.children.values():
118+
if not isinstance(attr, nn.Module):
119+
continue
120+
if isinstance(attr, LoRA):
121+
continue
122+
123+
if self.__class__.__name__ in targets and isinstance(attr, nn.Dense):
124+
instance = FlaxLinearWithLora(
125+
features=attr.features,
126+
use_bias=attr.use_bias,
127+
name=attr.name,
128+
parent=None,
129+
)
130+
else:
131+
subattrs = {f.name: getattr(attr, f.name) for f in dataclasses.fields(attr) if f.init}
132+
subattrs["parent"] = None
133+
klass = wrap_in_lora(attr.__class__, targets=targets)
134+
instance = klass(**subattrs)
135+
136+
replace_module(self, attr, instance)
137+
138+
def setup(self):
139+
super().setup()
140+
self.wrap()
141+
142+
_FlaxLora.__name__ = f"{model.__name__}Lora"
143+
_FlaxLora.__annotations__ = model.__annotations__
144+
return _FlaxLora
145+
146+
147+
def FlaxLora(model: Type[nn.Module], targets=["FlaxAttentionBlock", "FlaxGEGLU"]):
148+
targets = targets + [f"{t}Lora" for t in targets]
149+
150+
class _LoraFlax(wrap_in_lora(model, targets=targets)):
151+
def __init__(self, *args, **kwargs):
152+
super().__init__(*args, **kwargs)
153+
154+
@classmethod
155+
def from_pretrained(cls, *args, **kwargs):
156+
instance, params = cast(Type[FlaxModelMixin], model).from_pretrained(*args, **kwargs)
157+
params, mask = FlaxLoraUtils.wrap(params, instance, targets=targets)
158+
subattrs = {f.name: getattr(instance, f.name) for f in dataclasses.fields(instance) if f.init}
159+
instance = cls(**subattrs)
160+
mask_values = flatten_dict(mask)
161+
object.__setattr__(
162+
instance,
163+
"get_mask",
164+
lambda params: unflatten_dict(
165+
{k: mask_values.get(k, False) for k in flatten_dict(params, keep_empty_nodes=True).keys()}
166+
),
167+
)
168+
return instance, params
169+
170+
_LoraFlax.__name__ = f"{model.__name__}WithLora"
171+
return _LoraFlax
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
import os
2+
import pdb
3+
4+
import jax
5+
import optax
6+
from diffusers import FlaxUNet2DConditionModel
7+
from diffusers.experimental.lora.linear_with_lora_flax import FlaxLinearWithLora, FlaxLora
8+
from flax.training import train_state
9+
from jax.config import config
10+
from jax.experimental.compilation_cache import compilation_cache as cc
11+
12+
13+
config.update("jax_traceback_filtering", "off")
14+
config.update("jax_experimental_subjaxpr_lowering_cache", True)
15+
cc.initialize_cache(os.path.expanduser("~/.cache/jax/compilation_cache"))
16+
17+
if __name__ == "__main__":
18+
unet, unet_params = FlaxLora(FlaxUNet2DConditionModel).from_pretrained(
19+
"runwayml/stable-diffusion-v1-5",
20+
subfolder="unet",
21+
revision="flax",
22+
)
23+
get_mask = unet.get_mask
24+
25+
assert "lora_up" in unet_params["up_blocks_1"]["attentions_1"]["transformer_blocks_0"]["attn1"]["to_q"].keys()
26+
27+
optimizer = optax.masked(optax.adamw(1e-6), mask=get_mask)
28+
unet_state = train_state.TrainState.create(apply_fn=unet.__call__, params=unet_params, tx=optimizer)
29+
30+
bound = unet.bind({"params": unet_params})
31+
bound.init_weights(jax.random.PRNGKey(0))
32+
33+
assert isinstance(bound.up_blocks[1].attentions[1].transformer_blocks[0].attn1.query, FlaxLinearWithLora)
34+
pdb.set_trace()

0 commit comments

Comments
 (0)