-
Notifications
You must be signed in to change notification settings - Fork 214
Closed
Description
When loading the state_dict of an optimizer with a conditional tensor (such as SGD with MomentumBuffer which is only created after the first step), the optimizer doesn't load the tensor onto the correct device. The next time a step is taken the program will crash due to device conflicts.
The cause of this is that the conditional tensors are kept as null until they are assigned a value during the Step() process. When the optimizer loads in the state from the stream, it has to make sure to copy the tensor to the right device.
Sample code to reproduce:
var lin1 = torch.nn.Linear(10, 10, device: torch.CUDA);
var optim1 = torch.optim.SGD(lin1.parameters(), 0.001, 0.9);
// Create the momentum buffer
torch.nn.functional.mse_loss(lin1.call(torch.rand(10).cuda()), torch.rand(10).cuda()).backward();
optim1.step();
// Save to a memory stream
using var ms = new MemoryStream();
optim1.save_state_dict(new BinaryWriter(ms, System.Text.Encoding.UTF8, true));
ms.Position = 0;
// Create a new optimizer
var optim2 = torch.optim.SGD(lin1.parameters(), 0.001, 0.9);
optim2.load_state_dict(new BinaryReader(ms));
// Try to take a step, momentum buffer will crash as it's on CPU and not cuda
torch.nn.functional.mse_loss(lin1.call(torch.rand(10).cuda()), torch.rand(10).cuda()).backward();
optim2.step();Metadata
Metadata
Assignees
Labels
No labels