4242def test_input (op ) -> None :
4343 x = vector ("x" )
4444 out = op (x > 0 )
45+ out .dprint (print_type = True )
46+ assert 0
4547 x_test = mx .array ([1.0 , 2.0 , 3.0 ])
4648
4749 compare_mlx_and_py ([x ], out , [x_test ])
4850
4951
52+ @pytest .mark .skip (reason = "It's crashing the kernel" )
5053def test_mlx_CAReduce ():
5154 a_pt = vector ("a" )
5255 a_pt .tag .test_value = np .r_ [1 , 2 , 3 ].astype (config .floatX )
@@ -78,6 +81,7 @@ def test_mlx_CAReduce():
7881 compare_mlx_and_py ([a_pt ], [x ], [np .c_ [[1 , 2 , 3 ], [1 , 2 , 3 ]].astype (config .floatX )])
7982
8083
84+ @pytest .mark .skip (reason = "It's crashing the kernel" )
8185@pytest .mark .parametrize ("axis" , [None , 0 , 1 ])
8286def test_softmax (axis ):
8387 x = matrix ("x" )
@@ -86,6 +90,7 @@ def test_softmax(axis):
8690 compare_mlx_and_py ([x ], [out ], [x_test_value ])
8791
8892
93+ @pytest .mark .skip (reason = "It's crashing the kernel" )
8994@pytest .mark .parametrize ("axis" , [None , 0 , 1 ])
9095def test_softmax_grad (axis ):
9196 dy = matrix ("dy" )
@@ -97,6 +102,7 @@ def test_softmax_grad(axis):
97102 compare_mlx_and_py ([dy , sm ], [out ], [dy_test_value , sm_test_value ])
98103
99104
105+ @pytest .mark .skip (reason = "It's crashing the kernel" )
100106@pytest .mark .parametrize ("axis" , [None , 0 , 1 ])
101107def test_logsoftmax (axis ):
102108 x = matrix ("x" )
0 commit comments