Skip to content

Commit 0eee761

Browse files
author
Gaurav Shukla
committed
[WEB] Launch only one SD version at a time
Signed-Off-by: Gaurav Shukla <[email protected]>
1 parent 5ddce74 commit 0eee761

File tree

3 files changed

+26
-73
lines changed

3 files changed

+26
-73
lines changed

web/index.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -99,11 +99,6 @@ def resource_path(relative_path):
9999
step=0.1,
100100
label="Guidance Scale",
101101
)
102-
version = gr.Radio(
103-
label="Version",
104-
value="v2.1base",
105-
choices=["v1.4", "v2.1base"],
106-
)
107102
with gr.Row():
108103
scheduler_key = gr.Dropdown(
109104
label="Scheduler",
@@ -157,7 +152,6 @@ def resource_path(relative_path):
157152
guidance,
158153
seed,
159154
scheduler_key,
160-
version,
161155
],
162156
outputs=[generated_img, std_output],
163157
)
@@ -169,7 +163,6 @@ def resource_path(relative_path):
169163
guidance,
170164
seed,
171165
scheduler_key,
172-
version,
173166
],
174167
outputs=[generated_img, std_output],
175168
)

web/models/stable_diffusion/cache_objects.py

Lines changed: 17 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -11,79 +11,51 @@
1111
from models.stable_diffusion.stable_args import args
1212

1313

14+
model_config = {
15+
"v2": "stabilityai/stable-diffusion-2",
16+
"v2.1base": "stabilityai/stable-diffusion-2-1-base",
17+
"v1.4": "CompVis/stable-diffusion-v1-4",
18+
}
19+
1420
schedulers = dict()
1521
schedulers["PNDM"] = PNDMScheduler.from_pretrained(
16-
"CompVis/stable-diffusion-v1-4",
22+
model_config[args.version],
1723
subfolder="scheduler",
1824
)
1925
schedulers["LMSDiscrete"] = LMSDiscreteScheduler.from_pretrained(
20-
"CompVis/stable-diffusion-v1-4",
26+
model_config[args.version],
2127
subfolder="scheduler",
2228
)
2329
schedulers["DDIM"] = DDIMScheduler.from_pretrained(
24-
"CompVis/stable-diffusion-v1-4",
30+
model_config[args.version],
2531
subfolder="scheduler",
2632
)
2733
schedulers["DPMSolverMultistep"] = DPMSolverMultistepScheduler.from_pretrained(
28-
"CompVis/stable-diffusion-v1-4",
34+
model_config[args.version],
2935
subfolder="scheduler",
3036
)
3137
schedulers["EulerDiscrete"] = EulerDiscreteScheduler.from_pretrained(
32-
"CompVis/stable-diffusion-v1-4",
38+
model_config[args.version],
3339
subfolder="scheduler",
3440
)
3541

36-
schedulers2 = dict()
37-
schedulers2["PNDM"] = PNDMScheduler.from_pretrained(
38-
"stabilityai/stable-diffusion-2-1-base",
39-
subfolder="scheduler",
40-
)
41-
schedulers2["LMSDiscrete"] = LMSDiscreteScheduler.from_pretrained(
42-
"stabilityai/stable-diffusion-2-1-base",
43-
subfolder="scheduler",
44-
)
45-
schedulers2["DDIM"] = DDIMScheduler.from_pretrained(
46-
"stabilityai/stable-diffusion-2-1-base",
47-
subfolder="scheduler",
48-
)
49-
schedulers2[
50-
"DPMSolverMultistep"
51-
] = DPMSolverMultistepScheduler.from_pretrained(
52-
"stabilityai/stable-diffusion-2-1-base",
53-
subfolder="scheduler",
54-
)
55-
schedulers2["EulerDiscrete"] = EulerDiscreteScheduler.from_pretrained(
56-
"stabilityai/stable-diffusion-2-1-base",
57-
subfolder="scheduler",
58-
)
5942

6043
# set iree-runtime flags
6144
set_iree_runtime_flags(args)
62-
args.version = "v1.4"
6345

6446
cache_obj = dict()
65-
66-
# cache tokenizer
67-
cache_obj["tokenizer"] = CLIPTokenizer.from_pretrained(
68-
"openai/clip-vit-large-patch14"
69-
)
70-
7147
# cache vae, unet and clip.
7248
(
7349
cache_obj["vae"],
7450
cache_obj["unet"],
7551
cache_obj["clip"],
7652
) = (get_vae(args), get_unet(args), get_clip(args))
7753

78-
args.version = "v2.1base"
7954
# cache tokenizer
80-
cache_obj["tokenizer2"] = CLIPTokenizer.from_pretrained(
81-
"stabilityai/stable-diffusion-2-1-base", subfolder="tokenizer"
55+
cache_obj["tokenizer"] = CLIPTokenizer.from_pretrained(
56+
"openai/clip-vit-large-patch14"
8257
)
83-
84-
# cache vae, unet and clip.
85-
(
86-
cache_obj["vae2"],
87-
cache_obj["unet2"],
88-
cache_obj["clip2"],
89-
) = (get_vae(args), get_unet(args), get_clip(args))
58+
if args.version == "v2.1base":
59+
cache_obj["tokenizer"] = CLIPTokenizer.from_pretrained(
60+
"stabilityai/stable-diffusion-2-1-base", subfolder="tokenizer"
61+
)

web/models/stable_diffusion/main.py

Lines changed: 9 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -4,21 +4,19 @@
44
from models.stable_diffusion.cache_objects import (
55
cache_obj,
66
schedulers,
7-
schedulers2,
87
)
98
from models.stable_diffusion.stable_args import args
109
from random import randint
1110
import numpy as np
1211
import time
1312

1413

15-
def set_ui_params(prompt, steps, guidance, seed, scheduler_key, version):
14+
def set_ui_params(prompt, steps, guidance, seed, scheduler_key):
1615
args.prompt = [prompt]
1716
args.steps = steps
1817
args.guidance = guidance
1918
args.seed = seed
2019
args.scheduler = scheduler_key
21-
args.version = version
2220

2321

2422
def stable_diff_inf(
@@ -27,7 +25,6 @@ def stable_diff_inf(
2725
guidance: float,
2826
seed: int,
2927
scheduler_key: str,
30-
version: str,
3128
):
3229

3330
# Handle out of range seeds.
@@ -36,29 +33,20 @@ def stable_diff_inf(
3633
if seed < uint32_min or seed >= uint32_max:
3734
seed = randint(uint32_min, uint32_max)
3835

39-
set_ui_params(prompt, steps, guidance, seed, scheduler_key, version)
36+
set_ui_params(prompt, steps, guidance, seed, scheduler_key)
4037
dtype = torch.float32 if args.precision == "fp32" else torch.half
4138
generator = torch.manual_seed(
4239
args.seed
4340
) # Seed generator to create the inital latent noise
4441
guidance_scale = torch.tensor(args.guidance).to(torch.float32)
4542
# Initialize vae and unet models.
46-
if args.version == "v2.1base":
47-
vae, unet, clip, tokenizer = (
48-
cache_obj["vae2"],
49-
cache_obj["unet2"],
50-
cache_obj["clip2"],
51-
cache_obj["tokenizer2"],
52-
)
53-
scheduler = schedulers2[args.scheduler]
54-
else:
55-
vae, unet, clip, tokenizer = (
56-
cache_obj["vae"],
57-
cache_obj["unet"],
58-
cache_obj["clip"],
59-
cache_obj["tokenizer"],
60-
)
61-
scheduler = schedulers[args.scheduler]
43+
vae, unet, clip, tokenizer = (
44+
cache_obj["vae"],
45+
cache_obj["unet"],
46+
cache_obj["clip"],
47+
cache_obj["tokenizer"],
48+
)
49+
scheduler = schedulers[args.scheduler]
6250

6351
start = time.time()
6452
text_input = tokenizer(

0 commit comments

Comments
 (0)