diff --git a/RELEASENOTES.md b/RELEASENOTES.md
index f5527428a..222e467bc 100644
--- a/RELEASENOTES.md
+++ b/RELEASENOTES.md
@@ -27,6 +27,7 @@ __Bug Fixes__:
#1174 : Loading CUDA tensor from stream threw an error
#1179 : Calling `Module.to()` with the `ParameterList` and `ParameterDict` module didn't move the parameters stored in the field.
#1148 : Calling `Module.to()` shouldn't be differentiable
+#1180 : Module.to(ScalarType) has restrictions in PyTorch which aren't restricted in TorchSharp.
## NuGet Version 0.101.2
diff --git a/src/TorchSharp/NN/Module.cs b/src/TorchSharp/NN/Module.cs
index 031e415ac..0f174b763 100644
--- a/src/TorchSharp/NN/Module.cs
+++ b/src/TorchSharp/NN/Module.cs
@@ -150,6 +150,9 @@ protected virtual void Dispose(bool disposing)
/// The target element type.
protected internal virtual Module _to(Device device, ScalarType dtype)
{
+ if (!dtype.IsFloatingPoint() && !dtype.IsComplex())
+ throw new ArgumentException($"nn.Module.to only accepts floating point or complex types, but got desired dtype={dtype.ToString()}");
+
if (device.type != DeviceType.CUDA) { device = new Device(device.type, -1); };
if (device.type == DeviceType.CUDA && !torch.cuda.is_available()) throw new InvalidOperationException("CUDA is not available.");
@@ -198,6 +201,9 @@ protected internal virtual Module _to(DeviceType deviceType, int deviceIndex = -
///
protected internal virtual Module _to(ScalarType dtype)
{
+ if (!dtype.IsFloatingPoint() && !dtype.IsComplex())
+ throw new ArgumentException($"nn.Module.to only accepts floating point or complex types, but got desired dtype={dtype.ToString()}");
+
THSNN_Module_to_dtype(handle, (sbyte)dtype);
CheckForErrors();
@@ -238,11 +244,14 @@ private void _toEpilog(ScalarType? dtype, Device device)
// Store the requires_grad flag ahead, since we dispose the parameter after moving
bool requiresGrad = param.requires_grad;
Parameter p;
+ ScalarType paramType =
+ dtype != null && (param.dtype.IsFloatingPoint() || param.dtype.IsComplex()) ? dtype.Value : param.dtype;
+
// When moving the parameter, we don't want the autograd to track this movement on the graph.
// In addition, we need the new tensor to be a leaf to accumulate gradients, so if we didn't
// disable grad we would need to call .detach() on the moved tensor.
using (var d = torch.no_grad())
- p = new Parameter(param.to(dtype ?? param.dtype, device ?? param.device, disposeAfter: true), requiresGrad);
+ p = new Parameter(param.to(paramType, device ?? param.device, disposeAfter: true), requiresGrad);
ConditionallyRegisterParameter(name, p);
// If this parameter is a field, set it
@@ -253,8 +262,11 @@ private void _toEpilog(ScalarType? dtype, Device device)
foreach (var (name, buffer) in named_buffers(false).ToList()) {
if (!buffer.toWillCopy(dtype ?? buffer.dtype, device ?? buffer.device)) continue;
+ ScalarType bufferType =
+ dtype != null && (buffer.dtype.IsFloatingPoint() || buffer.dtype.IsComplex()) ? dtype.Value : buffer.dtype;
+
// Buffers don't get grads so we don't need to detach them afterwards
- var t = buffer.to(dtype ?? buffer.dtype, device ?? buffer.device, disposeAfter: true);
+ var t = buffer.to(bufferType, device ?? buffer.device, disposeAfter: true);
ConditionallyRegisterBuffer(name, t);
if (fieldsByComponentName.TryGetValue(name, out var field))
diff --git a/test/TorchSharpTest/NN.cs b/test/TorchSharpTest/NN.cs
index c07ec02e0..22b226231 100644
--- a/test/TorchSharpTest/NN.cs
+++ b/test/TorchSharpTest/NN.cs
@@ -3014,6 +3014,20 @@ public void TestDatatypeTo()
if (lin2.bias is not null) Assert.Equal(torch.float64, lin2.bias!.dtype);
}
+ [Fact]
+ public void TestDatatypeToFail()
+ {
+ var mod = new TestModule3();
+ mod.ValidateDtype(torch.float32);
+ Assert.Multiple(
+ () => Assert.Throws(() => mod.to(torch.uint8)),
+ () => Assert.Throws(() => mod.to(torch.int8)),
+ () => Assert.Throws(() => mod.to(torch.int16)),
+ () => Assert.Throws(() => mod.to(torch.int32)),
+ () => Assert.Throws(() => mod.to(torch.int64))
+ );
+ }
+
[Fact]
public void TestDeviceTo()
{