Skip to content

Commit c48ca33

Browse files
Merge pull request #1175 from shaltielshmid/tensor-load-cpu
Make sure to handle tensor copy when moving to cpu before loading
2 parents c7aab8c + b5cc811 commit c48ca33

File tree

3 files changed

+34
-5
lines changed

3 files changed

+34
-5
lines changed

RELEASENOTES.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ __Bug Fixes__:
2424
#1170 : Calling `torch.nn.rnn.utils.pad_packed_sequence` with a CUDA tensor and unsorted_indices threw an error
2525
#1172 : `optim.LoadStateDict` from an existing `StateDictionary` updated to make sure to copy value and to the right device.
2626
#1176 : When specific `Optimizers` load in a conditional tensor, made sure to copy to the right device.
27+
#1174 : Loading CUDA tensor from stream threw an error
2728

2829
## NuGet Version 0.101.2
2930

src/TorchSharp/Tensor/TensorExtensionMethods.cs

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -463,9 +463,15 @@ public static void Load(this Tensor tensor, System.IO.BinaryReader reader, bool
463463

464464
if (!skip) {
465465
var device = tensor.device;
466-
if (device.type != DeviceType.CPU) tensor.to(CPU);
467-
tensor.bytes = bytes;
468-
tensor.to(device);
466+
if (device.type != DeviceType.CPU) {
467+
using var copy = tensor.to(CPU);
468+
copy.bytes = bytes;
469+
using var moved = copy.to(device);
470+
tensor.set_(moved);
471+
}
472+
else {
473+
tensor.bytes = bytes;
474+
}
469475
}
470476
}
471477

@@ -532,9 +538,10 @@ public static void Load(ref Tensor? tensor, System.IO.BinaryReader reader, bool
532538

533539
if (!skip) {
534540
var device = tensor.device;
535-
if (device.type != DeviceType.CPU) tensor.to(CPU);
541+
if (device.type != DeviceType.CPU)
542+
tensor = tensor.to(CPU, disposeAfter: true);
536543
tensor.bytes = bytes;
537-
tensor.to(device);
544+
tensor = tensor.to(device, disposeAfter: true);
538545
}
539546
}
540547

test/TorchSharpTest/TestTorchTensorBugs.cs

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1456,5 +1456,26 @@ public void Validate1172_Device()
14561456
Assert.Equal(DeviceType.CUDA, (optim2.state_dict().State[0] as Adam.State)!.exp_avg.device.type);
14571457
}
14581458
}
1459+
1460+
[Fact]
1461+
public void Validate1174()
1462+
{
1463+
if (torch.cuda.is_available()) {
1464+
var tensor1 = torch.ones(10, ScalarType.Int32, device: torch.CUDA);
1465+
1466+
// Save to memory stream
1467+
using var ms = new MemoryStream();
1468+
tensor1.Save(new BinaryWriter(ms, System.Text.Encoding.UTF8, true));
1469+
ms.Position = 0;
1470+
1471+
var tensor2 = torch.zeros(10, ScalarType.Int32, device: torch.CUDA);
1472+
// This used to throw an error, trying to load bytes onto a cuda tensor
1473+
tensor2.Load(new BinaryReader(ms));
1474+
1475+
// Make sure tensor 2 is still on CUDA with the right values
1476+
Assert.Equal(DeviceType.CUDA, tensor2.device.type);
1477+
Assert.True(Enumerable.SequenceEqual(tensor2.data<int>(), Enumerable.Repeat(1, 10)));
1478+
}
1479+
}
14591480
}
14601481
}

0 commit comments

Comments
 (0)