Skip to content

Conditional load tensors in Optimizers aren't loaded onto the correct device #1176

@shaltielshmid

Description

@shaltielshmid

When loading the state_dict of an optimizer with a conditional tensor (such as SGD with MomentumBuffer which is only created after the first step), the optimizer doesn't load the tensor onto the correct device. The next time a step is taken the program will crash due to device conflicts.

The cause of this is that the conditional tensors are kept as null until they are assigned a value during the Step() process. When the optimizer loads in the state from the stream, it has to make sure to copy the tensor to the right device.

Sample code to reproduce:

var lin1 = torch.nn.Linear(10, 10, device: torch.CUDA);
var optim1 = torch.optim.SGD(lin1.parameters(), 0.001, 0.9);

// Create the momentum buffer
torch.nn.functional.mse_loss(lin1.call(torch.rand(10).cuda()), torch.rand(10).cuda()).backward();
optim1.step();

// Save to a memory stream
using var ms = new MemoryStream();
optim1.save_state_dict(new BinaryWriter(ms, System.Text.Encoding.UTF8, true));
ms.Position = 0;

// Create a new optimizer
var optim2 = torch.optim.SGD(lin1.parameters(), 0.001, 0.9);
optim2.load_state_dict(new BinaryReader(ms));

// Try to take a step, momentum buffer will crash as it's on CPU and not cuda
torch.nn.functional.mse_loss(lin1.call(torch.rand(10).cuda()), torch.rand(10).cuda()).backward();
optim2.step();

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions