11from diffusers import AutoencoderKL , UNet2DConditionModel
22from transformers import CLIPTextModel
33from models .stable_diffusion .utils import compile_through_fx
4+ from models .stable_diffusion .resources import models_config
45from models .stable_diffusion .stable_args import args
56import torch
67
7- model_config = {
8- "v2_1" : "stabilityai/stable-diffusion-2-1" ,
9- "v2_1base" : "stabilityai/stable-diffusion-2-1-base" ,
10- "v1_4" : "CompVis/stable-diffusion-v1-4" ,
11- }
128
139# clip has 2 variants of max length 77 or 64.
1410model_clip_max_length = 64 if args .max_length == 64 else 77
1713elif args .variant == "openjourney" :
1814 model_clip_max_length = 64
1915
20- model_variant = {
21- "stablediffusion" : "SD" ,
22- "anythingv3" : "Linaqruf/anything-v3.0" ,
23- "dreamlike" : "dreamlike-art/dreamlike-diffusion-1.0" ,
24- "openjourney" : "prompthero/openjourney" ,
25- "analogdiffusion" : "wavymulder/Analog-Diffusion" ,
26- }
27-
2816model_input = {
2917 "v2_1" : {
3018 "clip" : (torch .randint (1 , 2 , (2 , model_clip_max_length )),),
5846 },
5947}
6048
61- # revision param for from_pretrained defaults to "main" => fp32
62- model_revision = {
63- "stablediffusion" : "fp16" if args .precision == "fp16" else "main" ,
64- "anythingv3" : "diffusers" ,
65- "analogdiffusion" : "main" ,
66- "openjourney" : "main" ,
67- "dreamlike" : "main" ,
68- }
49+ version = args .version if args .variant == "stablediffusion" else "v1_4"
6950
7051
71- def get_clip_mlir (model_name = "clip_text" , extra_args = []):
52+ def get_configs ():
53+ model_id_key = f"{ args .variant } /{ version } "
54+ revision_key = f"{ args .variant } /{ args .precision } "
55+ try :
56+ model_id = models_config [0 ][model_id_key ]
57+ revision = models_config [1 ][revision_key ]
58+ except KeyError :
59+ raise Exception (
60+ f"No entry for { model_id_key } or { revision_key } in the models configuration"
61+ )
7262
73- text_encoder = CLIPTextModel .from_pretrained (
74- "openai/clip-vit-large-patch14"
75- )
76- if args .variant == "stablediffusion" :
77- if args .version != "v1_4" :
78- text_encoder = CLIPTextModel .from_pretrained (
79- model_config [args .version ], subfolder = "text_encoder"
80- )
63+ return model_id , revision
8164
82- elif args .variant in [
83- "anythingv3" ,
84- "analogdiffusion" ,
85- "openjourney" ,
86- "dreamlike" ,
87- ]:
88- text_encoder = CLIPTextModel .from_pretrained (
89- model_variant [args .variant ],
90- subfolder = "text_encoder" ,
91- revision = model_revision [args .variant ],
92- )
93- else :
94- raise ValueError (f"{ args .variant } not yet added" )
65+
66+ def get_clip_mlir (model_name = "clip_text" , extra_args = []):
67+ model_id , revision = get_configs ()
9568
9669 class CLIPText (torch .nn .Module ):
9770 def __init__ (self ):
9871 super ().__init__ ()
99- self .text_encoder = text_encoder
72+ self .text_encoder = CLIPTextModel .from_pretrained (
73+ model_id ,
74+ subfolder = "text_encoder" ,
75+ revision = revision ,
76+ )
10077
10178 def forward (self , input ):
10279 return self .text_encoder (input )[0 ]
10380
10481 clip_model = CLIPText ()
10582 shark_clip = compile_through_fx (
10683 clip_model ,
107- model_input [args . version ]["clip" ],
84+ model_input [version ]["clip" ],
10885 model_name = model_name ,
10986 extra_args = extra_args ,
11087 )
11188 return shark_clip
11289
11390
91+ def get_shark_module (model_key , module , model_name , extra_args ):
92+ if args .precision == "fp16" :
93+ module = module .half ().cuda ()
94+ inputs = tuple (
95+ [
96+ inputs .half ().cuda () if len (inputs .shape ) != 0 else inputs
97+ for inputs in model_input [version ][model_key ]
98+ ]
99+ )
100+ else :
101+ inputs = model_input [version ][model_key ]
102+
103+ shark_module = compile_through_fx (
104+ module ,
105+ inputs ,
106+ model_name = model_name ,
107+ extra_args = extra_args ,
108+ )
109+ return shark_module
110+
111+
114112def get_base_vae_mlir (model_name = "vae" , extra_args = []):
113+ model_id , revision = get_configs ()
114+
115115 class BaseVaeModel (torch .nn .Module ):
116116 def __init__ (self ):
117117 super ().__init__ ()
118118 self .vae = AutoencoderKL .from_pretrained (
119- model_config [args .version ]
120- if args .variant == "stablediffusion"
121- else model_variant [args .variant ],
119+ model_id ,
122120 subfolder = "vae" ,
123- revision = model_revision [ args . variant ] ,
121+ revision = revision ,
124122 )
125123
126124 def forward (self , input ):
127125 x = self .vae .decode (input , return_dict = False )[0 ]
128126 return (x / 2 + 0.5 ).clamp (0 , 1 )
129127
130128 vae = BaseVaeModel ()
131- if args .variant == "stablediffusion" :
132- if args .precision == "fp16" :
133- vae = vae .half ().cuda ()
134- inputs = tuple (
135- [
136- inputs .half ().cuda ()
137- for inputs in model_input [args .version ]["vae" ]
138- ]
139- )
140- else :
141- inputs = model_input [args .version ]["vae" ]
142- elif args .variant in [
143- "anythingv3" ,
144- "analogdiffusion" ,
145- "openjourney" ,
146- "dreamlike" ,
147- ]:
148- if args .precision == "fp16" :
149- vae = vae .half ().cuda ()
150- inputs = tuple (
151- [inputs .half ().cuda () for inputs in model_input ["v1_4" ]["vae" ]]
152- )
153- else :
154- inputs = model_input ["v1_4" ]["vae" ]
155- else :
156- raise ValueError (f"{ args .variant } not yet added" )
157-
158- shark_vae = compile_through_fx (
159- vae ,
160- inputs ,
161- model_name = model_name ,
162- extra_args = extra_args ,
163- )
164- return shark_vae
129+ return get_shark_module ("vae" , vae , model_name , extra_args )
165130
166131
167132def get_vae_mlir (model_name = "vae" , extra_args = []):
133+ model_id , revision = get_configs ()
134+
168135 class VaeModel (torch .nn .Module ):
169136 def __init__ (self ):
170137 super ().__init__ ()
171138 self .vae = AutoencoderKL .from_pretrained (
172- model_config [args .version ]
173- if args .variant == "stablediffusion"
174- else model_variant [args .variant ],
139+ model_id ,
175140 subfolder = "vae" ,
176- revision = model_revision [ args . variant ] ,
141+ revision = revision ,
177142 )
178143
179144 def forward (self , input ):
@@ -184,52 +149,19 @@ def forward(self, input):
184149 return x .round ()
185150
186151 vae = VaeModel ()
187- if args .variant == "stablediffusion" :
188- if args .precision == "fp16" :
189- vae = vae .half ().cuda ()
190- inputs = tuple (
191- [
192- inputs .half ().cuda ()
193- for inputs in model_input [args .version ]["vae" ]
194- ]
195- )
196- else :
197- inputs = model_input [args .version ]["vae" ]
198- elif args .variant in [
199- "anythingv3" ,
200- "analogdiffusion" ,
201- "openjourney" ,
202- "dreamlike" ,
203- ]:
204- if args .precision == "fp16" :
205- vae = vae .half ().cuda ()
206- inputs = tuple (
207- [inputs .half ().cuda () for inputs in model_input ["v1_4" ]["vae" ]]
208- )
209- else :
210- inputs = model_input ["v1_4" ]["vae" ]
211- else :
212- raise ValueError (f"{ args .variant } not yet added" )
213-
214- shark_vae = compile_through_fx (
215- vae ,
216- inputs ,
217- model_name = model_name ,
218- extra_args = extra_args ,
219- )
220- return shark_vae
152+ return get_shark_module ("vae" , vae , model_name , extra_args )
221153
222154
223155def get_unet_mlir (model_name = "unet" , extra_args = []):
156+ model_id , revision = get_configs ()
157+
224158 class UnetModel (torch .nn .Module ):
225159 def __init__ (self ):
226160 super ().__init__ ()
227161 self .unet = UNet2DConditionModel .from_pretrained (
228- model_config [args .version ]
229- if args .variant == "stablediffusion"
230- else model_variant [args .variant ],
162+ model_id ,
231163 subfolder = "unet" ,
232- revision = model_revision [ args . variant ] ,
164+ revision = revision ,
233165 )
234166 self .in_channels = self .unet .in_channels
235167 self .train (False )
@@ -247,39 +179,4 @@ def forward(self, latent, timestep, text_embedding, guidance_scale):
247179 return noise_pred
248180
249181 unet = UnetModel ()
250- if args .variant == "stablediffusion" :
251- if args .precision == "fp16" :
252- unet = unet .half ().cuda ()
253- inputs = tuple (
254- [
255- inputs .half ().cuda () if len (inputs .shape ) != 0 else inputs
256- for inputs in model_input [args .version ]["unet" ]
257- ]
258- )
259- else :
260- inputs = model_input [args .version ]["unet" ]
261- elif args .variant in [
262- "anythingv3" ,
263- "analogdiffusion" ,
264- "openjourney" ,
265- "dreamlike" ,
266- ]:
267- if args .precision == "fp16" :
268- unet = unet .half ().cuda ()
269- inputs = tuple (
270- [
271- inputs .half ().cuda () if len (inputs .shape ) != 0 else inputs
272- for inputs in model_input ["v1_4" ]["unet" ]
273- ]
274- )
275- else :
276- inputs = model_input ["v1_4" ]["unet" ]
277- else :
278- raise ValueError (f"{ args .variant } is not yet added" )
279- shark_unet = compile_through_fx (
280- unet ,
281- inputs ,
282- model_name = model_name ,
283- extra_args = extra_args ,
284- )
285- return shark_unet
182+ return get_shark_module ("unet" , unet , model_name , extra_args )
0 commit comments