Skip to content

Commit 895795d

Browse files
committed
Fix push_to_hub for dreambooth and textual_inversion
1 parent 367a671 commit 895795d

File tree

2 files changed

+4
-6
lines changed

2 files changed

+4
-6
lines changed

examples/dreambooth/train_dreambooth.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from accelerate.logging import get_logger
1414
from accelerate.utils import set_seed
1515
from diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionPipeline, UNet2DConditionModel
16+
from diffusers.hub_utils import push_to_hub
1617
from diffusers.optimization import get_scheduler
1718
from huggingface_hub import HfFolder, Repository, whoami
1819
from PIL import Image
@@ -575,9 +576,7 @@ def collate_fn(examples):
575576
pipeline.save_pretrained(args.output_dir)
576577

577578
if args.push_to_hub:
578-
repo.push_to_hub(
579-
args, pipeline, repo, commit_message="End of training", blocking=False, auto_lfs_prune=True
580-
)
579+
push_to_hub(args, pipeline, repo, commit_message="End of training", blocking=False, auto_lfs_prune=True)
581580

582581
accelerator.end_training()
583582

examples/textual_inversion/textual_inversion.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from accelerate.logging import get_logger
1818
from accelerate.utils import set_seed
1919
from diffusers import AutoencoderKL, DDPMScheduler, PNDMScheduler, StableDiffusionPipeline, UNet2DConditionModel
20+
from diffusers.hub_utils import push_to_hub
2021
from diffusers.optimization import get_scheduler
2122
from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
2223
from huggingface_hub import HfFolder, Repository, whoami
@@ -569,9 +570,7 @@ def main():
569570
save_progress(text_encoder, placeholder_token_id, accelerator, args)
570571

571572
if args.push_to_hub:
572-
repo.push_to_hub(
573-
args, pipeline, repo, commit_message="End of training", blocking=False, auto_lfs_prune=True
574-
)
573+
push_to_hub(args, pipeline, repo, commit_message="End of training", blocking=False, auto_lfs_prune=True)
575574

576575
accelerator.end_training()
577576

0 commit comments

Comments
 (0)