Skip to content

Commit f4821d0

Browse files
committed
[WEB] Update seed calculation and model versions.
Signed-Off-by: Gaurav Shukla <[email protected]>
1 parent fdf2aa5 commit f4821d0

File tree

3 files changed

+19
-12
lines changed

3 files changed

+19
-12
lines changed

web/index.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,9 @@
77
from PIL import Image
88
import json
99
import os
10+
from random import randint
11+
from numpy import iinfo
12+
import numpy as np
1013

1114
"""
1215
def debug_event(debug):
@@ -147,7 +150,12 @@ def debug_event(debug):
147150
value="fp16",
148151
choices=["fp16", "fp32"],
149152
)
150-
seed = gr.Textbox(value="42", max_lines=1, label="Seed")
153+
uint32_info = iinfo(np.uint32)
154+
rand_seed = randint(uint32_info.min, uint32_info.max)
155+
seed = gr.Number(
156+
value=rand_seed,
157+
label="Seed",
158+
)
151159
with gr.Row():
152160
cache = gr.Checkbox(label="Cache", value=True)
153161
debug = gr.Checkbox(

web/models/stable_diffusion/main.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
from diffusers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
88
from tqdm.auto import tqdm
99
import numpy as np
10+
from numpy import iinfo
11+
from random import randint
1012
from models.stable_diffusion.opt_params import get_unet, get_vae, get_clip
1113
from models.stable_diffusion.arguments import args, schedulers, cache_obj
1214

@@ -23,7 +25,7 @@ def stable_diff_inf(
2325
guidance: float,
2426
height: int,
2527
width: int,
26-
seed: str,
28+
seed: int,
2729
precision: str,
2830
device: str,
2931
cache: bool,
@@ -35,12 +37,9 @@ def stable_diff_inf(
3537

3638
start = time.time()
3739
# set seed value
38-
try:
39-
seed = int(seed)
40-
if seed < 0 or seed > 10000:
41-
seed = int(torch.randint(low=25, high=100, size=()))
42-
except (ValueError, OverflowError) as error:
43-
seed = hash(seed)
40+
uint32_info = iinfo(np.uint32)
41+
if seed < uint32_info.min and seed >= uint32_info.max:
42+
seed = randint(uint32_info.min, uint32_info.max)
4443

4544
args.set_params(
4645
prompt,

web/models/stable_diffusion/opt_params.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ def get_unet(args):
3131
return get_shark_model(args, bucket, model_name, iree_flags)
3232
else:
3333
bucket = "gs://shark_tank/prashant_nod"
34-
model_name = "unet_18nov_fp16"
34+
model_name = "unet_22nov_fp16"
3535
iree_flags += [
3636
"--iree-flow-enable-padding-linalg-ops",
3737
"--iree-flow-linalg-ops-padding-size=32",
@@ -44,7 +44,7 @@ def get_unet(args):
4444
# Tuned model is not present for `fp32` case.
4545
if args.precision == "fp32":
4646
bucket = "gs://shark_tank/prashant_nod"
47-
model_name = "unet_18nov_fp32"
47+
model_name = "unet_22nov_fp32"
4848
iree_flags += [
4949
"--iree-flow-enable-conv-nchw-to-nhwc-transform",
5050
"--iree-flow-enable-padding-linalg-ops",
@@ -77,7 +77,7 @@ def get_vae(args):
7777
)
7878
if args.precision in ["fp16", "int8"]:
7979
bucket = "gs://shark_tank/prashant_nod"
80-
model_name = "vae_18nov_fp16"
80+
model_name = "vae_22nov_fp16"
8181
iree_flags += [
8282
"--iree-flow-enable-conv-nchw-to-nhwc-transform",
8383
"--iree-flow-enable-padding-linalg-ops",
@@ -89,7 +89,7 @@ def get_vae(args):
8989

9090
if args.precision == "fp32":
9191
bucket = "gs://shark_tank/prashant_nod"
92-
model_name = "vae_18nov_fp32"
92+
model_name = "vae_22nov_fp32"
9393
iree_flags += [
9494
"--iree-flow-enable-conv-nchw-to-nhwc-transform",
9595
"--iree-flow-enable-padding-linalg-ops",

0 commit comments

Comments
 (0)