Skip to content

Conversation

@RaulPPelaez
Copy link
Collaborator

I added a new parameter, dtype, to control whether TorchMD_Net uses float or double.
The value can be passed along the rest to create_model, but it defaults to float if not present. Can be either a string or a torch.dtype.

I also added some comments here and there.

To showcase I added a test that checks the correctness of gradients for all models using torch.autograd.gradcheck.

@RaulPPelaez RaulPPelaez requested review from raimis and stefdoerr June 6, 2023 12:28
@raimis raimis requested a review from dav0dea June 6, 2023 12:36
@RaulPPelaez RaulPPelaez mentioned this pull request Jun 6, 2023
@peastman
Copy link
Collaborator

peastman commented Jun 6, 2023

This tensor should also get a dtype:

C_6_R_r = pt.tensor(

@RaulPPelaez RaulPPelaez merged commit 28ce74b into torchmd:main Jun 19, 2023
@RaulPPelaez RaulPPelaez deleted the float64 branch June 20, 2023 07:34
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants