Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions RELEASENOTES.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,12 @@

Releases, starting with 9/2/2021, are listed with the most recent release at the top.

## NuGet Version 0.101.7

__API Changes__:

#1219: Added support for loading and saving tensors that are >2GB.

## NuGet Version 0.101.6

__API Changes__:
Expand Down
2 changes: 1 addition & 1 deletion src/TorchSharp/NN/Module.cs
Original file line number Diff line number Diff line change
Expand Up @@ -509,7 +509,7 @@ public virtual (IList<string> missing_keys, IList<string> unexpected_keyes) load
foreach (var key in source.Keys) {
if (skip.Contains(key)) continue;
if (destination.ContainsKey(key)) {
destination[key].bytes = source[key].bytes;
destination[key].copy_(source[key]);
}
}

Expand Down
7 changes: 2 additions & 5 deletions src/TorchSharp/Tensor/Factories/Tensor.Factories.cs
Original file line number Diff line number Diff line change
Expand Up @@ -521,11 +521,8 @@ public static Tensor Load(System.IO.BinaryReader reader)
throw new NotImplementedException("Loading tensors larger than 2GB");

var tensor = torch.empty(loadedShape, dtype: type);

var bytes = reader.ReadBytes((int)(totalSize * type.ElementSize()));

tensor.bytes = bytes;

tensor.ReadBytesFromStream(reader.BaseStream);

return tensor;
}

Expand Down
73 changes: 73 additions & 0 deletions src/TorchSharp/Tensor/Tensor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
using System.ComponentModel;
using System.Diagnostics.Contracts;
using System.Globalization;
using System.IO;
using System.Linq;
using System.Runtime.CompilerServices;
using System.Runtime.InteropServices;
Expand Down Expand Up @@ -404,6 +405,78 @@ internal void ValidateType(Type dotnetType)
}
}


/// <summary>
/// Writes the bytes of the tensor to a stream. Useful for when tensors are >2GB.
/// </summary>
/// <param name="stream">Stream to write the bytes to</param>
/// <param name="bufferSize">The buffer size to use when writing to the stream</param>
public void WriteBytesToStream(Stream stream, int bufferSize = 1024)
{
// Validate, but passing 0 as the total size, since we don't need to validate the size
_validate(0);

long totalSize = NumberOfElements * ElementSize;

unsafe {
var ptr = NativeMethods.THSTensor_data(handle);
if (ptr == IntPtr.Zero) { CheckForErrors(); }

// NOTE: there is no safety here in this loop.
// Read in the buffer N bytes at a time, and write them out
byte[] buffer = new byte[bufferSize];
while (totalSize > 0) {
// Read in the current buffer size
int curBufferSize = (int)Math.Min(totalSize, bufferSize);
var span = new Span<byte>((void*)ptr, curBufferSize);
span.CopyTo(buffer);

// Write it out
stream.Write(buffer, 0, curBufferSize);

// Increment our pointer and decrease the total size of elements we have to write
ptr += curBufferSize;
totalSize -= curBufferSize;
}
}
}

/// <summary>
/// Reads the bytes of the tensor from a stream.
/// </summary>
/// <param name="stream">Stream to read the bytes from</param>
/// <param name="bufferSize">The buffer size to use when reading from the stream</param>
public void ReadBytesFromStream(Stream stream, int bufferSize = 1024)
{
long totalSize = NumberOfElements * ElementSize;

// Validate that this tensor matches the conditions for reading the bytes - pass 0 as total size
// since we don't need to check that condition.
_validate(0);

unsafe {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we check for CPU here?
We check in the bytes getter, but not the setter.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If it's going to blow up if it's not CPU, then it's better to check and throw an exception with good, specific information about the problem rather than some generic I/O exception later.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, will add

var ptr = NativeMethods.THSTensor_data(handle);
if (ptr == IntPtr.Zero) { CheckForErrors(); }

// NOTE: there is no safety here in this loop.
// Read in the buffer N bytes at a time, and write them out
byte[] buffer = new byte[bufferSize];
while (totalSize > 0) {
// Read in the current buffer size
int curBufferSize = (int)Math.Min(totalSize, bufferSize);
stream.Read(buffer, 0, curBufferSize);

// Copy the contents over to the span
var span = new Span<byte>((void*)ptr, curBufferSize);
buffer.AsSpan(0, curBufferSize).CopyTo(span);

// Increment our pointer and decrease the total size of elements we have to write
ptr += curBufferSize;
totalSize -= curBufferSize;
}
}
}

/// <summary>
/// Get or set the contents of a tensor as raw bytes.
/// </summary>
Expand Down
67 changes: 25 additions & 42 deletions src/TorchSharp/Tensor/TensorExtensionMethods.cs
Original file line number Diff line number Diff line change
Expand Up @@ -387,13 +387,9 @@ public static void Save(this Tensor tensor, System.IO.BinaryWriter writer)
// Then, the shape.
writer.Encode(tensor.shape.Length); // 4 bytes
foreach (var s in tensor.shape) writer.Encode(s); // n * 8 bytes
// Then, the data
#if NETSTANDARD2_0_OR_GREATER
// TODO: NETSTANDARD2_0_OR_GREATER Try to optimize to avoid the allocation
writer.Write(tensor.bytes.ToArray()); // ElementSize * NumberOfElements
#else
writer.Write(tensor.bytes); // ElementSize * NumberOfElements
#endif // NETSTANDARD2_0_OR_GREATER

// Then, the data
tensor.WriteBytesToStream(writer.BaseStream);

if (copied) tensor.Dispose();
}
Expand Down Expand Up @@ -441,37 +437,28 @@ public static void Load(this Tensor tensor, System.IO.BinaryReader reader, bool
// Then, the shape
var shLen = reader.Decode();
long[] loadedShape = new long[shLen];

long totalSize = 1;
for (int i = 0; i < shLen; ++i) {
loadedShape[i] = reader.Decode();
totalSize *= loadedShape[i];
}

if (!skip && !loadedShape.SequenceEqual(tensor.shape))
// We only care about this if the bytes will be written to the tensor.
throw new ArgumentException("Mismatched tensor shape while loading. Make sure that the model you are loading into is exactly the same as the origin.");

//
// TODO: Fix this so that you can read large tensors. Right now, they are limited to 2GB
//
if (totalSize > int.MaxValue)
throw new NotImplementedException("Loading tensors larger than 2GB");

// This needs to be done even if the tensor is skipped, since we have to advance the input stream.
var bytes = reader.ReadBytes((int)(totalSize * tensor.ElementSize));

if (!skip) {
using var d = torch.no_grad(); // so that we can perform operations on leaf tensors

var device = tensor.device;
if (device.type != DeviceType.CPU) {
using var copy = tensor.to(CPU);
copy.bytes = bytes;
using var moved = copy.to(device);
tensor.set_(moved);
}
else {
tensor.bytes = bytes;
}
using var temp = torch.zeros_like(tensor, device: torch.CPU);
temp.ReadBytesFromStream(reader.BaseStream);

tensor.copy_(temp);
} else {
tensor.ReadBytesFromStream(reader.BaseStream);
}
} else {
reader.BaseStream.Seek(tensor.NumberOfElements * tensor.ElementSize, System.IO.SeekOrigin.Current);
}
}

Expand Down Expand Up @@ -510,18 +497,10 @@ public static void Load(ref Tensor? tensor, System.IO.BinaryReader reader, bool
var shLen = reader.Decode();
long[] loadedShape = new long[shLen];

long totalSize = 1;
for (int i = 0; i < shLen; ++i) {
loadedShape[i] = reader.Decode();
totalSize *= loadedShape[i];
}

//
// TODO: Fix this so that you can read large tensors. Right now, they are limited to 2GB
//
if (totalSize > int.MaxValue)
throw new NotImplementedException("Loading tensors larger than 2GB");

if (tensor is null) {
// If the tensor doesn't exist, initialize by zeros unless
// it's going to be loaded from the stream.
Expand All @@ -533,15 +512,19 @@ public static void Load(ref Tensor? tensor, System.IO.BinaryReader reader, bool
throw new ArgumentException("Mismatched tensor shape while loading. Make sure that the model you are loading into is exactly the same as the origin.");
}

// This needs to be done even if the tensor is skipped, since we have to advance the input stream.
var bytes = reader.ReadBytes((int)(totalSize * tensor.ElementSize));

if (!skip) {
using var d = torch.no_grad(); // so that we can perform operations on leaf tensors

var device = tensor.device;
if (device.type != DeviceType.CPU)
tensor = tensor.to(CPU, disposeAfter: true);
tensor.bytes = bytes;
tensor = tensor.to(device, disposeAfter: true);
if (device.type != DeviceType.CPU) {
using var temp = torch.zeros_like(tensor, device: torch.CPU);
temp.ReadBytesFromStream(reader.BaseStream);
tensor.copy_(temp);
} else {
tensor.ReadBytesFromStream(reader.BaseStream);
}
} else {
reader.BaseStream.Seek(tensor.NumberOfElements * tensor.ElementSize, System.IO.SeekOrigin.Current);
}
}

Expand Down
3 changes: 2 additions & 1 deletion src/TorchVision/File.cs
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,8 @@ public static async Task<Tensor> read_file_async(string filename)
/// <param name="data">One dimensional <c>uint8</c> <cref>Tensor</cref>.</param>
public static void write_file(string filename, Tensor data)
{
File.WriteAllBytes(filename, data.bytes.ToArray());
using var stream = File.OpenWrite(filename);
data.WriteBytesToStream(stream);
}

/// <summary>
Expand Down
106 changes: 106 additions & 0 deletions test/TorchSharpTest/TestTorchTensor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -8269,5 +8269,111 @@ Tensor nbins_ratio(int seed, int size)
}
}
}


[Fact(Skip = "Very heavy on the compute")]
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The tests are very heavy on the compute, so I put them in Skip (but I tested them myself)

public void TestSaveAndLoadLarger2GBTensor()
{
var tensor = torch.rand((long)int.MaxValue + 128, device: torch.CPU);

var tempFile = Path.GetTempFileName();
try {
// Save to memory
using (var fs = File.OpenWrite(tempFile))
tensor.Save(fs);

// Create a new copy of zeros
var copyTensor = torch.zeros_like(tensor, device: torch.CPU);

// Read it in
using (var fs = File.OpenRead(tempFile))
copyTensor.Load(new BinaryReader(fs));

Assert.Equal(tensor.npstr(), copyTensor.npstr());
} finally {
File.Delete(tempFile);
}
}

[Fact(Skip = "Very heavy on the compute")]
public void TestSaveAndLoadLarger2GBTensorCUDA()
{
if (torch.cuda.is_available()) {
var tensor = torch.rand((long)int.MaxValue + 128, device: torch.CUDA);

var tempFile = Path.GetTempFileName();
try {
// Save to memory
using (var fs = File.OpenWrite(tempFile))
tensor.Save(fs);

// Create a new copy of zeros
var copyTensor = torch.zeros_like(tensor, device: torch.CUDA);

// Read it in
using (var fs = File.OpenRead(tempFile))
copyTensor.Load(new BinaryReader(fs));

Assert.Equal(tensor.npstr(), copyTensor.npstr());
} finally {
File.Delete(tempFile);
}
}
}


[Fact(Skip = "Very heavy on the compute")]
public void TestSaveAndLoadModuleWithLarger2GBTensor()
{
// Create a sequential with a parameter slightly larger than 2GB
var seq = nn.Sequential(("lin1", torch.nn.Linear(int.MaxValue / 2048, 2049, false)));

var tempFile = Path.GetTempFileName();
try {
// Save to memory
using (var fs = File.OpenWrite(tempFile))
seq.save(fs);

// Create a new sequence, and make sure it is equal
var copySeq = nn.Sequential(("lin1", torch.nn.Linear(int.MaxValue / 2048, 2049, false)));

// Read it in
using (var fs = File.OpenRead(tempFile))
copySeq.load(fs);

// Compare results
Assert.Equal(copySeq.parameters().First().npstr(), seq.parameters().First().npstr());
} finally {
File.Delete(tempFile);
}
}

[Fact(Skip = "Very heavy on the compute")]
public void TestSaveAndLoadModuleWithLarger2GBTensorCUDA()
{
if (torch.cuda.is_available()) {
// Create a sequential with a parameter slightly larger than 2GB
var seq = nn.Sequential(("lin1", torch.nn.Linear(int.MaxValue / 2048, 2049, false))).cuda();

var tempFile = Path.GetTempFileName();
try {
// Save to memory
using (var fs = File.OpenWrite(tempFile))
seq.save(fs);

// Create a new sequence, and make sure it is equal
var copySeq = nn.Sequential(("lin1", torch.nn.Linear(int.MaxValue / 2048, 2049, false))).cuda();

// Read it in
using (var fs = File.OpenRead(tempFile))
copySeq.load(fs);

// Compare results
Assert.Equal(copySeq.parameters().First().npstr(), seq.parameters().First().npstr());
} finally {
File.Delete(tempFile);
}
}
}
}
}