|
10 | 10 | import argparse |
11 | 11 | from glob import glob |
12 | 12 | import shutil |
| 13 | +import requests |
13 | 14 |
|
14 | 15 | model_config_dicts = get_json_file( |
15 | 16 | os.path.join( |
|
19 | 20 | ) |
20 | 21 |
|
21 | 22 |
|
| 23 | +def get_inpaint_inputs(): |
| 24 | + os.mkdir("./test_images/inputs") |
| 25 | + img_url = ( |
| 26 | + "https://huggingface.co/datasets/diffusers/test-arrays/resolve" |
| 27 | + "/main/stable_diffusion_inpaint/input_bench_image.png" |
| 28 | + ) |
| 29 | + mask_url = ( |
| 30 | + "https://huggingface.co/datasets/diffusers/test-arrays/resolve" |
| 31 | + "/main/stable_diffusion_inpaint/input_bench_mask.png" |
| 32 | + ) |
| 33 | + img = requests.get(img_url) |
| 34 | + mask = requests.get(mask_url) |
| 35 | + open("./test_images/inputs/image.png", "wb").write(img.content) |
| 36 | + open("./test_images/inputs/mask.png", "wb").write(mask.content) |
| 37 | + |
| 38 | + |
22 | 39 | def test_loop(device="vulkan", beta=False, extra_flags=[]): |
23 | 40 | # Get golden values from tank |
24 | 41 | shutil.rmtree("./test_images", ignore_errors=True) |
25 | 42 | os.mkdir("./test_images") |
26 | 43 | os.mkdir("./test_images/golden") |
| 44 | + get_inpaint_inputs() |
27 | 45 | hf_model_names = model_config_dicts[0].values() |
28 | 46 | tuned_options = ["--no-use_tuned", "--use_tuned"] |
29 | 47 | import_options = ["--import_mlir", "--no-import_mlir"] |
30 | 48 | prompt_text = "--prompt=cyberpunk forest by Salvador Dali" |
| 49 | + inpaint_prompt_text = "--prompt=Face of a yellow cat, high resolution, sitting on a park bench" |
31 | 50 | if os.name == "nt": |
32 | 51 | prompt_text = '--prompt="cyberpunk forest by Salvador Dali"' |
| 52 | + inpaint_prompt_text = '--prompt="Face of a yellow cat, high resolution, sitting on a park bench"' |
33 | 53 | if beta: |
34 | 54 | extra_flags.append("--beta_models=True") |
35 | 55 | for import_opt in import_options: |
36 | 56 | for model_name in hf_model_names: |
37 | 57 | for use_tune in tuned_options: |
38 | | - command = [ |
39 | | - executable, # executable is the python from the venv used to run this |
40 | | - "apps/stable_diffusion/scripts/txt2img.py", |
41 | | - "--device=" + device, |
42 | | - prompt_text, |
43 | | - "--negative_prompts=" + '""', |
44 | | - "--seed=42", |
45 | | - import_opt, |
46 | | - "--output_dir=" |
47 | | - + os.path.join(os.getcwd(), "test_images", model_name), |
48 | | - "--hf_model_id=" + model_name, |
49 | | - use_tune, |
50 | | - ] |
| 58 | + command = ( |
| 59 | + [ |
| 60 | + executable, # executable is the python from the venv used to run this |
| 61 | + "apps/stable_diffusion/scripts/txt2img.py", |
| 62 | + "--device=" + device, |
| 63 | + prompt_text, |
| 64 | + "--negative_prompts=" + '""', |
| 65 | + "--seed=42", |
| 66 | + import_opt, |
| 67 | + "--output_dir=" |
| 68 | + + os.path.join(os.getcwd(), "test_images", model_name), |
| 69 | + "--hf_model_id=" + model_name, |
| 70 | + use_tune, |
| 71 | + ] |
| 72 | + if "inpainting" not in model_name |
| 73 | + else [ |
| 74 | + "python", |
| 75 | + "apps/stable_diffusion/scripts/inpaint.py", |
| 76 | + "--device=" + device, |
| 77 | + inpaint_prompt_text, |
| 78 | + "--negative_prompts=" + '""', |
| 79 | + "--img_path=./test_images/inputs/image.png", |
| 80 | + "--mask_path=./test_images/inputs/mask.png", |
| 81 | + "--seed=42", |
| 82 | + "--import_mlir", |
| 83 | + "--output_dir=" |
| 84 | + + os.path.join(os.getcwd(), "test_images", model_name), |
| 85 | + "--hf_model_id=" + model_name, |
| 86 | + use_tune, |
| 87 | + ] |
| 88 | + ) |
51 | 89 | command += extra_flags |
52 | 90 | if os.name == "nt": |
53 | 91 | command = " ".join(command) |
|
0 commit comments