Skip to content

Commit ee56559

Browse files
committed
[SD][web] Add a json file for model configuration
This cleans model_wrappers.py file. Signed-Off-by: Gaurav Shukla <[email protected]>
1 parent 00e594d commit ee56559

File tree

4 files changed

+90
-162
lines changed

4 files changed

+90
-162
lines changed
Lines changed: 58 additions & 161 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,10 @@
11
from diffusers import AutoencoderKL, UNet2DConditionModel
22
from transformers import CLIPTextModel
33
from models.stable_diffusion.utils import compile_through_fx
4+
from models.stable_diffusion.resources import models_config
45
from models.stable_diffusion.stable_args import args
56
import torch
67

7-
model_config = {
8-
"v2_1": "stabilityai/stable-diffusion-2-1",
9-
"v2_1base": "stabilityai/stable-diffusion-2-1-base",
10-
"v1_4": "CompVis/stable-diffusion-v1-4",
11-
}
128

139
# clip has 2 variants of max length 77 or 64.
1410
model_clip_max_length = 64 if args.max_length == 64 else 77
@@ -17,14 +13,6 @@
1713
elif args.variant == "openjourney":
1814
model_clip_max_length = 64
1915

20-
model_variant = {
21-
"stablediffusion": "SD",
22-
"anythingv3": "Linaqruf/anything-v3.0",
23-
"dreamlike": "dreamlike-art/dreamlike-diffusion-1.0",
24-
"openjourney": "prompthero/openjourney",
25-
"analogdiffusion": "wavymulder/Analog-Diffusion",
26-
}
27-
2816
model_input = {
2917
"v2_1": {
3018
"clip": (torch.randint(1, 2, (2, model_clip_max_length)),),
@@ -58,122 +46,99 @@
5846
},
5947
}
6048

61-
# revision param for from_pretrained defaults to "main" => fp32
62-
model_revision = {
63-
"stablediffusion": "fp16" if args.precision == "fp16" else "main",
64-
"anythingv3": "diffusers",
65-
"analogdiffusion": "main",
66-
"openjourney": "main",
67-
"dreamlike": "main",
68-
}
49+
version = args.version if args.variant == "stablediffusion" else "v1_4"
6950

7051

71-
def get_clip_mlir(model_name="clip_text", extra_args=[]):
52+
def get_configs():
53+
model_id_key = f"{args.variant}/{version}"
54+
revision_key = f"{args.variant}/{args.precision}"
55+
try:
56+
model_id = models_config[0][model_id_key]
57+
revision = models_config[1][revision_key]
58+
except KeyError:
59+
raise Exception(
60+
f"No entry for {model_id_key} or {revision_key} in the models configuration"
61+
)
7262

73-
text_encoder = CLIPTextModel.from_pretrained(
74-
"openai/clip-vit-large-patch14"
75-
)
76-
if args.variant == "stablediffusion":
77-
if args.version != "v1_4":
78-
text_encoder = CLIPTextModel.from_pretrained(
79-
model_config[args.version], subfolder="text_encoder"
80-
)
63+
return model_id, revision
8164

82-
elif args.variant in [
83-
"anythingv3",
84-
"analogdiffusion",
85-
"openjourney",
86-
"dreamlike",
87-
]:
88-
text_encoder = CLIPTextModel.from_pretrained(
89-
model_variant[args.variant],
90-
subfolder="text_encoder",
91-
revision=model_revision[args.variant],
92-
)
93-
else:
94-
raise ValueError(f"{args.variant} not yet added")
65+
66+
def get_clip_mlir(model_name="clip_text", extra_args=[]):
67+
model_id, revision = get_configs()
9568

9669
class CLIPText(torch.nn.Module):
9770
def __init__(self):
9871
super().__init__()
99-
self.text_encoder = text_encoder
72+
self.text_encoder = CLIPTextModel.from_pretrained(
73+
model_id,
74+
subfolder="text_encoder",
75+
revision=revision,
76+
)
10077

10178
def forward(self, input):
10279
return self.text_encoder(input)[0]
10380

10481
clip_model = CLIPText()
10582
shark_clip = compile_through_fx(
10683
clip_model,
107-
model_input[args.version]["clip"],
84+
model_input[version]["clip"],
10885
model_name=model_name,
10986
extra_args=extra_args,
11087
)
11188
return shark_clip
11289

11390

91+
def get_shark_module(model_key, module, model_name, extra_args):
92+
if args.precision == "fp16":
93+
module = module.half().cuda()
94+
inputs = tuple(
95+
[
96+
inputs.half().cuda() if len(inputs.shape) != 0 else inputs
97+
for inputs in model_input[version][model_key]
98+
]
99+
)
100+
else:
101+
inputs = model_input[version][model_key]
102+
103+
shark_module = compile_through_fx(
104+
module,
105+
inputs,
106+
model_name=model_name,
107+
extra_args=extra_args,
108+
)
109+
return shark_module
110+
111+
114112
def get_base_vae_mlir(model_name="vae", extra_args=[]):
113+
model_id, revision = get_configs()
114+
115115
class BaseVaeModel(torch.nn.Module):
116116
def __init__(self):
117117
super().__init__()
118118
self.vae = AutoencoderKL.from_pretrained(
119-
model_config[args.version]
120-
if args.variant == "stablediffusion"
121-
else model_variant[args.variant],
119+
model_id,
122120
subfolder="vae",
123-
revision=model_revision[args.variant],
121+
revision=revision,
124122
)
125123

126124
def forward(self, input):
127125
x = self.vae.decode(input, return_dict=False)[0]
128126
return (x / 2 + 0.5).clamp(0, 1)
129127

130128
vae = BaseVaeModel()
131-
if args.variant == "stablediffusion":
132-
if args.precision == "fp16":
133-
vae = vae.half().cuda()
134-
inputs = tuple(
135-
[
136-
inputs.half().cuda()
137-
for inputs in model_input[args.version]["vae"]
138-
]
139-
)
140-
else:
141-
inputs = model_input[args.version]["vae"]
142-
elif args.variant in [
143-
"anythingv3",
144-
"analogdiffusion",
145-
"openjourney",
146-
"dreamlike",
147-
]:
148-
if args.precision == "fp16":
149-
vae = vae.half().cuda()
150-
inputs = tuple(
151-
[inputs.half().cuda() for inputs in model_input["v1_4"]["vae"]]
152-
)
153-
else:
154-
inputs = model_input["v1_4"]["vae"]
155-
else:
156-
raise ValueError(f"{args.variant} not yet added")
157-
158-
shark_vae = compile_through_fx(
159-
vae,
160-
inputs,
161-
model_name=model_name,
162-
extra_args=extra_args,
163-
)
164-
return shark_vae
129+
return get_shark_module("vae", vae, model_name, extra_args)
165130

166131

167132
def get_vae_mlir(model_name="vae", extra_args=[]):
133+
model_id, revision = get_configs()
134+
168135
class VaeModel(torch.nn.Module):
169136
def __init__(self):
170137
super().__init__()
171138
self.vae = AutoencoderKL.from_pretrained(
172-
model_config[args.version]
173-
if args.variant == "stablediffusion"
174-
else model_variant[args.variant],
139+
model_id,
175140
subfolder="vae",
176-
revision=model_revision[args.variant],
141+
revision=revision,
177142
)
178143

179144
def forward(self, input):
@@ -184,52 +149,19 @@ def forward(self, input):
184149
return x.round()
185150

186151
vae = VaeModel()
187-
if args.variant == "stablediffusion":
188-
if args.precision == "fp16":
189-
vae = vae.half().cuda()
190-
inputs = tuple(
191-
[
192-
inputs.half().cuda()
193-
for inputs in model_input[args.version]["vae"]
194-
]
195-
)
196-
else:
197-
inputs = model_input[args.version]["vae"]
198-
elif args.variant in [
199-
"anythingv3",
200-
"analogdiffusion",
201-
"openjourney",
202-
"dreamlike",
203-
]:
204-
if args.precision == "fp16":
205-
vae = vae.half().cuda()
206-
inputs = tuple(
207-
[inputs.half().cuda() for inputs in model_input["v1_4"]["vae"]]
208-
)
209-
else:
210-
inputs = model_input["v1_4"]["vae"]
211-
else:
212-
raise ValueError(f"{args.variant} not yet added")
213-
214-
shark_vae = compile_through_fx(
215-
vae,
216-
inputs,
217-
model_name=model_name,
218-
extra_args=extra_args,
219-
)
220-
return shark_vae
152+
return get_shark_module("vae", vae, model_name, extra_args)
221153

222154

223155
def get_unet_mlir(model_name="unet", extra_args=[]):
156+
model_id, revision = get_configs()
157+
224158
class UnetModel(torch.nn.Module):
225159
def __init__(self):
226160
super().__init__()
227161
self.unet = UNet2DConditionModel.from_pretrained(
228-
model_config[args.version]
229-
if args.variant == "stablediffusion"
230-
else model_variant[args.variant],
162+
model_id,
231163
subfolder="unet",
232-
revision=model_revision[args.variant],
164+
revision=revision,
233165
)
234166
self.in_channels = self.unet.in_channels
235167
self.train(False)
@@ -247,39 +179,4 @@ def forward(self, latent, timestep, text_embedding, guidance_scale):
247179
return noise_pred
248180

249181
unet = UnetModel()
250-
if args.variant == "stablediffusion":
251-
if args.precision == "fp16":
252-
unet = unet.half().cuda()
253-
inputs = tuple(
254-
[
255-
inputs.half().cuda() if len(inputs.shape) != 0 else inputs
256-
for inputs in model_input[args.version]["unet"]
257-
]
258-
)
259-
else:
260-
inputs = model_input[args.version]["unet"]
261-
elif args.variant in [
262-
"anythingv3",
263-
"analogdiffusion",
264-
"openjourney",
265-
"dreamlike",
266-
]:
267-
if args.precision == "fp16":
268-
unet = unet.half().cuda()
269-
inputs = tuple(
270-
[
271-
inputs.half().cuda() if len(inputs.shape) != 0 else inputs
272-
for inputs in model_input["v1_4"]["unet"]
273-
]
274-
)
275-
else:
276-
inputs = model_input["v1_4"]["unet"]
277-
else:
278-
raise ValueError(f"{args.variant} is not yet added")
279-
shark_unet = compile_through_fx(
280-
unet,
281-
inputs,
282-
model_name=model_name,
283-
extra_args=extra_args,
284-
)
285-
return shark_unet
182+
return get_shark_module("unet", unet, model_name, extra_args)

web/models/stable_diffusion/opt_params.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ def get_params(bucket_key, model_key, model, is_tuned, precision):
3333
]
3434
except KeyError:
3535
raise Exception(
36-
f"{bucket}/{model_key} is not present in the models database"
36+
f" there is no entry for {model_key} in the models database"
3737
)
3838

3939
if (

web/models/stable_diffusion/resources.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,3 +29,13 @@ def resource_path(relative_path):
2929

3030
if len(models_db) != 3:
3131
sys.exit("Error: Unable to load models database.")
32+
33+
34+
models_config = []
35+
modelconfig_loc = resource_path("resources/model_config.json")
36+
if os.path.exists(modelconfig_loc):
37+
with open(modelconfig_loc, encoding="utf-8") as fopen:
38+
models_config = json.load(fopen)
39+
40+
if len(models_config) != 2:
41+
sys.exit("Error: Unable to load models configuration.")
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
[
2+
{
3+
"stablediffusion/v1_4":"CompVis/stable-diffusion-v1-4",
4+
"stablediffusion/v2_1base":"stabilityai/stable-diffusion-2-1-base",
5+
"stablediffusion/v2_1":"stabilityai/stable-diffusion-2-1",
6+
"anythingv3/v1_4":"Linaqruf/anything-v3.0",
7+
"analogdiffusion/v1_4":"wavymulder/Analog-Diffusion",
8+
"openjourney/v1_4":"prompthero/openjourney",
9+
"dreamlike/v1_4":"dreamlike-art/dreamlike-diffusion-1.0"
10+
},
11+
{
12+
"stablediffusion/fp16":"fp16",
13+
"stablediffusion/fp32":"main",
14+
"anythingv3/fp16":"diffusers",
15+
"anythingv3/fp32":"diffusers",
16+
"analogdiffusion/fp16":"main",
17+
"analogdiffusion/fp32":"main",
18+
"openjourney/fp16":"main",
19+
"openjourney/fp32":"main"
20+
}
21+
]

0 commit comments

Comments
 (0)