Skip to content

Commit ba6c180

Browse files
committed
fix fuse / unfuse unet
1 parent 86c7d69 commit ba6c180

File tree

1 file changed

+22
-4
lines changed

1 file changed

+22
-4
lines changed

src/diffusers/loaders.py

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -669,15 +669,33 @@ def fuse_lora(self, lora_scale=1.0):
669669
self.apply(self._fuse_lora_apply)
670670

671671
def _fuse_lora_apply(self, module):
672-
if hasattr(module, "_fuse_lora"):
673-
module._fuse_lora(self.lora_scale)
672+
if not self.use_peft_backend:
673+
if hasattr(module, "_fuse_lora"):
674+
module._fuse_lora(self.lora_scale)
675+
else:
676+
from peft.tuners.tuners_utils import BaseTunerLayer
677+
678+
if isinstance(module, BaseTunerLayer):
679+
if self.lora_scale != 1.0:
680+
module.scale_layer(self.lora_scale)
681+
682+
module.merge()
674683

675684
def unfuse_lora(self):
676685
self.apply(self._unfuse_lora_apply)
677686

678687
def _unfuse_lora_apply(self, module):
679-
if hasattr(module, "_unfuse_lora"):
680-
module._unfuse_lora()
688+
if not self.use_peft_backend:
689+
if hasattr(module, "_unfuse_lora"):
690+
module._unfuse_lora()
691+
else:
692+
from peft.tuners.tuners_utils import BaseTunerLayer
693+
694+
if isinstance(module, BaseTunerLayer):
695+
if self.lora_scale != 1.0:
696+
module.unscale_layer()
697+
698+
module.unmerge()
681699

682700
def set_adapters(
683701
self,

0 commit comments

Comments
 (0)