Skip to content

Commit 986c126

Browse files
author
Gaurav Shukla
committed
[SHARK][SD] Add support for negative prompts
Signed-Off-by: Gaurav Shukla <[email protected]>
1 parent 0eee761 commit 986c126

File tree

6 files changed

+52
-13
lines changed

6 files changed

+52
-13
lines changed

shark/examples/shark_inference/stable_diffusion/main.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from utils import get_shark_model, set_iree_runtime_flags
1515
from opt_params import get_unet, get_vae, get_clip
1616
import time
17+
import sys
1718
from model_wrappers import get_vae_mlir
1819
from shark.iree_utils.compile_utils import dump_isas
1920

@@ -39,6 +40,7 @@ def end_profiling(device):
3940
dtype = torch.float32 if args.precision == "fp32" else torch.half
4041

4142
prompt = args.prompts
43+
neg_prompt = args.negative_prompts
4244
height = 512 # default height of Stable Diffusion
4345
width = 512 # default width of Stable Diffusion
4446
if args.version == "v2":
@@ -54,7 +56,12 @@ def end_profiling(device):
5456
args.seed
5557
) # Seed generator to create the inital latent noise
5658

59+
# TODO: Add support for batch_size > 1.
5760
batch_size = len(prompt)
61+
if batch_size != 1:
62+
sys.exit("More than one prompt is not supported yet.")
63+
if batch_size != len(neg_prompt):
64+
sys.exit("prompts and negative prompts must be of same length")
5865

5966
set_iree_runtime_flags()
6067
unet = get_unet()
@@ -103,9 +110,10 @@ def end_profiling(device):
103110
text_embeddings = torch.from_numpy(text_embeddings).to(dtype)
104111
max_length = text_input.input_ids.shape[-1]
105112
uncond_input = tokenizer(
106-
[""] * batch_size,
113+
neg_prompt,
107114
padding="max_length",
108115
max_length=max_length,
116+
truncation=True,
109117
return_tensors="pt",
110118
)
111119
uncond_clip_inf_start = time.time()

shark/examples/shark_inference/stable_diffusion/model_wrappers.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,9 @@
1-
from diffusers import AutoencoderKL, UNet2DConditionModel, PNDMScheduler
1+
from diffusers import AutoencoderKL, UNet2DConditionModel
22
from transformers import CLIPTextModel
33
from utils import compile_through_fx
44
from stable_args import args
55
import torch
66

7-
BATCH_SIZE = len(args.prompts)
8-
97
model_config = {
108
"v2": "stabilityai/stable-diffusion-2",
119
"v2.1base": "stabilityai/stable-diffusion-2-1-base",

shark/examples/shark_inference/stable_diffusion/stable_args.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,21 @@
77
p.add_argument(
88
"--prompts",
99
nargs="+",
10-
default=["a photograph of an astronaut riding a horse"],
10+
default=["cyberpunk forest by Salvador Dali"],
1111
help="text of which images to be generated.",
1212
)
13+
14+
p.add_argument(
15+
"--negative-prompts",
16+
nargs="+",
17+
default=["trees, green"],
18+
help="text you don't want to see in the generated image.",
19+
)
20+
1321
p.add_argument(
1422
"--device", type=str, default="cpu", help="device to run the model."
1523
)
24+
1625
p.add_argument(
1726
"--steps",
1827
type=int,
@@ -33,6 +42,7 @@
3342
default=42,
3443
help="the seed to use.",
3544
)
45+
3646
p.add_argument(
3747
"--guidance_scale",
3848
type=float,

web/index.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,13 @@ def resource_path(relative_path):
7878
with gr.Group(elem_id="prompt_box_outer"):
7979
prompt = gr.Textbox(
8080
label="Prompt",
81-
value="A photograph of an astronaut riding a horse",
81+
value="cyberpunk forest by Salvador Dali",
82+
lines=1,
83+
elem_id="prompt_box",
84+
)
85+
negative_prompt = gr.Textbox(
86+
label="Negative Prompt",
87+
value="trees, green",
8288
lines=1,
8389
elem_id="prompt_box",
8490
)
@@ -148,6 +154,7 @@ def resource_path(relative_path):
148154
stable_diff_inf,
149155
inputs=[
150156
prompt,
157+
negative_prompt,
151158
steps,
152159
guidance,
153160
seed,
@@ -159,6 +166,7 @@ def resource_path(relative_path):
159166
stable_diff_inf,
160167
inputs=[
161168
prompt,
169+
negative_prompt,
162170
steps,
163171
guidance,
164172
seed,

web/models/stable_diffusion/main.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,11 @@
1111
import time
1212

1313

14-
def set_ui_params(prompt, steps, guidance, seed, scheduler_key):
15-
args.prompt = [prompt]
14+
def set_ui_params(
15+
prompt, negative_prompt, steps, guidance, seed, scheduler_key
16+
):
17+
args.prompts = [prompt]
18+
args.negative_prompts = [negative_prompt]
1619
args.steps = steps
1720
args.guidance = guidance
1821
args.seed = seed
@@ -21,6 +24,7 @@ def set_ui_params(prompt, steps, guidance, seed, scheduler_key):
2124

2225
def stable_diff_inf(
2326
prompt: str,
27+
negative_prompt: str,
2428
steps: int,
2529
guidance: float,
2630
seed: int,
@@ -33,7 +37,9 @@ def stable_diff_inf(
3337
if seed < uint32_min or seed >= uint32_max:
3438
seed = randint(uint32_min, uint32_max)
3539

36-
set_ui_params(prompt, steps, guidance, seed, scheduler_key)
40+
set_ui_params(
41+
prompt, negative_prompt, steps, guidance, seed, scheduler_key
42+
)
3743
dtype = torch.float32 if args.precision == "fp32" else torch.half
3844
generator = torch.manual_seed(
3945
args.seed
@@ -50,7 +56,7 @@ def stable_diff_inf(
5056

5157
start = time.time()
5258
text_input = tokenizer(
53-
args.prompt,
59+
args.prompts,
5460
padding="max_length",
5561
max_length=args.max_length,
5662
truncation=True,
@@ -64,9 +70,10 @@ def stable_diff_inf(
6470
max_length = text_input.input_ids.shape[-1]
6571

6672
uncond_input = tokenizer(
67-
[""],
73+
args.negative_prompts,
6874
padding="max_length",
6975
max_length=max_length,
76+
truncation=True,
7077
return_tensors="pt",
7178
)
7279
uncond_clip_inf_start = time.time()
@@ -127,7 +134,8 @@ def stable_diff_inf(
127134
avg_ms = 1000 * avg_ms / args.steps
128135
total_time = time.time() - start
129136

130-
text_output = f"prompt={args.prompt}"
137+
text_output = f"prompt={args.prompts}"
138+
text_output += f"\nnegative prompt={args.negative_prompts}"
131139
text_output += f"\nsteps={args.steps}, guidance_scale={args.guidance}, scheduler={args.scheduler}, seed={args.seed}, size={args.height}x{args.width}, version={args.version}"
132140
text_output += "\nAverage step time: {0:.2f}ms/it".format(avg_ms)
133141
print(f"\nAverage step time: {avg_ms}ms/it")

web/models/stable_diffusion/stable_args.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,17 @@
77
p.add_argument(
88
"--prompts",
99
nargs="+",
10-
default=["a photograph of an astronaut riding a horse"],
10+
default=["cyberpunk forest by Salvador Dali"],
1111
help="text of which images to be generated.",
1212
)
1313

14+
p.add_argument(
15+
"--negative-prompts",
16+
nargs="+",
17+
default=["trees, green"],
18+
help="text you don't want to see in the generated image.",
19+
)
20+
1421
p.add_argument(
1522
"--device", type=str, default="vulkan", help="device to run the model."
1623
)

0 commit comments

Comments
 (0)