Skip to content

Commit c7aab8c

Browse files
Merge pull request #1173 from shaltielshmid/optimizer-load-clone-tensors
Fix optimizer load state dict copy tensor by reference
2 parents 71ee6d1 + bbc569e commit c7aab8c

File tree

16 files changed

+145
-50
lines changed

16 files changed

+145
-50
lines changed

RELEASENOTES.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@ __Bug Fixes__:
2222

2323
#1154 : `mu_product` was not initialized in `NAdam` optimizer
2424
#1170 : Calling `torch.nn.rnn.utils.pad_packed_sequence` with a CUDA tensor and unsorted_indices threw an error
25+
#1172 : `optim.LoadStateDict` from an existing `StateDictionary` updated to make sure to copy value and to the right device.
26+
#1176 : When specific `Optimizers` load in a conditional tensor, made sure to copy to the right device.
2527

2628
## NuGet Version 0.101.2
2729

src/TorchSharp/Optimizers/ASGD.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -241,7 +241,7 @@ public override void LoadStateDict(OptimizerState source)
241241
eta = st_state.eta;
242242
mu = st_state.mu;
243243
ax.Dispose();
244-
ax = st_state.ax;
244+
ax = st_state.ax.to(_parameter.device, copy: true);
245245
}
246246

247247
/// <summary>

src/TorchSharp/Optimizers/Adadelta.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -237,8 +237,8 @@ public override void LoadStateDict(OptimizerState source)
237237
acc_delta.Dispose();
238238

239239
step = st_state.step;
240-
square_avg = st_state.square_avg;
241-
acc_delta = st_state.acc_delta;
240+
square_avg = st_state.square_avg.to(_parameter.device, copy: true);
241+
acc_delta = st_state.acc_delta.to(_parameter.device, copy: true);
242242
}
243243

244244
public override bool ApproximatelyEquals(OptimizerState other)

src/TorchSharp/Optimizers/Adagrad.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -231,7 +231,7 @@ public override void LoadStateDict(OptimizerState source)
231231
var st_state = source as State;
232232
sum.Dispose();
233233
step = st_state.step;
234-
sum = st_state.sum;
234+
sum = st_state.sum.to(_parameter.device, copy: true);
235235
}
236236

237237
/// <summary>

src/TorchSharp/Optimizers/Adam.cs

Lines changed: 6 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -250,13 +250,7 @@ public override void LoadStateDict(BinaryReader reader)
250250
step = reader.ReadInt64();
251251
exp_avg.Load(reader);
252252
exp_avg_sq.Load(reader);
253-
var hasMax = reader.ReadBoolean();
254-
if (hasMax) {
255-
TensorExtensionMethods.Load(ref max_exp_avg_sq, reader);
256-
} else {
257-
max_exp_avg_sq?.Dispose();
258-
max_exp_avg_sq = null;
259-
}
253+
LoadConditionalStateTensor(reader, ref max_exp_avg_sq, _parameter.device);
260254
}
261255

262256
/// <summary>
@@ -285,14 +279,12 @@ public override void LoadStateDict(OptimizerState source)
285279
var st_state = source as State;
286280
exp_avg.Dispose();
287281
exp_avg_sq.Dispose();
288-
if (max_exp_avg_sq is not null) {
289-
max_exp_avg_sq.Dispose();
290-
}
291-
282+
max_exp_avg_sq?.Dispose();
283+
292284
step = st_state.step;
293-
exp_avg = st_state.exp_avg;
294-
exp_avg_sq = st_state.exp_avg_sq;
295-
max_exp_avg_sq = st_state.max_exp_avg_sq;
285+
exp_avg = st_state.exp_avg.to(_parameter.device, copy: true);
286+
exp_avg_sq = st_state.exp_avg_sq.to(_parameter.device, copy: true);
287+
max_exp_avg_sq = st_state.max_exp_avg_sq?.to(_parameter.device, copy: true);
296288
}
297289

298290
public override bool ApproximatelyEquals(OptimizerState other)

src/TorchSharp/Optimizers/AdamW.cs

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -286,14 +286,12 @@ public override void LoadStateDict(OptimizerState source)
286286
var st_state = source as State;
287287
exp_avg.Dispose();
288288
exp_avg_sq.Dispose();
289-
if (max_exp_avg_sq is not null) {
290-
max_exp_avg_sq.Dispose();
291-
}
292-
289+
max_exp_avg_sq?.Dispose();
290+
293291
step = st_state.step;
294-
exp_avg = st_state.exp_avg;
295-
exp_avg_sq = st_state.exp_avg_sq;
296-
max_exp_avg_sq = st_state.max_exp_avg_sq;
292+
exp_avg = st_state.exp_avg.to(_parameter.device, copy: true);
293+
exp_avg_sq = st_state.exp_avg_sq.to(_parameter.device, copy: true);
294+
max_exp_avg_sq = st_state.max_exp_avg_sq?.to(_parameter.device, copy: true);
297295
}
298296

299297
public override bool ApproximatelyEquals(OptimizerState other)

src/TorchSharp/Optimizers/Adamax.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -255,8 +255,8 @@ public override void LoadStateDict(OptimizerState source)
255255
exp_inf.Dispose();
256256

257257
step = st_state.step;
258-
exp_avg = st_state.exp_avg;
259-
exp_inf = st_state.exp_inf;
258+
exp_avg = st_state.exp_avg.to(_parameter.device, copy: true);
259+
exp_inf = st_state.exp_inf.to(_parameter.device, copy: true);
260260
}
261261

262262
public override bool ApproximatelyEquals(OptimizerState other)

src/TorchSharp/Optimizers/NAdam.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -269,8 +269,8 @@ public override void LoadStateDict(OptimizerState source)
269269

270270
step = st_state.step;
271271
mu_product = st_state.mu_product;
272-
exp_avg = st_state.exp_avg;
273-
exp_avg_sq = st_state.exp_avg_sq;
272+
exp_avg = st_state.exp_avg.to(_parameter.device, copy: true);
273+
exp_avg_sq = st_state.exp_avg_sq.to(_parameter.device, copy: true);
274274
}
275275

276276
/// <summary>

src/TorchSharp/Optimizers/Optimizer.cs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -555,12 +555,13 @@ public virtual bool ApproximatelyEquals(OptimizerState other)
555555
/// <param name="device">The device to move all state to.</param>
556556
public virtual void to(Device device) { }
557557

558-
protected static void LoadConditionalStateTensor(BinaryReader reader, ref Tensor result)
558+
protected static void LoadConditionalStateTensor(BinaryReader reader, ref Tensor result, Device device)
559559
{
560560
var hasTensor = reader.ReadBoolean();
561561

562562
if (hasTensor) {
563563
TensorExtensionMethods.Load(ref result, reader);
564+
result = result.to(device, disposeAfter: true);
564565
} else {
565566
if (result is not null)
566567
result.Dispose();

src/TorchSharp/Optimizers/RAdam.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -262,8 +262,8 @@ public override void LoadStateDict(OptimizerState source)
262262
exp_avg_sq.Dispose();
263263

264264
step = st_state.step;
265-
exp_avg = st_state.exp_avg;
266-
exp_avg_sq = st_state.exp_avg_sq;
265+
exp_avg = st_state.exp_avg.to(_parameter.device, copy: true);
266+
exp_avg_sq = st_state.exp_avg_sq.to(_parameter.device, copy: true);
267267
}
268268

269269
/// <summary>

0 commit comments

Comments
 (0)