Skip to content

Commit 9a4293e

Browse files
committed
remove jax for log test
1 parent d2fd0a2 commit 9a4293e

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

test/test_da.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ def test_class_jax_tf():
4141
with pytest.raises(TypeError):
4242
otda.fit(Xs=Xs, ys=ys, Xt=Xt)
4343

44-
44+
@pytest.skip_backend("jax")
4545
@pytest.mark.parametrize("class_to_test", [ot.da.EMDTransport, ot.da.SinkhornTransport, ot.da.SinkhornLpl1Transport, ot.da.SinkhornL1l2Transport, ot.da.SinkhornL1l2Transport])
4646
def test_log_da(nx, class_to_test):
4747

0 commit comments

Comments
 (0)