Skip to content

Commit 8856878

Browse files
authored
Add flag for enabling rgp from the main.py SD script (huggingface#533)
1 parent a9bac02 commit 8856878

File tree

3 files changed

+18
-0
lines changed

3 files changed

+18
-0
lines changed

shark/examples/shark_inference/stable_diffusion/README.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,12 @@ python main.py --precision="fp16" --device="vulkan" --iree-vulkan-target-triple=
2323
python shark/examples/shark_inference/stable_diffusion/main.py --precision=fp16 --device=vulkan --steps=50 --save_vmfb
2424
```
2525

26+
## Capture an RGP trace
27+
28+
```shell
29+
python shark/examples/shark_inference/stable_diffusion/main.py --precision=fp16 --device=vulkan --steps=50 --save_vmfb --enable_rgp
30+
```
31+
2632
## Run the vae module with iree-benchmark-module (NCHW, fp16, vulkan, for example):
2733

2834
```shell

shark/examples/shark_inference/stable_diffusion/stable_args.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,4 +118,11 @@
118118
help="flag for setting VMA preferredLargeHeapBlockSize for vulkan device, default is 4G",
119119
)
120120

121+
p.add_argument(
122+
"--enable_rgp",
123+
default=False,
124+
action=argparse.BooleanOptionalAction,
125+
help="flag for inserting debug frames between iterations for use with rgp.",
126+
)
127+
121128
args = p.parse_args()

shark/examples/shark_inference/stable_diffusion/utils.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,11 @@ def set_iree_runtime_flags():
7070
vulkan_runtime_flags = [
7171
f"--vulkan_large_heap_block_size={args.vulkan_large_heap_block_size}",
7272
]
73+
if args.enable_rgp:
74+
vulkan_runtime_flags += [
75+
f"--enable_rgp=true",
76+
f"--vulkan_debug_utils=true",
77+
]
7378
if "vulkan" in args.device:
7479
set_iree_vulkan_runtime_flags(flags=vulkan_runtime_flags)
7580

0 commit comments

Comments
 (0)