diff --git a/RELEASENOTES.md b/RELEASENOTES.md index f5527428a..222e467bc 100644 --- a/RELEASENOTES.md +++ b/RELEASENOTES.md @@ -27,6 +27,7 @@ __Bug Fixes__: #1174 : Loading CUDA tensor from stream threw an error
#1179 : Calling `Module.to()` with the `ParameterList` and `ParameterDict` module didn't move the parameters stored in the field.
#1148 : Calling `Module.to()` shouldn't be differentiable
+#1180 : Module.to(ScalarType) has restrictions in PyTorch which aren't restricted in TorchSharp.
## NuGet Version 0.101.2 diff --git a/src/TorchSharp/NN/Module.cs b/src/TorchSharp/NN/Module.cs index 031e415ac..0f174b763 100644 --- a/src/TorchSharp/NN/Module.cs +++ b/src/TorchSharp/NN/Module.cs @@ -150,6 +150,9 @@ protected virtual void Dispose(bool disposing) /// The target element type. protected internal virtual Module _to(Device device, ScalarType dtype) { + if (!dtype.IsFloatingPoint() && !dtype.IsComplex()) + throw new ArgumentException($"nn.Module.to only accepts floating point or complex types, but got desired dtype={dtype.ToString()}"); + if (device.type != DeviceType.CUDA) { device = new Device(device.type, -1); }; if (device.type == DeviceType.CUDA && !torch.cuda.is_available()) throw new InvalidOperationException("CUDA is not available."); @@ -198,6 +201,9 @@ protected internal virtual Module _to(DeviceType deviceType, int deviceIndex = - /// protected internal virtual Module _to(ScalarType dtype) { + if (!dtype.IsFloatingPoint() && !dtype.IsComplex()) + throw new ArgumentException($"nn.Module.to only accepts floating point or complex types, but got desired dtype={dtype.ToString()}"); + THSNN_Module_to_dtype(handle, (sbyte)dtype); CheckForErrors(); @@ -238,11 +244,14 @@ private void _toEpilog(ScalarType? dtype, Device device) // Store the requires_grad flag ahead, since we dispose the parameter after moving bool requiresGrad = param.requires_grad; Parameter p; + ScalarType paramType = + dtype != null && (param.dtype.IsFloatingPoint() || param.dtype.IsComplex()) ? dtype.Value : param.dtype; + // When moving the parameter, we don't want the autograd to track this movement on the graph. // In addition, we need the new tensor to be a leaf to accumulate gradients, so if we didn't // disable grad we would need to call .detach() on the moved tensor. using (var d = torch.no_grad()) - p = new Parameter(param.to(dtype ?? param.dtype, device ?? param.device, disposeAfter: true), requiresGrad); + p = new Parameter(param.to(paramType, device ?? param.device, disposeAfter: true), requiresGrad); ConditionallyRegisterParameter(name, p); // If this parameter is a field, set it @@ -253,8 +262,11 @@ private void _toEpilog(ScalarType? dtype, Device device) foreach (var (name, buffer) in named_buffers(false).ToList()) { if (!buffer.toWillCopy(dtype ?? buffer.dtype, device ?? buffer.device)) continue; + ScalarType bufferType = + dtype != null && (buffer.dtype.IsFloatingPoint() || buffer.dtype.IsComplex()) ? dtype.Value : buffer.dtype; + // Buffers don't get grads so we don't need to detach them afterwards - var t = buffer.to(dtype ?? buffer.dtype, device ?? buffer.device, disposeAfter: true); + var t = buffer.to(bufferType, device ?? buffer.device, disposeAfter: true); ConditionallyRegisterBuffer(name, t); if (fieldsByComponentName.TryGetValue(name, out var field)) diff --git a/test/TorchSharpTest/NN.cs b/test/TorchSharpTest/NN.cs index c07ec02e0..22b226231 100644 --- a/test/TorchSharpTest/NN.cs +++ b/test/TorchSharpTest/NN.cs @@ -3014,6 +3014,20 @@ public void TestDatatypeTo() if (lin2.bias is not null) Assert.Equal(torch.float64, lin2.bias!.dtype); } + [Fact] + public void TestDatatypeToFail() + { + var mod = new TestModule3(); + mod.ValidateDtype(torch.float32); + Assert.Multiple( + () => Assert.Throws(() => mod.to(torch.uint8)), + () => Assert.Throws(() => mod.to(torch.int8)), + () => Assert.Throws(() => mod.to(torch.int16)), + () => Assert.Throws(() => mod.to(torch.int32)), + () => Assert.Throws(() => mod.to(torch.int64)) + ); + } + [Fact] public void TestDeviceTo() {