Skip to content

Commit d3d22ce

Browse files
Small modification to enable usage by external scripts (#956)
* Make training code usable by external scripts Add parameter inputs to training and argument parsing function to allow this script to be used by an external call. * Apply suggestions from code review Co-authored-by: Patrick von Platen <[email protected]>
1 parent 8332c1a commit d3d22ce

File tree

1 file changed

+9
-5
lines changed

1 file changed

+9
-5
lines changed

examples/dreambooth/train_dreambooth.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
logger = get_logger(__name__)
2727

2828

29-
def parse_args():
29+
def parse_args(input_args):
3030
parser = argparse.ArgumentParser(description="Simple example of a training script.")
3131
parser.add_argument(
3232
"--pretrained_model_name_or_path",
@@ -196,7 +196,11 @@ def parse_args():
196196
)
197197
parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
198198

199-
args = parser.parse_args()
199+
if input_args is not None:
200+
args = parser.parse_args(input_args)
201+
else:
202+
args = parser.parse_args()
203+
200204
env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
201205
if env_local_rank != -1 and env_local_rank != args.local_rank:
202206
args.local_rank = env_local_rank
@@ -319,8 +323,7 @@ def get_full_repo_name(model_id: str, organization: Optional[str] = None, token:
319323
return f"{organization}/{model_id}"
320324

321325

322-
def main():
323-
args = parse_args()
326+
def main(args):
324327
logging_dir = Path(args.output_dir, args.logging_dir)
325328

326329
accelerator = Accelerator(
@@ -653,4 +656,5 @@ def collate_fn(examples):
653656

654657

655658
if __name__ == "__main__":
656-
main()
659+
args = parse_args()
660+
main(args)

0 commit comments

Comments
 (0)