Skip to content

Conversation

@jessegrabowski
Copy link
Member

@jessegrabowski jessegrabowski commented May 13, 2023

Motivation for these changes

Add a jax implementation of pt.linalg.pinv

Implementation details

Not the most elaborate PR:

@jax_funcify.register(MatrixPinv)
def jax_funcify_Pinv(op, **kwargs):
    def pinv(x):
        return jnp.linalg.pinv(x)

    return pinv

Checklist

Major / Breaking Changes

None

New features

You can compile graphs with pinv to JAX

Bugfixes

None

Documentation

None

Maintenance

None

@ricardoV94
Copy link
Member

Thanks @jessegrabowski

@jessegrabowski jessegrabowski deleted the jax_pinv branch May 13, 2023 23:44
@ricardoV94
Copy link
Member

ricardoV94 commented May 14, 2023

Oops I missed one float32 test that was failing: https://github.com/pymc-devs/pytensor/actions/runs/4968627533/jobs/8891362025?pr=294

@jessegrabowski
Copy link
Member Author

jessegrabowski commented May 14, 2023

I didn't add astype(config.floatX) to the test array (again...). Not sure how to proceed -- I deleted the branch on my fork already so it's not auto-pushing the change into this PR.

@ricardoV94
Copy link
Member

No worries, I opened a fix PR in #296

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants