Skip to content

Conversation

@shaltielshmid
Copy link
Contributor

Fixes: #1172

@shaltielshmid
Copy link
Contributor Author

Also fixes #1176 with the conditional state being copied to the right device.

public override void LoadStateDict(BinaryReader reader)
{
LoadConditionalStateTensor(reader, ref momentum_buffer);
momentum_buffer = momentum_buffer.to(_parameter.device);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No need to dispose the old momentum buffer if it's actually moved?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch. Will fix.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Question: Do you want me to write something like:

if (momentum_buffer.device_type != _parameter.device_type || momentum_buffer.device_index != _parameter.device_index) {
    using var copy = result;
    result = copy.to(_parameter.device);
}

Copy link
Contributor Author

@shaltielshmid shaltielshmid Dec 6, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(That's not handled in the regular Optimizer.to(Device device) function. See for example here)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's sufficient to check the Handle before and after the call to `to'. I think.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe not.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Haha
In that case, should I go with my proposed solution?

Also, does that mean we should update all the Optimizer.to(...) functions?

Maybe we should have a parameter dispose_if_moved to the to function to make it generic?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's handled in the Module variants of to(). Follow that pattern.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe it's worth writing an internal utility function that has the 'move or not move' predicate, so it can be reused.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I checked the Module variants of to(), and there is no dispose going on there, but there they check the device to see if it changed. I'll follow that pattern.

What do we want to do about all the cases where the .to() function is called but no dispose being handled, like here or here etc.

@shaltielshmid shaltielshmid force-pushed the optimizer-load-clone-tensors branch from d1b68ad to bbc569e Compare December 6, 2023 22:36
@shaltielshmid
Copy link
Contributor Author

@NiklasGustafsson What do you think of my solution?
I added a disposeAfter flag to the to(..) function which dispose the input tensor after moving to the new tensor.
Turns out we didn't need to check if the tensor actually moved, because we get a new handle regardless.

I wrote a unit test attempting to move and cast tensors to the same and different types/devices, and made sure the behavior was as expected.

What do you think?

If this solution is good, I'll wait for this to be merged and then use it in the other PR.

@NiklasGustafsson
Copy link
Contributor

@shaltielshmid -- here's a tip: Once you have created a PR, try to consolidate a number of commits before pushing again. Every push starts a new build on the CI/CD pipeline, which wastes resources.

@shaltielshmid
Copy link
Contributor Author

Of course, I apologize.
I generally try to make a few commits, but only push when I complete changes.

step = st_state.step;
square_avg = st_state.square_avg;
acc_delta = st_state.acc_delta;
square_avg = st_state.square_avg.to(_parameter.device, copy: true);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we have to set copy to true here?

Copy link
Contributor Author

@shaltielshmid shaltielshmid Dec 7, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the State Dictionary is actually from another existing optimizer, then the tensors are copied by reference and we don't want two different optimizers sharing a state tensor.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For example:

var lin1 = torch.nn.Linear(10, 10);
var optim1 = torch.optim.Adam(lin1.parameters(), 0.05f);
var lin2 = torch.nn.Linear(10, 10);
var optim2 = torch.optim.Adam(lin2.parameters());
optim2.load_state_dict(optim1.state_dict())

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In that case, shouldn't the 'copy' argument be passed down from where you know whether it's necessary or not?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can add it in, if you think it's the right way to it.

The way I see it, the StateDictionary is an object which is created only by calling Optimizer.state_dict(), and the source Optimizer is the one which manages the tensors that are in the dictionary. If they are to be disposed, then the Optimizer would handle them. And therefore, the new Optimizer should receive a fresh copy of the tensors to manage itself.

Copy link
Contributor Author

@shaltielshmid shaltielshmid Dec 7, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Meaning, the StateDictionary is just an wrapper interface for accessing the values of an Optimizer.
I think we don't ever want two Optimizers to share the same tensor handle.

That being said, if you disagree and would like me to add a copy parameter to the load_state_dict function, no problem.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it. I think that's probably right. If we later conceive of a case where you do want to share state, we can add an argument later and set the default to 'true'

@NiklasGustafsson
Copy link
Contributor

You have run the unit tests on CUDA, right?

@shaltielshmid
Copy link
Contributor Author

Yup

@NiklasGustafsson NiklasGustafsson merged commit c7aab8c into dotnet:main Dec 7, 2023
@NiklasGustafsson
Copy link
Contributor

Merged

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

optim.LoadStateDict from existing StateDict doesn't clone tensors

2 participants