Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions tests/models/autoencoders/test_models_autoencoder_kl.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,13 +35,13 @@
torch_all_close,
torch_device,
)
from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin
from ..test_modeling_common import ModelTesterMixin, PEFTTesterMixin, UNetTesterMixin


enable_full_determinism()


class AutoencoderKLTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
class AutoencoderKLTests(ModelTesterMixin, UNetTesterMixin, PEFTTesterMixin, unittest.TestCase):
model_class = AutoencoderKL
main_input_name = "sample"
base_precision = 1e-2
Expand Down
324 changes: 153 additions & 171 deletions tests/models/test_modeling_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,12 @@


if is_peft_available():
from peft import LoraConfig
from peft.tuners.tuners_utils import BaseTunerLayer
from peft.utils import get_peft_model_state_dict

from diffusers.loaders.lora_base import LORA_ADAPTER_METADATA_KEY
from diffusers.loaders.peft import PeftAdapterMixin


def caculate_expected_num_shards(index_map_path):
Expand Down Expand Up @@ -1113,177 +1118,6 @@ def test_deprecated_kwargs(self):
" from `_deprecated_kwargs = [<deprecated_argument>]`"
)

@parameterized.expand([(4, 4, True), (4, 8, False), (8, 4, False)])
@torch.no_grad()
@unittest.skipIf(not is_peft_available(), "Only with PEFT")
def test_save_load_lora_adapter(self, rank, lora_alpha, use_dora=False):
from peft import LoraConfig
from peft.utils import get_peft_model_state_dict

from diffusers.loaders.peft import PeftAdapterMixin

init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict).to(torch_device)

if not issubclass(model.__class__, PeftAdapterMixin):
pytest.skip(f"PEFT is not supported for this model ({model.__class__.__name__}).")

torch.manual_seed(0)
output_no_lora = model(**inputs_dict, return_dict=False)[0]

denoiser_lora_config = LoraConfig(
r=rank,
lora_alpha=lora_alpha,
target_modules=["to_q", "to_k", "to_v", "to_out.0"],
init_lora_weights=False,
use_dora=use_dora,
)
model.add_adapter(denoiser_lora_config)
self.assertTrue(check_if_lora_correctly_set(model), "LoRA layers not set correctly")

torch.manual_seed(0)
outputs_with_lora = model(**inputs_dict, return_dict=False)[0]

self.assertFalse(torch.allclose(output_no_lora, outputs_with_lora, atol=1e-4, rtol=1e-4))

with tempfile.TemporaryDirectory() as tmpdir:
model.save_lora_adapter(tmpdir)
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))

state_dict_loaded = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))

model.unload_lora()
self.assertFalse(check_if_lora_correctly_set(model), "LoRA layers not set correctly")

model.load_lora_adapter(tmpdir, prefix=None, use_safetensors=True)
state_dict_retrieved = get_peft_model_state_dict(model, adapter_name="default_0")

for k in state_dict_loaded:
loaded_v = state_dict_loaded[k]
retrieved_v = state_dict_retrieved[k].to(loaded_v.device)
self.assertTrue(torch.allclose(loaded_v, retrieved_v))

self.assertTrue(check_if_lora_correctly_set(model), "LoRA layers not set correctly")

torch.manual_seed(0)
outputs_with_lora_2 = model(**inputs_dict, return_dict=False)[0]

self.assertFalse(torch.allclose(output_no_lora, outputs_with_lora_2, atol=1e-4, rtol=1e-4))
self.assertTrue(torch.allclose(outputs_with_lora, outputs_with_lora_2, atol=1e-4, rtol=1e-4))

@unittest.skipIf(not is_peft_available(), "Only with PEFT")
def test_lora_wrong_adapter_name_raises_error(self):
from peft import LoraConfig

from diffusers.loaders.peft import PeftAdapterMixin

init_dict, _ = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict).to(torch_device)

if not issubclass(model.__class__, PeftAdapterMixin):
pytest.skip(f"PEFT is not supported for this model ({model.__class__.__name__}).")

denoiser_lora_config = LoraConfig(
r=4,
lora_alpha=4,
target_modules=["to_q", "to_k", "to_v", "to_out.0"],
init_lora_weights=False,
use_dora=False,
)
model.add_adapter(denoiser_lora_config)
self.assertTrue(check_if_lora_correctly_set(model), "LoRA layers not set correctly")

with tempfile.TemporaryDirectory() as tmpdir:
wrong_name = "foo"
with self.assertRaises(ValueError) as err_context:
model.save_lora_adapter(tmpdir, adapter_name=wrong_name)

self.assertTrue(f"Adapter name {wrong_name} not found in the model." in str(err_context.exception))

@parameterized.expand([(4, 4, True), (4, 8, False), (8, 4, False)])
@torch.no_grad()
@unittest.skipIf(not is_peft_available(), "Only with PEFT")
def test_lora_adapter_metadata_is_loaded_correctly(self, rank, lora_alpha, use_dora):
from peft import LoraConfig

from diffusers.loaders.peft import PeftAdapterMixin

init_dict, _ = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict).to(torch_device)

if not issubclass(model.__class__, PeftAdapterMixin):
pytest.skip(f"PEFT is not supported for this model ({model.__class__.__name__}).")

denoiser_lora_config = LoraConfig(
r=rank,
lora_alpha=lora_alpha,
target_modules=["to_q", "to_k", "to_v", "to_out.0"],
init_lora_weights=False,
use_dora=use_dora,
)
model.add_adapter(denoiser_lora_config)
metadata = model.peft_config["default"].to_dict()
self.assertTrue(check_if_lora_correctly_set(model), "LoRA layers not set correctly")

with tempfile.TemporaryDirectory() as tmpdir:
model.save_lora_adapter(tmpdir)
model_file = os.path.join(tmpdir, "pytorch_lora_weights.safetensors")
self.assertTrue(os.path.isfile(model_file))

model.unload_lora()
self.assertFalse(check_if_lora_correctly_set(model), "LoRA layers not set correctly")

model.load_lora_adapter(tmpdir, prefix=None, use_safetensors=True)
parsed_metadata = model.peft_config["default_0"].to_dict()
check_if_dicts_are_equal(metadata, parsed_metadata)

@torch.no_grad()
@unittest.skipIf(not is_peft_available(), "Only with PEFT")
def test_lora_adapter_wrong_metadata_raises_error(self):
from peft import LoraConfig

from diffusers.loaders.lora_base import LORA_ADAPTER_METADATA_KEY
from diffusers.loaders.peft import PeftAdapterMixin

init_dict, _ = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict).to(torch_device)

if not issubclass(model.__class__, PeftAdapterMixin):
pytest.skip(f"PEFT is not supported for this model ({model.__class__.__name__}).")

denoiser_lora_config = LoraConfig(
r=4,
lora_alpha=4,
target_modules=["to_q", "to_k", "to_v", "to_out.0"],
init_lora_weights=False,
use_dora=False,
)
model.add_adapter(denoiser_lora_config)
self.assertTrue(check_if_lora_correctly_set(model), "LoRA layers not set correctly")

with tempfile.TemporaryDirectory() as tmpdir:
model.save_lora_adapter(tmpdir)
model_file = os.path.join(tmpdir, "pytorch_lora_weights.safetensors")
self.assertTrue(os.path.isfile(model_file))

# Perturb the metadata in the state dict.
loaded_state_dict = safetensors.torch.load_file(model_file)
metadata = {"format": "pt"}
lora_adapter_metadata = denoiser_lora_config.to_dict()
lora_adapter_metadata.update({"foo": 1, "bar": 2})
for key, value in lora_adapter_metadata.items():
if isinstance(value, set):
lora_adapter_metadata[key] = list(value)
metadata[LORA_ADAPTER_METADATA_KEY] = json.dumps(lora_adapter_metadata, indent=2, sort_keys=True)
safetensors.torch.save_file(loaded_state_dict, model_file, metadata=metadata)

model.unload_lora()
self.assertFalse(check_if_lora_correctly_set(model), "LoRA layers not set correctly")

with self.assertRaises(TypeError) as err_context:
model.load_lora_adapter(tmpdir, prefix=None, use_safetensors=True)
self.assertTrue("`LoraConfig` class could not be instantiated" in str(err_context.exception))

@require_torch_accelerator
def test_cpu_offload(self):
if self.model_class._no_split_modules is None:
Expand Down Expand Up @@ -1941,6 +1775,154 @@ def test_passing_dict_device_map_works(self, name, device):
_ = loaded_model(**inputs_dict)


class PEFTTesterMixin:
@require_peft_backend
@pytest.mark.parametrize("rank,lora_alpha,use_dora", [(4, 4, True), (4, 8, False), (8, 4, False)])
@torch.no_grad()
def test_save_load_lora_adapter(self, rank, lora_alpha, use_dora):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict).to(torch_device)

if not issubclass(model.__class__, PeftAdapterMixin):
pytest.skip(f"PEFT is not supported for this model ({model.__class__.__name__}).")

torch.manual_seed(0)
output_no_lora = model(**inputs_dict, return_dict=False)[0]

denoiser_lora_config = LoraConfig(
r=rank,
lora_alpha=lora_alpha,
target_modules=["to_q", "to_k", "to_v", "to_out.0"],
init_lora_weights=False,
use_dora=use_dora,
)
model.add_adapter(denoiser_lora_config)
assert check_if_lora_correctly_set(model), "LoRA layers not set correctly"

torch.manual_seed(0)
outputs_with_lora = model(**inputs_dict, return_dict=False)[0]
assert not torch.allclose(output_no_lora, outputs_with_lora, atol=1e-4, rtol=1e-4)

with tempfile.TemporaryDirectory() as tmpdir:
model.save_lora_adapter(tmpdir)
model_file = os.path.join(tmpdir, "pytorch_lora_weights.safetensors")
assert os.path.isfile(model_file)

state_dict_loaded = safetensors.torch.load_file(model_file)

model.unload_lora()
assert not check_if_lora_correctly_set(model), "LoRA layers not set correctly"

model.load_lora_adapter(tmpdir, prefix=None, use_safetensors=True)
state_dict_retrieved = get_peft_model_state_dict(model, adapter_name="default_0")

for k, loaded_v in state_dict_loaded.items():
retrieved_v = state_dict_retrieved[k].to(loaded_v.device)
assert torch.allclose(loaded_v, retrieved_v)

assert check_if_lora_correctly_set(model), "LoRA layers not set correctly"

torch.manual_seed(0)
outputs_with_lora_2 = model(**inputs_dict, return_dict=False)[0]
assert not torch.allclose(output_no_lora, outputs_with_lora_2, atol=1e-4, rtol=1e-4)
assert torch.allclose(outputs_with_lora, outputs_with_lora_2, atol=1e-4, rtol=1e-4)

@require_peft_backend
def test_lora_wrong_adapter_name_raises_error(self):
init_dict, _ = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict).to(torch_device)

if not issubclass(model.__class__, PeftAdapterMixin):
pytest.skip(f"PEFT is not supported for this model ({model.__class__.__name__}).")

denoiser_lora_config = LoraConfig(
r=4,
lora_alpha=4,
target_modules=["to_q", "to_k", "to_v", "to_out.0"],
init_lora_weights=False,
use_dora=False,
)
model.add_adapter(denoiser_lora_config)
assert check_if_lora_correctly_set(model), "LoRA layers not set correctly"

with tempfile.TemporaryDirectory() as tmpdir:
wrong_name = "foo"
with pytest.raises(ValueError, match=rf"Adapter name {wrong_name} not found in the model\."):
model.save_lora_adapter(tmpdir, adapter_name=wrong_name)

@require_peft_backend
@pytest.mark.parametrize("rank,lora_alpha,use_dora", [(4, 4, True), (4, 8, False), (8, 4, False)])
def test_lora_adapter_metadata_is_loaded_correctly(self, rank, lora_alpha, use_dora):
init_dict, _ = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict).to(torch_device)

if not issubclass(model.__class__, PeftAdapterMixin):
pytest.skip(f"PEFT is not supported for this model ({model.__class__.__name__}).")

denoiser_lora_config = LoraConfig(
r=rank,
lora_alpha=lora_alpha,
target_modules=["to_q", "to_k", "to_v", "to_out.0"],
init_lora_weights=False,
use_dora=use_dora,
)
model.add_adapter(denoiser_lora_config)
metadata = model.peft_config["default"].to_dict()
assert check_if_lora_correctly_set(model), "LoRA layers not set correctly"

with tempfile.TemporaryDirectory() as tmpdir:
model.save_lora_adapter(tmpdir)
model_file = os.path.join(tmpdir, "pytorch_lora_weights.safetensors")
assert os.path.isfile(model_file)

model.unload_lora()
assert not check_if_lora_correctly_set(model), "LoRA layers not set correctly"

model.load_lora_adapter(tmpdir, prefix=None, use_safetensors=True)
parsed_metadata = model.peft_config["default_0"].to_dict()
check_if_dicts_are_equal(metadata, parsed_metadata)

@require_peft_backend
def test_lora_adapter_wrong_metadata_raises_error(self):
init_dict, _ = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict).to(torch_device)

if not issubclass(model.__class__, PeftAdapterMixin):
pytest.skip(f"PEFT is not supported for this model ({model.__class__.__name__}).")

denoiser_lora_config = LoraConfig(
r=4,
lora_alpha=4,
target_modules=["to_q", "to_k", "to_v", "to_out.0"],
init_lora_weights=False,
use_dora=False,
)
model.add_adapter(denoiser_lora_config)
assert check_if_lora_correctly_set(model), "LoRA layers not set correctly"

with tempfile.TemporaryDirectory() as tmpdir:
model.save_lora_adapter(tmpdir)
model_file = os.path.join(tmpdir, "pytorch_lora_weights.safetensors")
assert os.path.isfile(model_file)

# Perturb the metadata
loaded_state_dict = safetensors.torch.load_file(model_file)
metadata = {"format": "pt"}
lora_adapter_metadata = denoiser_lora_config.to_dict()
lora_adapter_metadata.update({"foo": 1, "bar": 2})
for key, value in list(lora_adapter_metadata.items()):
if isinstance(value, set):
lora_adapter_metadata[key] = list(value)
metadata[LORA_ADAPTER_METADATA_KEY] = json.dumps(lora_adapter_metadata, indent=2, sort_keys=True)
safetensors.torch.save_file(loaded_state_dict, model_file, metadata=metadata)

model.unload_lora()
assert not check_if_lora_correctly_set(model), "LoRA layers not set correctly"

with pytest.raises(TypeError, match=r"`LoraConfig` class could not be instantiated"):
model.load_lora_adapter(tmpdir, prefix=None, use_safetensors=True)


@is_staging_test
class ModelPushToHubTester(unittest.TestCase):
identifier = uuid.uuid4()
Expand Down
4 changes: 2 additions & 2 deletions tests/models/transformers/test_models_prior.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,13 +30,13 @@
torch_all_close,
torch_device,
)
from ..test_modeling_common import ModelTesterMixin
from ..test_modeling_common import ModelTesterMixin, PEFTTesterMixin


enable_full_determinism()


class PriorTransformerTests(ModelTesterMixin, unittest.TestCase):
class PriorTransformerTests(ModelTesterMixin, PEFTTesterMixin, unittest.TestCase):
model_class = PriorTransformer
main_input_name = "hidden_states"

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,13 @@
from diffusers import AuraFlowTransformer2DModel

from ...testing_utils import enable_full_determinism, torch_device
from ..test_modeling_common import ModelTesterMixin
from ..test_modeling_common import ModelTesterMixin, PEFTTesterMixin


enable_full_determinism()


class AuraFlowTransformerTests(ModelTesterMixin, unittest.TestCase):
class AuraFlowTransformerTests(ModelTesterMixin, PEFTTesterMixin, unittest.TestCase):
model_class = AuraFlowTransformer2DModel
main_input_name = "hidden_states"
# We override the items here because the transformer under consideration is small.
Expand Down
Loading
Loading