@@ -143,3 +143,34 @@ def test_pinv():
143143 fgraph = FunctionGraph ([x ], [x_inv ])
144144 x_np = np .array ([[1.0 , 2.0 ], [3.0 , 4.0 ]], dtype = config .floatX )
145145 compare_jax_and_py (fgraph , [x_np ])
146+
147+
148+ def test_pinv_hermitian ():
149+ A = matrix ("A" , dtype = "complex128" )
150+ A_h_test = np .c_ [[3 , 3 + 2j ], [3 - 2j , 2 ]]
151+ A_not_h_test = A_h_test + 0 + 1j
152+
153+ A_inv = at_nlinalg .pinv (A , hermitian = False )
154+ jax_fn = function ([A ], A_inv , mode = "JAX" )
155+
156+ assert np .allclose (jax_fn (A_h_test ), np .linalg .pinv (A_h_test , hermitian = False ))
157+ assert np .allclose (jax_fn (A_h_test ), np .linalg .pinv (A_h_test , hermitian = True ))
158+ assert np .allclose (
159+ jax_fn (A_not_h_test ), np .linalg .pinv (A_not_h_test , hermitian = False )
160+ )
161+ assert not np .allclose (
162+ jax_fn (A_not_h_test ), np .linalg .pinv (A_not_h_test , hermitian = True )
163+ )
164+
165+ A_inv = at_nlinalg .pinv (A , hermitian = True )
166+ jax_fn = function ([A ], A_inv , mode = "JAX" )
167+
168+ assert np .allclose (jax_fn (A_h_test ), np .linalg .pinv (A_h_test , hermitian = False ))
169+ assert np .allclose (jax_fn (A_h_test ), np .linalg .pinv (A_h_test , hermitian = True ))
170+ assert not np .allclose (
171+ jax_fn (A_not_h_test ), np .linalg .pinv (A_not_h_test , hermitian = False )
172+ )
173+ # Numpy fails differently than JAX when hermitian assumption is violated
174+ assert not np .allclose (
175+ jax_fn (A_not_h_test ), np .linalg .pinv (A_not_h_test , hermitian = True )
176+ )
0 commit comments