Skip to content
1 change: 1 addition & 0 deletions RELEASENOTES.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ __Bug Fixes__:
#1174 : Loading CUDA tensor from stream threw an error<br/>
#1179 : Calling `Module.to()` with the `ParameterList` and `ParameterDict` module didn't move the parameters stored in the field.<br/>
#1148 : Calling `Module.to()` shouldn't be differentiable<br/>
#1180 : Module.to(ScalarType) has restrictions in PyTorch which aren't restricted in TorchSharp.<br/>

## NuGet Version 0.101.2

Expand Down
16 changes: 14 additions & 2 deletions src/TorchSharp/NN/Module.cs
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,9 @@ protected virtual void Dispose(bool disposing)
/// <param name="dtype">The target element type.</param>
protected internal virtual Module _to(Device device, ScalarType dtype)
{
if (!dtype.IsFloatingPoint() && !dtype.IsComplex())
throw new ArgumentException($"nn.Module.to only accepts floating point or complex types, but got desired dtype={dtype.ToString()}");

if (device.type != DeviceType.CUDA) { device = new Device(device.type, -1); };

if (device.type == DeviceType.CUDA && !torch.cuda.is_available()) throw new InvalidOperationException("CUDA is not available.");
Expand Down Expand Up @@ -198,6 +201,9 @@ protected internal virtual Module _to(DeviceType deviceType, int deviceIndex = -
/// <returns></returns>
protected internal virtual Module _to(ScalarType dtype)
{
if (!dtype.IsFloatingPoint() && !dtype.IsComplex())
throw new ArgumentException($"nn.Module.to only accepts floating point or complex types, but got desired dtype={dtype.ToString()}");

THSNN_Module_to_dtype(handle, (sbyte)dtype);
CheckForErrors();

Expand Down Expand Up @@ -238,11 +244,14 @@ private void _toEpilog(ScalarType? dtype, Device device)
// Store the requires_grad flag ahead, since we dispose the parameter after moving
bool requiresGrad = param.requires_grad;
Parameter p;
ScalarType paramType =
dtype != null && (param.dtype.IsFloatingPoint() || param.dtype.IsComplex()) ? dtype.Value : param.dtype;

// When moving the parameter, we don't want the autograd to track this movement on the graph.
// In addition, we need the new tensor to be a leaf to accumulate gradients, so if we didn't
// disable grad we would need to call .detach() on the moved tensor.
using (var d = torch.no_grad())
p = new Parameter(param.to(dtype ?? param.dtype, device ?? param.device, disposeAfter: true), requiresGrad);
p = new Parameter(param.to(paramType, device ?? param.device, disposeAfter: true), requiresGrad);
ConditionallyRegisterParameter(name, p);

// If this parameter is a field, set it
Expand All @@ -253,8 +262,11 @@ private void _toEpilog(ScalarType? dtype, Device device)
foreach (var (name, buffer) in named_buffers(false).ToList()) {
if (!buffer.toWillCopy(dtype ?? buffer.dtype, device ?? buffer.device)) continue;

ScalarType bufferType =
dtype != null && (buffer.dtype.IsFloatingPoint() || buffer.dtype.IsComplex()) ? dtype.Value : buffer.dtype;

// Buffers don't get grads so we don't need to detach them afterwards
var t = buffer.to(dtype ?? buffer.dtype, device ?? buffer.device, disposeAfter: true);
var t = buffer.to(bufferType, device ?? buffer.device, disposeAfter: true);
ConditionallyRegisterBuffer(name, t);

if (fieldsByComponentName.TryGetValue(name, out var field))
Expand Down
14 changes: 14 additions & 0 deletions test/TorchSharpTest/NN.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3014,6 +3014,20 @@ public void TestDatatypeTo()
if (lin2.bias is not null) Assert.Equal(torch.float64, lin2.bias!.dtype);
}

[Fact]
public void TestDatatypeToFail()
{
var mod = new TestModule3();
mod.ValidateDtype(torch.float32);
Assert.Multiple(
() => Assert.Throws<ArgumentException>(() => mod.to(torch.uint8)),
() => Assert.Throws<ArgumentException>(() => mod.to(torch.int8)),
() => Assert.Throws<ArgumentException>(() => mod.to(torch.int16)),
() => Assert.Throws<ArgumentException>(() => mod.to(torch.int32)),
() => Assert.Throws<ArgumentException>(() => mod.to(torch.int64))
);
}

[Fact]
public void TestDeviceTo()
{
Expand Down