Skip to content

Commit d4e62ce

Browse files
authored
add an import-mlir fallback in case of failure (huggingface#1030)
may not cover all cases. will observet Co-authored-by: dan <[email protected]>
1 parent 9738483 commit d4e62ce

File tree

3 files changed

+54
-22
lines changed

3 files changed

+54
-22
lines changed

apps/stable_diffusion/src/pipelines/pipeline_shark_stable_diffusion_utils.py

Lines changed: 32 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -219,16 +219,37 @@ def from_pretrained(
219219
)
220220
clip, unet, vae = mlir_import()
221221
return cls(vae, clip, get_tokenizer(), unet, scheduler)
222-
223-
if cls.__name__ in ["Image2ImagePipeline", "InpaintPipeline"]:
222+
try:
223+
if cls.__name__ in ["Image2ImagePipeline", "InpaintPipeline"]:
224+
return cls(
225+
get_vae_encode(),
226+
get_vae(),
227+
get_clip(),
228+
get_tokenizer(),
229+
get_unet(),
230+
scheduler,
231+
)
224232
return cls(
225-
get_vae_encode(),
226-
get_vae(),
227-
get_clip(),
228-
get_tokenizer(),
229-
get_unet(),
230-
scheduler,
233+
get_vae(), get_clip(), get_tokenizer(), get_unet(), scheduler
231234
)
232-
return cls(
233-
get_vae(), get_clip(), get_tokenizer(), get_unet(), scheduler
234-
)
235+
except:
236+
print("download pipeline failed, falling back to import_mlir")
237+
mlir_import = SharkifyStableDiffusionModel(
238+
model_id,
239+
ckpt_loc,
240+
custom_vae,
241+
precision,
242+
max_len=max_length,
243+
batch_size=batch_size,
244+
height=height,
245+
width=width,
246+
use_base_vae=use_base_vae,
247+
use_tuned=use_tuned,
248+
)
249+
if cls.__name__ in ["Image2ImagePipeline", "InpaintPipeline"]:
250+
clip, unet, vae, vae_encode = mlir_import()
251+
return cls(
252+
vae_encode, vae, clip, get_tokenizer(), unet, scheduler
253+
)
254+
clip, unet, vae = mlir_import()
255+
return cls(vae, clip, get_tokenizer(), unet, scheduler)

apps/stable_diffusion/src/schedulers/shark_eulerdiscrete.py

Lines changed: 22 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ def forward(self, noise_pred, sigma, latent, dt):
8787
if sys.platform == "darwin":
8888
iree_flags.append("-iree-stream-fuse-binding=false")
8989

90-
if args.import_mlir:
90+
def _import(self):
9191
scaling_model = ScalingModel()
9292
self.scaling_model = compile_through_fx(
9393
scaling_model,
@@ -105,15 +105,28 @@ def forward(self, noise_pred, sigma, latent, dt):
105105
+ args.precision,
106106
extra_args=iree_flags,
107107
)
108+
109+
if args.import_mlir:
110+
_import(self)
111+
108112
else:
109-
self.scaling_model = get_shark_model(
110-
SCHEDULER_BUCKET,
111-
"euler_scale_model_input_" + args.precision,
112-
iree_flags,
113-
)
114-
self.step_model = get_shark_model(
115-
SCHEDULER_BUCKET, "euler_step_" + args.precision, iree_flags
116-
)
113+
try:
114+
self.scaling_model = get_shark_model(
115+
SCHEDULER_BUCKET,
116+
"euler_scale_model_input_" + args.precision,
117+
iree_flags,
118+
)
119+
self.step_model = get_shark_model(
120+
SCHEDULER_BUCKET,
121+
"euler_step_" + args.precision,
122+
iree_flags,
123+
)
124+
except:
125+
print(
126+
"failed to download model, falling back and using import_mlir"
127+
)
128+
args.import_mlir = True
129+
_import(self)
117130

118131
def scale_model_input(self, sample, timestep):
119132
step_index = (self.timesteps == timestep).nonzero().item()

apps/stable_diffusion/src/utils/resources/model_db.json

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,6 @@
2222
"stablediffusion/v1_4/vae/fp16/length_77/tuned":"vae_19dec_fp16_tuned",
2323
"stablediffusion/v1_4/vae/fp16/length_77/tuned/cuda":"vae_19dec_fp16_cuda_tuned",
2424
"stablediffusion/v1_4/vae/fp16/length_77/untuned/base":"vae_8dec_fp16",
25-
"stablediffusion/v1_4/vae/fp32/length_77/untuned":"vae_1dec_fp32",
26-
"stablediffusion/v1_4/clip/fp32/length_77/untuned":"clip_18dec_fp32",
2725
"stablediffusion/v2_1base/unet/fp16/length_77/untuned":"unet77_512_512_fp16_stabilityai_stable_diffusion_2_1_base",
2826
"stablediffusion/v2_1base/unet/fp16/length_77/tuned":"unet2base_8dec_fp16_tuned_v2",
2927
"stablediffusion/v2_1base/unet/fp16/length_77/tuned/cuda":"unet2base_8dec_fp16_cuda_tuned",

0 commit comments

Comments
 (0)