diff --git a/RELEASENOTES.md b/RELEASENOTES.md index 75fb28142..5b3cdccd0 100644 --- a/RELEASENOTES.md +++ b/RELEASENOTES.md @@ -2,6 +2,12 @@ Releases, starting with 9/2/2021, are listed with the most recent release at the top. +## NuGet Version 0.101.7 + +__API Changes__: + +#1219: Added support for loading and saving tensors that are >2GB. + ## NuGet Version 0.101.6 __API Changes__: diff --git a/src/TorchSharp/NN/Module.cs b/src/TorchSharp/NN/Module.cs index 4ca8a3258..69faf6108 100644 --- a/src/TorchSharp/NN/Module.cs +++ b/src/TorchSharp/NN/Module.cs @@ -509,7 +509,7 @@ public virtual (IList missing_keys, IList unexpected_keyes) load foreach (var key in source.Keys) { if (skip.Contains(key)) continue; if (destination.ContainsKey(key)) { - destination[key].bytes = source[key].bytes; + destination[key].copy_(source[key]); } } diff --git a/src/TorchSharp/Tensor/Factories/Tensor.Factories.cs b/src/TorchSharp/Tensor/Factories/Tensor.Factories.cs index 9bc1c562f..87e26b58c 100644 --- a/src/TorchSharp/Tensor/Factories/Tensor.Factories.cs +++ b/src/TorchSharp/Tensor/Factories/Tensor.Factories.cs @@ -521,11 +521,8 @@ public static Tensor Load(System.IO.BinaryReader reader) throw new NotImplementedException("Loading tensors larger than 2GB"); var tensor = torch.empty(loadedShape, dtype: type); - - var bytes = reader.ReadBytes((int)(totalSize * type.ElementSize())); - - tensor.bytes = bytes; - + tensor.ReadBytesFromStream(reader.BaseStream); + return tensor; } diff --git a/src/TorchSharp/Tensor/Tensor.cs b/src/TorchSharp/Tensor/Tensor.cs index b8b457063..a2c4f7684 100644 --- a/src/TorchSharp/Tensor/Tensor.cs +++ b/src/TorchSharp/Tensor/Tensor.cs @@ -4,6 +4,7 @@ using System.ComponentModel; using System.Diagnostics.Contracts; using System.Globalization; +using System.IO; using System.Linq; using System.Runtime.CompilerServices; using System.Runtime.InteropServices; @@ -404,6 +405,78 @@ internal void ValidateType(Type dotnetType) } } + + /// + /// Writes the bytes of the tensor to a stream. Useful for when tensors are >2GB. + /// + /// Stream to write the bytes to + /// The buffer size to use when writing to the stream + public void WriteBytesToStream(Stream stream, int bufferSize = 1024) + { + // Validate, but passing 0 as the total size, since we don't need to validate the size + _validate(0); + + long totalSize = NumberOfElements * ElementSize; + + unsafe { + var ptr = NativeMethods.THSTensor_data(handle); + if (ptr == IntPtr.Zero) { CheckForErrors(); } + + // NOTE: there is no safety here in this loop. + // Read in the buffer N bytes at a time, and write them out + byte[] buffer = new byte[bufferSize]; + while (totalSize > 0) { + // Read in the current buffer size + int curBufferSize = (int)Math.Min(totalSize, bufferSize); + var span = new Span((void*)ptr, curBufferSize); + span.CopyTo(buffer); + + // Write it out + stream.Write(buffer, 0, curBufferSize); + + // Increment our pointer and decrease the total size of elements we have to write + ptr += curBufferSize; + totalSize -= curBufferSize; + } + } + } + + /// + /// Reads the bytes of the tensor from a stream. + /// + /// Stream to read the bytes from + /// The buffer size to use when reading from the stream + public void ReadBytesFromStream(Stream stream, int bufferSize = 1024) + { + long totalSize = NumberOfElements * ElementSize; + + // Validate that this tensor matches the conditions for reading the bytes - pass 0 as total size + // since we don't need to check that condition. + _validate(0); + + unsafe { + var ptr = NativeMethods.THSTensor_data(handle); + if (ptr == IntPtr.Zero) { CheckForErrors(); } + + // NOTE: there is no safety here in this loop. + // Read in the buffer N bytes at a time, and write them out + byte[] buffer = new byte[bufferSize]; + while (totalSize > 0) { + // Read in the current buffer size + int curBufferSize = (int)Math.Min(totalSize, bufferSize); + stream.Read(buffer, 0, curBufferSize); + + // Copy the contents over to the span + var span = new Span((void*)ptr, curBufferSize); + buffer.AsSpan(0, curBufferSize).CopyTo(span); + + // Increment our pointer and decrease the total size of elements we have to write + ptr += curBufferSize; + totalSize -= curBufferSize; + } + } + } + /// /// Get or set the contents of a tensor as raw bytes. /// diff --git a/src/TorchSharp/Tensor/TensorExtensionMethods.cs b/src/TorchSharp/Tensor/TensorExtensionMethods.cs index 09d79028f..846be615b 100644 --- a/src/TorchSharp/Tensor/TensorExtensionMethods.cs +++ b/src/TorchSharp/Tensor/TensorExtensionMethods.cs @@ -387,13 +387,9 @@ public static void Save(this Tensor tensor, System.IO.BinaryWriter writer) // Then, the shape. writer.Encode(tensor.shape.Length); // 4 bytes foreach (var s in tensor.shape) writer.Encode(s); // n * 8 bytes - // Then, the data -#if NETSTANDARD2_0_OR_GREATER - // TODO: NETSTANDARD2_0_OR_GREATER Try to optimize to avoid the allocation - writer.Write(tensor.bytes.ToArray()); // ElementSize * NumberOfElements -#else - writer.Write(tensor.bytes); // ElementSize * NumberOfElements -#endif // NETSTANDARD2_0_OR_GREATER + + // Then, the data + tensor.WriteBytesToStream(writer.BaseStream); if (copied) tensor.Dispose(); } @@ -441,37 +437,28 @@ public static void Load(this Tensor tensor, System.IO.BinaryReader reader, bool // Then, the shape var shLen = reader.Decode(); long[] loadedShape = new long[shLen]; - - long totalSize = 1; for (int i = 0; i < shLen; ++i) { loadedShape[i] = reader.Decode(); - totalSize *= loadedShape[i]; } if (!skip && !loadedShape.SequenceEqual(tensor.shape)) // We only care about this if the bytes will be written to the tensor. throw new ArgumentException("Mismatched tensor shape while loading. Make sure that the model you are loading into is exactly the same as the origin."); - // - // TODO: Fix this so that you can read large tensors. Right now, they are limited to 2GB - // - if (totalSize > int.MaxValue) - throw new NotImplementedException("Loading tensors larger than 2GB"); - - // This needs to be done even if the tensor is skipped, since we have to advance the input stream. - var bytes = reader.ReadBytes((int)(totalSize * tensor.ElementSize)); - if (!skip) { + using var d = torch.no_grad(); // so that we can perform operations on leaf tensors + var device = tensor.device; if (device.type != DeviceType.CPU) { - using var copy = tensor.to(CPU); - copy.bytes = bytes; - using var moved = copy.to(device); - tensor.set_(moved); - } - else { - tensor.bytes = bytes; - } + using var temp = torch.zeros_like(tensor, device: torch.CPU); + temp.ReadBytesFromStream(reader.BaseStream); + + tensor.copy_(temp); + } else { + tensor.ReadBytesFromStream(reader.BaseStream); + } + } else { + reader.BaseStream.Seek(tensor.NumberOfElements * tensor.ElementSize, System.IO.SeekOrigin.Current); } } @@ -510,18 +497,10 @@ public static void Load(ref Tensor? tensor, System.IO.BinaryReader reader, bool var shLen = reader.Decode(); long[] loadedShape = new long[shLen]; - long totalSize = 1; for (int i = 0; i < shLen; ++i) { loadedShape[i] = reader.Decode(); - totalSize *= loadedShape[i]; } - // - // TODO: Fix this so that you can read large tensors. Right now, they are limited to 2GB - // - if (totalSize > int.MaxValue) - throw new NotImplementedException("Loading tensors larger than 2GB"); - if (tensor is null) { // If the tensor doesn't exist, initialize by zeros unless // it's going to be loaded from the stream. @@ -533,15 +512,19 @@ public static void Load(ref Tensor? tensor, System.IO.BinaryReader reader, bool throw new ArgumentException("Mismatched tensor shape while loading. Make sure that the model you are loading into is exactly the same as the origin."); } - // This needs to be done even if the tensor is skipped, since we have to advance the input stream. - var bytes = reader.ReadBytes((int)(totalSize * tensor.ElementSize)); - if (!skip) { + using var d = torch.no_grad(); // so that we can perform operations on leaf tensors + var device = tensor.device; - if (device.type != DeviceType.CPU) - tensor = tensor.to(CPU, disposeAfter: true); - tensor.bytes = bytes; - tensor = tensor.to(device, disposeAfter: true); + if (device.type != DeviceType.CPU) { + using var temp = torch.zeros_like(tensor, device: torch.CPU); + temp.ReadBytesFromStream(reader.BaseStream); + tensor.copy_(temp); + } else { + tensor.ReadBytesFromStream(reader.BaseStream); + } + } else { + reader.BaseStream.Seek(tensor.NumberOfElements * tensor.ElementSize, System.IO.SeekOrigin.Current); } } diff --git a/src/TorchVision/File.cs b/src/TorchVision/File.cs index 73b6ab1d7..ea0c48cff 100644 --- a/src/TorchVision/File.cs +++ b/src/TorchVision/File.cs @@ -48,7 +48,8 @@ public static async Task read_file_async(string filename) /// One dimensional uint8 Tensor. public static void write_file(string filename, Tensor data) { - File.WriteAllBytes(filename, data.bytes.ToArray()); + using var stream = File.OpenWrite(filename); + data.WriteBytesToStream(stream); } /// diff --git a/test/TorchSharpTest/TestTorchTensor.cs b/test/TorchSharpTest/TestTorchTensor.cs index 34ede2b58..92d3009e4 100644 --- a/test/TorchSharpTest/TestTorchTensor.cs +++ b/test/TorchSharpTest/TestTorchTensor.cs @@ -8269,5 +8269,111 @@ Tensor nbins_ratio(int seed, int size) } } } + + + [Fact(Skip = "Very heavy on the compute")] + public void TestSaveAndLoadLarger2GBTensor() + { + var tensor = torch.rand((long)int.MaxValue + 128, device: torch.CPU); + + var tempFile = Path.GetTempFileName(); + try { + // Save to memory + using (var fs = File.OpenWrite(tempFile)) + tensor.Save(fs); + + // Create a new copy of zeros + var copyTensor = torch.zeros_like(tensor, device: torch.CPU); + + // Read it in + using (var fs = File.OpenRead(tempFile)) + copyTensor.Load(new BinaryReader(fs)); + + Assert.Equal(tensor.npstr(), copyTensor.npstr()); + } finally { + File.Delete(tempFile); + } + } + + [Fact(Skip = "Very heavy on the compute")] + public void TestSaveAndLoadLarger2GBTensorCUDA() + { + if (torch.cuda.is_available()) { + var tensor = torch.rand((long)int.MaxValue + 128, device: torch.CUDA); + + var tempFile = Path.GetTempFileName(); + try { + // Save to memory + using (var fs = File.OpenWrite(tempFile)) + tensor.Save(fs); + + // Create a new copy of zeros + var copyTensor = torch.zeros_like(tensor, device: torch.CUDA); + + // Read it in + using (var fs = File.OpenRead(tempFile)) + copyTensor.Load(new BinaryReader(fs)); + + Assert.Equal(tensor.npstr(), copyTensor.npstr()); + } finally { + File.Delete(tempFile); + } + } + } + + + [Fact(Skip = "Very heavy on the compute")] + public void TestSaveAndLoadModuleWithLarger2GBTensor() + { + // Create a sequential with a parameter slightly larger than 2GB + var seq = nn.Sequential(("lin1", torch.nn.Linear(int.MaxValue / 2048, 2049, false))); + + var tempFile = Path.GetTempFileName(); + try { + // Save to memory + using (var fs = File.OpenWrite(tempFile)) + seq.save(fs); + + // Create a new sequence, and make sure it is equal + var copySeq = nn.Sequential(("lin1", torch.nn.Linear(int.MaxValue / 2048, 2049, false))); + + // Read it in + using (var fs = File.OpenRead(tempFile)) + copySeq.load(fs); + + // Compare results + Assert.Equal(copySeq.parameters().First().npstr(), seq.parameters().First().npstr()); + } finally { + File.Delete(tempFile); + } + } + + [Fact(Skip = "Very heavy on the compute")] + public void TestSaveAndLoadModuleWithLarger2GBTensorCUDA() + { + if (torch.cuda.is_available()) { + // Create a sequential with a parameter slightly larger than 2GB + var seq = nn.Sequential(("lin1", torch.nn.Linear(int.MaxValue / 2048, 2049, false))).cuda(); + + var tempFile = Path.GetTempFileName(); + try { + // Save to memory + using (var fs = File.OpenWrite(tempFile)) + seq.save(fs); + + // Create a new sequence, and make sure it is equal + var copySeq = nn.Sequential(("lin1", torch.nn.Linear(int.MaxValue / 2048, 2049, false))).cuda(); + + // Read it in + using (var fs = File.OpenRead(tempFile)) + copySeq.load(fs); + + // Compare results + Assert.Equal(copySeq.parameters().First().npstr(), seq.parameters().First().npstr()); + } finally { + File.Delete(tempFile); + } + } + } } } \ No newline at end of file