Skip to content

Commit 9bd951b

Browse files
authored
Clean up the v-diffusion install pipeline (huggingface#327)
1 parent c43448a commit 9bd951b

File tree

5 files changed

+39
-4
lines changed

5 files changed

+39
-4
lines changed

tank/pytorch/v_diffusion_pytorch/README.md

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,25 @@ Run the script setup_v_diffusion_pytorch.sh
2727
./v-diffusion-pytorch/cfg_sample.py "New York City, oil on canvas":5 -n 5 -bs 5
2828
```
2929

30+
The runtime device can be specified with `--runtime_device=<device string>`
31+
3032
### Run the v-diffusion model via torch-mlir
3133
```shell
3234
./cfg_sample.py "New York City, oil on canvas":5 -n 1 -bs 1 --steps 2
3335
```
36+
37+
### Run the model stored in the tank
38+
```shell
39+
./cfg_sample_from_mlir.py "New York City, oil on canvas":5 -n 1 -bs 1 --steps 2
40+
```
41+
Note that the current model in the tank requires batch size 1 statically.
42+
43+
### Run the model with preprocessing elements taken out
44+
To run the model without preprocessing copy `cc12m_1.py` to replace the version in `v-diffusion-pytorch`
45+
```shell
46+
cp cc12m_1.py v-diffusion-pytorch/diffusion/models
47+
```
48+
Then run
49+
```shell
50+
./cfg_sample_preprocess.py "New York City, oil on canvas":5 -n 1 -bs 1 --steps 2
51+
```

tank/pytorch/v_diffusion_pytorch/cfg_sample.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,12 @@ def resize_and_center_crop(image, size):
6767
)
6868
p.add_argument("--checkpoint", type=str, help="the checkpoint to use")
6969
p.add_argument("--device", type=str, help="the device to use")
70+
p.add_argument(
71+
"--runtime_device",
72+
type=str,
73+
help="the device to use with SHARK",
74+
default="cpu",
75+
)
7076
p.add_argument(
7177
"--eta",
7278
type=float,
@@ -235,7 +241,7 @@ def strip_overloads(gm):
235241
func_name = "forward"
236242

237243
shark_module = SharkInference(
238-
mlir_model, func_name, device="cuda", mlir_dialect="linalg"
244+
mlir_model, func_name, device=args.runtime_device, mlir_dialect="linalg"
239245
)
240246
shark_module.compile()
241247

tank/pytorch/v_diffusion_pytorch/cfg_sample_from_mlir.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,12 @@ def resize_and_center_crop(image, size):
6969
)
7070
p.add_argument("--checkpoint", type=str, help="the checkpoint to use")
7171
p.add_argument("--device", type=str, help="the device to use")
72+
p.add_argument(
73+
"--runtime_device",
74+
type=str,
75+
help="the device to use with SHARK",
76+
default="cpu",
77+
)
7278
p.add_argument(
7379
"--eta",
7480
type=float,
@@ -188,7 +194,7 @@ def cfg_model_fn(x, t):
188194
mlir_model, func_name, inputs, golden_out = download_torch_model("v_diffusion")
189195

190196
shark_module = SharkInference(
191-
mlir_model, func_name, device="cpu", mlir_dialect="linalg"
197+
mlir_model, func_name, device=args.runtime_device, mlir_dialect="linalg"
192198
)
193199
shark_module.compile()
194200

tank/pytorch/v_diffusion_pytorch/cfg_sample_preprocess.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,12 @@ def resize_and_center_crop(image, size):
6969
)
7070
p.add_argument("--checkpoint", type=str, help="the checkpoint to use")
7171
p.add_argument("--device", type=str, help="the device to use")
72+
p.add_argument(
73+
"--runtime_device",
74+
type=str,
75+
help="the device to use with SHARK",
76+
default="intel-gpu",
77+
)
7278
p.add_argument(
7379
"--eta",
7480
type=float,
@@ -260,7 +266,7 @@ def strip_overloads(gm):
260266
func_name = "forward"
261267

262268
shark_module = SharkInference(
263-
mlir_model, func_name, device="intel-gpu", mlir_dialect="linalg"
269+
mlir_model, func_name, device=args.runtime_device, mlir_dialect="linalg"
264270
)
265271
shark_module.compile()
266272

tank/pytorch/v_diffusion_pytorch/setup_v_diffusion_pytorch.sh

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,4 +24,3 @@ mkdir checkpoints
2424
wget https://the-eye.eu/public/AI/models/v-diffusion/cc12m_1_cfg.pth -P checkpoints/
2525

2626
cp -r checkpoints/ v-diffusion-pytorch/
27-
cp cc12m_1.py v-diffusion-pytorch/diffusion/models/.

0 commit comments

Comments
 (0)