Skip to content

Commit e67ea31

Browse files
committed
[SHARK][SD] Add --local_tank_cache flag in the stable diffusion
This flag can be used to set local shark_tank cache directory. Signed-Off-by: Gaurav Shukla <[email protected]>
1 parent 986c126 commit e67ea31

File tree

4 files changed

+20
-0
lines changed

4 files changed

+20
-0
lines changed

shark/examples/shark_inference/stable_diffusion/stable_args.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,12 @@
103103
help="Download and use the tuned version of the model if available",
104104
)
105105

106+
p.add_argument(
107+
"--local_tank_cache",
108+
default="",
109+
help="Specify where to save downloaded shark_tank artifacts. If this is not set, the default is ~/.local/shark_tank/.",
110+
)
111+
106112
p.add_argument(
107113
"--dump_isa",
108114
default=False,

shark/examples/shark_inference/stable_diffusion/utils.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,10 @@ def _compile_module(shark_module, model_name, extra_args=[]):
4040
# Downloads the model from shark_tank and returns the shark_module.
4141
def get_shark_model(tank_url, model_name, extra_args=[]):
4242
from shark.shark_downloader import download_model
43+
from shark.parser import shark_args
44+
45+
# Set local shark_tank cache directory.
46+
shark_args.local_tank_cache = args.local_tank_cache
4347

4448
mlir_model, func_name, inputs, golden_out = download_model(
4549
model_name,

web/models/stable_diffusion/stable_args.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,12 @@
110110
help="Download and use the tuned version of the model if available",
111111
)
112112

113+
p.add_argument(
114+
"--local_tank_cache",
115+
default="",
116+
help="Specify where to save downloaded shark_tank artifacts. If this is not set, the default is ~/.local/shark_tank/.",
117+
)
118+
113119
p.add_argument(
114120
"--vulkan_large_heap_block_size",
115121
default="4294967296",

web/models/stable_diffusion/utils.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,10 @@ def _compile_module(args, shark_module, model_name, extra_args=[]):
3838
# Downloads the model from shark_tank and returns the shark_module.
3939
def get_shark_model(args, tank_url, model_name, extra_args=[]):
4040
from shark.shark_downloader import download_model
41+
from shark.parser import shark_args
42+
43+
# Set local shark_tank cache directory.
44+
shark_args.local_tank_cache = args.local_tank_cache
4145

4246
mlir_model, func_name, inputs, golden_out = download_model(
4347
model_name, tank_url=tank_url, frontend="torch"

0 commit comments

Comments
 (0)