Skip to content

Commit 089511d

Browse files
Merge pull request #1182 from NiklasGustafsson/bugs
Validate the target data type when converting a Module using `to()`
2 parents 6efd3e9 + 49746af commit 089511d

File tree

3 files changed

+29
-2
lines changed

3 files changed

+29
-2
lines changed

RELEASENOTES.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ __Bug Fixes__:
2727
#1174 : Loading CUDA tensor from stream threw an error<br/>
2828
#1179 : Calling `Module.to()` with the `ParameterList` and `ParameterDict` module didn't move the parameters stored in the field.<br/>
2929
#1148 : Calling `Module.to()` shouldn't be differentiable<br/>
30+
#1180 : Module.to(ScalarType) has restrictions in PyTorch which aren't restricted in TorchSharp.<br/>
3031

3132
## NuGet Version 0.101.2
3233

src/TorchSharp/NN/Module.cs

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,9 @@ protected virtual void Dispose(bool disposing)
150150
/// <param name="dtype">The target element type.</param>
151151
protected internal virtual Module _to(Device device, ScalarType dtype)
152152
{
153+
if (!dtype.IsFloatingPoint() && !dtype.IsComplex())
154+
throw new ArgumentException($"nn.Module.to only accepts floating point or complex types, but got desired dtype={dtype.ToString()}");
155+
153156
if (device.type != DeviceType.CUDA) { device = new Device(device.type, -1); };
154157

155158
if (device.type == DeviceType.CUDA && !torch.cuda.is_available()) throw new InvalidOperationException("CUDA is not available.");
@@ -198,6 +201,9 @@ protected internal virtual Module _to(DeviceType deviceType, int deviceIndex = -
198201
/// <returns></returns>
199202
protected internal virtual Module _to(ScalarType dtype)
200203
{
204+
if (!dtype.IsFloatingPoint() && !dtype.IsComplex())
205+
throw new ArgumentException($"nn.Module.to only accepts floating point or complex types, but got desired dtype={dtype.ToString()}");
206+
201207
THSNN_Module_to_dtype(handle, (sbyte)dtype);
202208
CheckForErrors();
203209

@@ -238,11 +244,14 @@ private void _toEpilog(ScalarType? dtype, Device device)
238244
// Store the requires_grad flag ahead, since we dispose the parameter after moving
239245
bool requiresGrad = param.requires_grad;
240246
Parameter p;
247+
ScalarType paramType =
248+
dtype != null && (param.dtype.IsFloatingPoint() || param.dtype.IsComplex()) ? dtype.Value : param.dtype;
249+
241250
// When moving the parameter, we don't want the autograd to track this movement on the graph.
242251
// In addition, we need the new tensor to be a leaf to accumulate gradients, so if we didn't
243252
// disable grad we would need to call .detach() on the moved tensor.
244253
using (var d = torch.no_grad())
245-
p = new Parameter(param.to(dtype ?? param.dtype, device ?? param.device, disposeAfter: true), requiresGrad);
254+
p = new Parameter(param.to(paramType, device ?? param.device, disposeAfter: true), requiresGrad);
246255
ConditionallyRegisterParameter(name, p);
247256

248257
// If this parameter is a field, set it
@@ -253,8 +262,11 @@ private void _toEpilog(ScalarType? dtype, Device device)
253262
foreach (var (name, buffer) in named_buffers(false).ToList()) {
254263
if (!buffer.toWillCopy(dtype ?? buffer.dtype, device ?? buffer.device)) continue;
255264

265+
ScalarType bufferType =
266+
dtype != null && (buffer.dtype.IsFloatingPoint() || buffer.dtype.IsComplex()) ? dtype.Value : buffer.dtype;
267+
256268
// Buffers don't get grads so we don't need to detach them afterwards
257-
var t = buffer.to(dtype ?? buffer.dtype, device ?? buffer.device, disposeAfter: true);
269+
var t = buffer.to(bufferType, device ?? buffer.device, disposeAfter: true);
258270
ConditionallyRegisterBuffer(name, t);
259271

260272
if (fieldsByComponentName.TryGetValue(name, out var field))

test/TorchSharpTest/NN.cs

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3014,6 +3014,20 @@ public void TestDatatypeTo()
30143014
if (lin2.bias is not null) Assert.Equal(torch.float64, lin2.bias!.dtype);
30153015
}
30163016

3017+
[Fact]
3018+
public void TestDatatypeToFail()
3019+
{
3020+
var mod = new TestModule3();
3021+
mod.ValidateDtype(torch.float32);
3022+
Assert.Multiple(
3023+
() => Assert.Throws<ArgumentException>(() => mod.to(torch.uint8)),
3024+
() => Assert.Throws<ArgumentException>(() => mod.to(torch.int8)),
3025+
() => Assert.Throws<ArgumentException>(() => mod.to(torch.int16)),
3026+
() => Assert.Throws<ArgumentException>(() => mod.to(torch.int32)),
3027+
() => Assert.Throws<ArgumentException>(() => mod.to(torch.int64))
3028+
);
3029+
}
3030+
30173031
[Fact]
30183032
public void TestDeviceTo()
30193033
{

0 commit comments

Comments
 (0)