Skip to content

Commit 83970c2

Browse files
author
Nathan Cassereau
committed
Solve example throwing an error when executed on a GPU
1 parent 818c7ac commit 83970c2

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

examples/backends/plot_sliced_wass_grad_flow_pytorch.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@
7474
loss_iter = []
7575

7676
# generator for random permutations
77-
gen = torch.Generator()
77+
gen = torch.Generator(device=device)
7878
gen.manual_seed(42)
7979

8080
for i in range(nb_iter_max):
@@ -136,7 +136,7 @@ def _update_plot(i):
136136
loss_iter = []
137137

138138
# generator for random permutations
139-
gen = torch.Generator()
139+
gen = torch.Generator(device=device)
140140
gen.manual_seed(42)
141141

142142
alpha = 0.5

0 commit comments

Comments
 (0)