Skip to content

Commit c4d2823

Browse files
[SDXL Lora] Fix last ben sdxl lora (#4797)
* Fix last ben sdxl lora * Correct typo * make style
1 parent 4f8853e commit c4d2823

File tree

2 files changed

+43
-14
lines changed

2 files changed

+43
-14
lines changed

src/diffusers/loaders.py

Lines changed: 27 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1084,7 +1084,7 @@ def lora_state_dict(
10841084
# Map SDXL blocks correctly.
10851085
if unet_config is not None:
10861086
# use unet config to remap block numbers
1087-
state_dict = cls._map_sgm_blocks_to_diffusers(state_dict, unet_config)
1087+
state_dict = cls._maybe_map_sgm_blocks_to_diffusers(state_dict, unet_config)
10881088
state_dict, network_alphas = cls._convert_kohya_lora_to_diffusers(state_dict)
10891089

10901090
return state_dict, network_alphas
@@ -1121,24 +1121,41 @@ def _best_guess_weight_name(cls, pretrained_model_name_or_path_or_dict, file_ext
11211121
return weight_name
11221122

11231123
@classmethod
1124-
def _map_sgm_blocks_to_diffusers(cls, state_dict, unet_config, delimiter="_", block_slice_pos=5):
1125-
is_all_unet = all(k.startswith("lora_unet") for k in state_dict)
1124+
def _maybe_map_sgm_blocks_to_diffusers(cls, state_dict, unet_config, delimiter="_", block_slice_pos=5):
1125+
# 1. get all state_dict_keys
1126+
all_keys = state_dict.keys()
1127+
sgm_patterns = ["input_blocks", "middle_block", "output_blocks"]
1128+
1129+
# 2. check if needs remapping, if not return original dict
1130+
is_in_sgm_format = False
1131+
for key in all_keys:
1132+
if any(p in key for p in sgm_patterns):
1133+
is_in_sgm_format = True
1134+
break
1135+
1136+
if not is_in_sgm_format:
1137+
return state_dict
1138+
1139+
# 3. Else remap from SGM patterns
11261140
new_state_dict = {}
11271141
inner_block_map = ["resnets", "attentions", "upsamplers"]
11281142

11291143
# Retrieves # of down, mid and up blocks
11301144
input_block_ids, middle_block_ids, output_block_ids = set(), set(), set()
1131-
for layer in state_dict:
1132-
if "text" not in layer:
1145+
1146+
for layer in all_keys:
1147+
if "text" in layer:
1148+
new_state_dict[layer] = state_dict.pop(layer)
1149+
else:
11331150
layer_id = int(layer.split(delimiter)[:block_slice_pos][-1])
1134-
if "input_blocks" in layer:
1151+
if sgm_patterns[0] in layer:
11351152
input_block_ids.add(layer_id)
1136-
elif "middle_block" in layer:
1153+
elif sgm_patterns[1] in layer:
11371154
middle_block_ids.add(layer_id)
1138-
elif "output_blocks" in layer:
1155+
elif sgm_patterns[2] in layer:
11391156
output_block_ids.add(layer_id)
11401157
else:
1141-
raise ValueError("Checkpoint not supported")
1158+
raise ValueError(f"Checkpoint not supported because layer {layer} not supported.")
11421159

11431160
input_blocks = {
11441161
layer_id: [key for key in state_dict if f"input_blocks{delimiter}{layer_id}" in key]
@@ -1201,12 +1218,8 @@ def _map_sgm_blocks_to_diffusers(cls, state_dict, unet_config, delimiter="_", bl
12011218
)
12021219
new_state_dict[new_key] = state_dict.pop(key)
12031220

1204-
if is_all_unet and len(state_dict) > 0:
1221+
if len(state_dict) > 0:
12051222
raise ValueError("At this point all state dict entries have to be converted.")
1206-
else:
1207-
# Remaining is the text encoder state dict.
1208-
for k, v in state_dict.items():
1209-
new_state_dict.update({k: v})
12101223

12111224
return new_state_dict
12121225

tests/models/test_lora_layers.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -942,3 +942,19 @@ def test_sdxl_1_0_lora(self):
942942
expected = np.array([0.4468, 0.4087, 0.4134, 0.366, 0.3202, 0.3505, 0.3786, 0.387, 0.3535])
943943

944944
self.assertTrue(np.allclose(images, expected, atol=1e-4))
945+
946+
def test_sdxl_1_0_last_ben(self):
947+
generator = torch.Generator().manual_seed(0)
948+
949+
pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0")
950+
pipe.enable_model_cpu_offload()
951+
lora_model_id = "TheLastBen/Papercut_SDXL"
952+
lora_filename = "papercut.safetensors"
953+
pipe.load_lora_weights(lora_model_id, weight_name=lora_filename)
954+
955+
images = pipe("papercut.safetensors", output_type="np", generator=generator, num_inference_steps=2).images
956+
957+
images = images[0, -3:, -3:, -1].flatten()
958+
expected = np.array([0.5244, 0.4347, 0.4312, 0.4246, 0.4398, 0.4409, 0.4884, 0.4938, 0.4094])
959+
960+
self.assertTrue(np.allclose(images, expected, atol=1e-3))

0 commit comments

Comments
 (0)