Skip to content

Commit be5b25d

Browse files
committed
[WIP] Add Flax LoRA Support to Dreambooth
I saw @patrickvonplaten is working on LoRA support for the non-Flax Dreambooth. We've been taking a stab at implementing LoRA support for TPUs, taking example from the patching method used by @cloneofsimo in cloneofsimo/lora. I've got it successfully patching and training, but the output is currently no good. I'm reaching the end of the time I have allocated for this—might pick it up in the future, but for now I'm putting this up in case anyone finds it useful!
1 parent 261a448 commit be5b25d

File tree

3 files changed

+155
-3
lines changed

3 files changed

+155
-3
lines changed

examples/dreambooth/train_dreambooth_flax.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
FlaxStableDiffusionPipeline,
2323
FlaxUNet2DConditionModel,
2424
)
25+
from diffusers.experimental.lora.linear_with_lora_flax import FlaxLora
2526
from diffusers.pipelines.stable_diffusion import FlaxStableDiffusionSafetyChecker
2627
from diffusers.utils import check_min_version
2728
from flax import jax_utils
@@ -97,6 +98,7 @@ def parse_args():
9798
" class_data_dir, additional images will be sampled with class_prompt."
9899
),
99100
)
101+
parser.add_argument("--lora", action="store_true", help="Use LoRA (https://arxiv.org/abs/2106.09685)")
100102
parser.add_argument(
101103
"--output_dir",
102104
type=str,
@@ -444,9 +446,6 @@ def collate_fn(examples):
444446
vae, vae_params = FlaxAutoencoderKL.from_pretrained(
445447
args.pretrained_model_name_or_path, subfolder="vae", dtype=weight_dtype
446448
)
447-
unet, unet_params = FlaxUNet2DConditionModel.from_pretrained(
448-
args.pretrained_model_name_or_path, subfolder="unet", dtype=weight_dtype
449-
)
450449

451450
# Optimization
452451
if args.scale_lr:
@@ -467,6 +466,22 @@ def collate_fn(examples):
467466
adamw,
468467
)
469468

469+
if args.lora:
470+
unet, unet_params = FlaxLora(FlaxUNet2DConditionModel).from_pretrained(
471+
args.pretrained_model_name_or_path,
472+
subfolder="unet",
473+
dtype=weight_dtype,
474+
revision=args.revision,
475+
)
476+
optimizer = optax.masked(optimizer, mask=unet.get_mask)
477+
else:
478+
unet, unet_params = FlaxUNet2DConditionModel.from_pretrained(
479+
args.pretrained_model_name_or_path,
480+
subfolder="unet",
481+
dtype=weight_dtype,
482+
revision=args.revision,
483+
)
484+
470485
unet_state = train_state.TrainState.create(apply_fn=unet.__call__, params=unet_params, tx=optimizer)
471486
text_encoder_state = train_state.TrainState.create(
472487
apply_fn=text_encoder.__call__, params=text_encoder.params, tx=optimizer
Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
import copy
2+
from collections import defaultdict
3+
from typing import Dict, List, Type, Union, cast
4+
5+
import flax.linen as nn
6+
import jax
7+
import jax.numpy as jnp
8+
from diffusers.modeling_flax_utils import FlaxModelMixin
9+
from flax.core.frozen_dict import FrozenDict
10+
from flax.traverse_util import flatten_dict, unflatten_dict
11+
12+
13+
class FlaxLinearWithLora(nn.Module):
14+
out_features: int
15+
rank: int = 5
16+
in_features: int = 1
17+
scale: float = 1.0
18+
use_bias: bool = True
19+
20+
def setup(self):
21+
self.linear = nn.Dense(features=self.out_features, use_bias=self.use_bias)
22+
self.lora_up = nn.Dense(features=self.out_features, use_bias=False)
23+
self.lora_down = nn.Dense(features=4, use_bias=False)
24+
25+
def init_weights(self, rng: jax.random.PRNGKey) -> FrozenDict:
26+
return self.init(rng, jnp.zeros((self.in_features, self.out_features)))
27+
28+
def __call__(self, input):
29+
return self.linear(input) + self.lora_up(self.lora_down(input)) * self.scale
30+
31+
32+
class FlaxLoraBase(nn.Module):
33+
@staticmethod
34+
def _get_children(model: nn.Module) -> Dict[str, nn.Module]:
35+
model._try_setup(shallow=True)
36+
return {k: v for k, v in model._state.children.items() if isinstance(v, nn.Module)}
37+
38+
@staticmethod
39+
def _wrap_dense(params: dict, parent: nn.Module, model: Union[nn.Dense, nn.Module], name: str):
40+
if not isinstance(model, nn.Dense):
41+
return params, {}
42+
43+
params_to_optimize = defaultdict(dict)
44+
45+
parent._in_setup = True
46+
lora = FlaxLinearWithLora(
47+
out_features=model.features,
48+
use_bias=model.use_bias,
49+
name=name,
50+
parent=parent,
51+
)
52+
53+
lora_params = lora.init_weights(jax.random.PRNGKey(0)).unfreeze()["params"]
54+
lora_params["linear"] = params
55+
lora = lora.bind({"params": lora_params})
56+
57+
for k, v in parent.__dict__.items():
58+
if isinstance(v, nn.Module) and v.name == name:
59+
setattr(model.parent, k, lora)
60+
61+
parent._in_setup = False
62+
63+
for n in ["lora_up", "lora_down"]:
64+
params_to_optimize[n] = {k: True for k in lora_params[n].keys()}
65+
params_to_optimize["linear"] = {k: False for k in lora_params["linear"].keys()}
66+
67+
return lora_params, dict(params_to_optimize)
68+
69+
@staticmethod
70+
def inject(
71+
params: Union[dict, FrozenDict],
72+
model: nn.Module,
73+
targets: List[str],
74+
is_target: bool = False,
75+
):
76+
params = params.unfreeze() if isinstance(params, FrozenDict) else copy.copy(params)
77+
params_to_optimize = {}
78+
79+
for name, child in FlaxLoraBase._get_children(model).items():
80+
if is_target:
81+
results = FlaxLoraBase._wrap_dense(params.get(name, {}), model, child, name)
82+
elif child.__class__.__name__ in targets:
83+
results = FlaxLoraBase.inject(params.get(name, {}), child, targets=targets, is_target=True)
84+
else:
85+
results = FlaxLoraBase.inject(params.get(name, {}), child, targets=targets)
86+
87+
params[name], params_to_optimize[name] = results
88+
89+
return params, params_to_optimize
90+
91+
92+
def FlaxLora(model: Type[nn.Module], targets=["FlaxAttentionBlock"]):
93+
class _FlaxLora(model):
94+
def setup(self):
95+
super().setup()
96+
params = cast(FlaxModelMixin, self).init_weights(jax.random.PRNGKey(0))
97+
FlaxLoraBase.inject(params, self, targets=targets)
98+
99+
@classmethod
100+
def from_pretrained(cls, *args, **kwargs):
101+
instance, params = cast(Type[FlaxModelMixin], model).from_pretrained(*args, **kwargs)
102+
params, mask = FlaxLoraBase.inject(params, instance, targets=targets)
103+
mask_values = flatten_dict(mask)
104+
instance.get_mask = lambda params: unflatten_dict(
105+
{k: mask_values.get(k, False) for k in flatten_dict(params, keep_empty_nodes=True).keys()}
106+
)
107+
return instance, params
108+
109+
_FlaxLora.__name__ = f"{model.__name__}Lora"
110+
111+
return _FlaxLora
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
import os
2+
import pdb
3+
4+
import optax
5+
from diffusers import FlaxUNet2DConditionModel
6+
from diffusers.experimental.lora.linear_with_lora_flax import FlaxLora
7+
from flax.training import train_state
8+
from jax.config import config
9+
from jax.experimental.compilation_cache import compilation_cache as cc
10+
11+
12+
config.update("jax_traceback_filtering", "off")
13+
cc.initialize_cache(os.path.expanduser("~/.cache/jax/compilation_cache"))
14+
15+
if __name__ == "__main__":
16+
unet, unet_params = FlaxLora(FlaxUNet2DConditionModel).from_pretrained(
17+
"runwayml/stable-diffusion-v1-5",
18+
subfolder="unet",
19+
revision="flax",
20+
)
21+
get_mask = unet.get_mask
22+
23+
optimizer = optax.masked(optax.adamw(1e-6), mask=get_mask)
24+
unet_state = train_state.TrainState.create(apply_fn=unet.__call__, params=unet_params, tx=optimizer)
25+
26+
pdb.set_trace()

0 commit comments

Comments
 (0)