Skip to content

Commit c5c9399

Browse files
correct paths for tests
1 parent 836f3f3 commit c5c9399

File tree

1 file changed

+20
-28
lines changed

1 file changed

+20
-28
lines changed

tests/test_modeling_utils.py

Lines changed: 20 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -365,9 +365,7 @@ def prepare_init_args_and_inputs_for_common(self):
365365
return init_dict, inputs_dict
366366

367367
def test_from_pretrained_hub(self):
368-
model, loading_info = UNet2DModel.from_pretrained(
369-
"/home/patrick/google_checkpoints/unet-ldm-dummy-update", output_loading_info=True
370-
)
368+
model, loading_info = UNet2DModel.from_pretrained("fusing/unet-ldm-dummy-update", output_loading_info=True)
371369

372370
self.assertIsNotNone(model)
373371
self.assertEqual(len(loading_info["missing_keys"]), 0)
@@ -378,7 +376,7 @@ def test_from_pretrained_hub(self):
378376
assert image is not None, "Make sure output is not None"
379377

380378
def test_output_pretrained(self):
381-
model = UNet2DModel.from_pretrained("/home/patrick/google_checkpoints/unet-ldm-dummy-update")
379+
model = UNet2DModel.from_pretrained("fusing/unet-ldm-dummy-update")
382380
model.eval()
383381

384382
torch.manual_seed(0)
@@ -472,9 +470,7 @@ def prepare_init_args_and_inputs_for_common(self):
472470
return init_dict, inputs_dict
473471

474472
def test_from_pretrained_hub(self):
475-
model, loading_info = UNet2DModel.from_pretrained(
476-
"/home/patrick/google_checkpoints/ncsnpp-celebahq-256", output_loading_info=True
477-
)
473+
model, loading_info = UNet2DModel.from_pretrained("google/ncsnpp-celebahq-256", output_loading_info=True)
478474
self.assertIsNotNone(model)
479475
self.assertEqual(len(loading_info["missing_keys"]), 0)
480476

@@ -487,7 +483,7 @@ def test_from_pretrained_hub(self):
487483
assert image is not None, "Make sure output is not None"
488484

489485
def test_output_pretrained_ve_mid(self):
490-
model = UNet2DModel.from_pretrained("/home/patrick/google_checkpoints/ncsnpp-celebahq-256")
486+
model = UNet2DModel.from_pretrained("google/ncsnpp-celebahq-256")
491487
model.to(torch_device)
492488

493489
torch.manual_seed(0)
@@ -512,7 +508,7 @@ def test_output_pretrained_ve_mid(self):
512508
self.assertTrue(torch.allclose(output_slice, expected_output_slice, rtol=1e-2))
513509

514510
def test_output_pretrained_ve_large(self):
515-
model = UNet2DModel.from_pretrained("/home/patrick/google_checkpoints/ncsnpp-ffhq-ve-dummy-update")
511+
model = UNet2DModel.from_pretrained("fusing/ncsnpp-ffhq-ve-dummy-update")
516512
model.to(torch_device)
517513

518514
torch.manual_seed(0)
@@ -582,9 +578,7 @@ def test_training(self):
582578
pass
583579

584580
def test_from_pretrained_hub(self):
585-
model, loading_info = VQModel.from_pretrained(
586-
"/home/patrick/google_checkpoints/vqgan-dummy", output_loading_info=True
587-
)
581+
model, loading_info = VQModel.from_pretrained("fusing/vqgan-dummy", output_loading_info=True)
588582
self.assertIsNotNone(model)
589583
self.assertEqual(len(loading_info["missing_keys"]), 0)
590584

@@ -594,7 +588,7 @@ def test_from_pretrained_hub(self):
594588
assert image is not None, "Make sure output is not None"
595589

596590
def test_output_pretrained(self):
597-
model = VQModel.from_pretrained("/home/patrick/google_checkpoints/vqgan-dummy")
591+
model = VQModel.from_pretrained("fusing/vqgan-dummy")
598592
model.eval()
599593

600594
torch.manual_seed(0)
@@ -655,9 +649,7 @@ def test_training(self):
655649
pass
656650

657651
def test_from_pretrained_hub(self):
658-
model, loading_info = AutoencoderKL.from_pretrained(
659-
"/home/patrick/google_checkpoints/autoencoder-kl-dummy", output_loading_info=True
660-
)
652+
model, loading_info = AutoencoderKL.from_pretrained("fusing/autoencoder-kl-dummy", output_loading_info=True)
661653
self.assertIsNotNone(model)
662654
self.assertEqual(len(loading_info["missing_keys"]), 0)
663655

@@ -667,7 +659,7 @@ def test_from_pretrained_hub(self):
667659
assert image is not None, "Make sure output is not None"
668660

669661
def test_output_pretrained(self):
670-
model = AutoencoderKL.from_pretrained("/home/patrick/google_checkpoints/autoencoder-kl-dummy")
662+
model = AutoencoderKL.from_pretrained("fusing/autoencoder-kl-dummy")
671663
model.eval()
672664

673665
torch.manual_seed(0)
@@ -715,7 +707,7 @@ def test_from_pretrained_save_pretrained(self):
715707

716708
@slow
717709
def test_from_pretrained_hub(self):
718-
model_path = "/home/patrick/google_checkpoints/ddpm-cifar10-32"
710+
model_path = "google/ddpm-cifar10-32"
719711

720712
ddpm = DDPMPipeline.from_pretrained(model_path)
721713
ddpm_from_hub = DiffusionPipeline.from_pretrained(model_path)
@@ -733,7 +725,7 @@ def test_from_pretrained_hub(self):
733725

734726
@slow
735727
def test_output_format(self):
736-
model_path = "/home/patrick/google_checkpoints/ddpm-cifar10-32"
728+
model_path = "google/ddpm-cifar10-32"
737729

738730
pipe = DDIMPipeline.from_pretrained(model_path)
739731

@@ -754,7 +746,7 @@ def test_output_format(self):
754746

755747
@slow
756748
def test_ddpm_cifar10(self):
757-
model_id = "/home/patrick/google_checkpoints/ddpm-cifar10-32"
749+
model_id = "google/ddpm-cifar10-32"
758750

759751
unet = UNet2DModel.from_pretrained(model_id)
760752
scheduler = DDPMScheduler.from_config(model_id)
@@ -773,7 +765,7 @@ def test_ddpm_cifar10(self):
773765

774766
@slow
775767
def test_ddim_lsun(self):
776-
model_id = "/home/patrick/google_checkpoints/ddpm-ema-bedroom-256"
768+
model_id = "google/ddpm-ema-bedroom-256"
777769

778770
unet = UNet2DModel.from_pretrained(model_id)
779771
scheduler = DDIMScheduler.from_config(model_id)
@@ -791,7 +783,7 @@ def test_ddim_lsun(self):
791783

792784
@slow
793785
def test_ddim_cifar10(self):
794-
model_id = "/home/patrick/google_checkpoints/ddpm-cifar10-32"
786+
model_id = "google/ddpm-cifar10-32"
795787

796788
unet = UNet2DModel.from_pretrained(model_id)
797789
scheduler = DDIMScheduler(tensor_format="pt")
@@ -809,7 +801,7 @@ def test_ddim_cifar10(self):
809801

810802
@slow
811803
def test_pndm_cifar10(self):
812-
model_id = "/home/patrick/google_checkpoints/ddpm-cifar10-32"
804+
model_id = "google/ddpm-cifar10-32"
813805

814806
unet = UNet2DModel.from_pretrained(model_id)
815807
scheduler = PNDMScheduler(tensor_format="pt")
@@ -826,7 +818,7 @@ def test_pndm_cifar10(self):
826818

827819
@slow
828820
def test_ldm_text2img(self):
829-
ldm = LDMTextToImagePipeline.from_pretrained("/home/patrick/google_checkpoints/ldm-text2im-large-256")
821+
ldm = LDMTextToImagePipeline.from_pretrained("CompVis/ldm-text2im-large-256")
830822

831823
prompt = "A painting of a squirrel eating a burger"
832824
generator = torch.manual_seed(0)
@@ -842,7 +834,7 @@ def test_ldm_text2img(self):
842834

843835
@slow
844836
def test_ldm_text2img_fast(self):
845-
ldm = LDMTextToImagePipeline.from_pretrained("/home/patrick/google_checkpoints/ldm-text2im-large-256")
837+
ldm = LDMTextToImagePipeline.from_pretrained("CompVis/ldm-text2im-large-256")
846838

847839
prompt = "A painting of a squirrel eating a burger"
848840
generator = torch.manual_seed(0)
@@ -856,13 +848,13 @@ def test_ldm_text2img_fast(self):
856848

857849
@slow
858850
def test_score_sde_ve_pipeline(self):
859-
model = UNet2DModel.from_pretrained("/home/patrick/google_checkpoints/ncsnpp-church-256")
851+
model = UNet2DModel.from_pretrained("google/ncsnpp-church-256")
860852

861853
torch.manual_seed(0)
862854
if torch.cuda.is_available():
863855
torch.cuda.manual_seed_all(0)
864856

865-
scheduler = ScoreSdeVeScheduler.from_config("/home/patrick/google_checkpoints/ncsnpp-church-256")
857+
scheduler = ScoreSdeVeScheduler.from_config("google/ncsnpp-church-256")
866858

867859
sde_ve = ScoreSdeVePipeline(model=model, scheduler=scheduler)
868860

@@ -877,7 +869,7 @@ def test_score_sde_ve_pipeline(self):
877869

878870
@slow
879871
def test_ldm_uncond(self):
880-
ldm = LDMPipeline.from_pretrained("/home/patrick/google_checkpoints/ldm-celebahq-256")
872+
ldm = LDMPipeline.from_pretrained("CompVis/ldm-celebahq-256")
881873

882874
generator = torch.manual_seed(0)
883875
image = ldm(generator=generator, num_inference_steps=5, output_type="numpy")["sample"]

0 commit comments

Comments
 (0)