Skip to content

Conversation

@yueyinqiu
Copy link
Contributor

@yueyinqiu yueyinqiu commented Apr 17, 2024

An error occurred when trying to save the downloaded pretained model:

states = weight.get_state_dict()
exportsd.save_state_dict(states, file)

(where weight is something like torchvision.models.AlexNet_Weights.IMAGENET1K_V1)

(full codes here)

@yueyinqiu yueyinqiu marked this pull request as ready for review April 17, 2024 16:26
@yueyinqiu
Copy link
Contributor Author

yueyinqiu commented Apr 17, 2024

I'm not sure why the tensors are in computation graphs... but the error do occur and it could be solved by adding detach.

@GeorgeS2019
Copy link

@yueyinqiu
/Are we missing unit tests to check for this bug?

@NiklasGustafsson
Copy link
Contributor

@yueyinqiu /Are we missing unit tests to check for this bug?

Right. We have no unit tests for the Python export / import code at all. I've just assumed it works... :-)

@yueyinqiu
Copy link
Contributor Author

yueyinqiu commented Apr 18, 2024

Actually I think a simple unit test could not find this out. In most cases the problem does not occur and I don't know why the pretrained model are in computation graphs...

Also I'm not really sure whether we should modify this, since we do not use cpu() as well. Perhaps this should be done by users before saving. Maybe more checks should be made before merging.

@yueyinqiu yueyinqiu marked this pull request as draft April 18, 2024 00:00
@NiklasGustafsson
Copy link
Contributor

Calling detach seems like a safe, harmless thing to do.

@yueyinqiu
Copy link
Contributor Author

just to make a note here: torchvision.models.AlexNet_Weights.IMAGENET1K_V1.get_state_dict()['features.0.weight'].requires_grad is True

@yueyinqiu
Copy link
Contributor Author

yueyinqiu commented Apr 18, 2024

Calling detach seems like a safe, harmless thing to do.

However it seems that torch.save/torch.load will record whether a tensor requires_grad or not in PyTorch. I'm testing on that now.

@yueyinqiu
Copy link
Contributor Author

yueyinqiu commented Apr 18, 2024

Well I think it's fine. The .net format does not save requires_grad at all. And the current PyTorch's state_dict() implemention is using detach() before putting the tensor into the state dictionary, so definitely requires_grad won't be accessed when loading it. Perhaps it's a behavior exists in some old PyTorch versions.

@yueyinqiu yueyinqiu marked this pull request as ready for review April 18, 2024 14:15
@NiklasGustafsson
Copy link
Contributor

Ready to merge?

@yueyinqiu
Copy link
Contributor Author

Ready to merge?

I suppose yes

@NiklasGustafsson NiklasGustafsson merged commit 9c3a461 into dotnet:main Apr 19, 2024
@yueyinqiu yueyinqiu deleted the detach branch April 19, 2024 01:50
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants