Skip to content

Commit 6efd3e9

Browse files
Merge pull request #1181 from shaltielshmid/fix-module-to-2
Fixed `Module.to()` bugs with `ParameterList` and `ParameterDict`, and with autograd tracking movements between CPU & GPU
2 parents c48ca33 + c0c598f commit 6efd3e9

File tree

6 files changed

+284
-160
lines changed

6 files changed

+284
-160
lines changed

RELEASENOTES.md

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,13 @@ All distribution classes now implement IDisposable.<br/>
2020

2121
__Bug Fixes__:
2222

23-
#1154 : `mu_product` was not initialized in `NAdam` optimizer
24-
#1170 : Calling `torch.nn.rnn.utils.pad_packed_sequence` with a CUDA tensor and unsorted_indices threw an error
25-
#1172 : `optim.LoadStateDict` from an existing `StateDictionary` updated to make sure to copy value and to the right device.
26-
#1176 : When specific `Optimizers` load in a conditional tensor, made sure to copy to the right device.
27-
#1174 : Loading CUDA tensor from stream threw an error
23+
#1154 : `mu_product` was not initialized in `NAdam` optimizer<br/>
24+
#1170 : Calling `torch.nn.rnn.utils.pad_packed_sequence` with a CUDA tensor and unsorted_indices threw an error<br/>
25+
#1172 : `optim.LoadStateDict` from an existing `StateDictionary` updated to make sure to copy value and to the right device.<br/>
26+
#1176 : When specific `Optimizers` load in a conditional tensor, made sure to copy to the right device.<br/>
27+
#1174 : Loading CUDA tensor from stream threw an error<br/>
28+
#1179 : Calling `Module.to()` with the `ParameterList` and `ParameterDict` module didn't move the parameters stored in the field.<br/>
29+
#1148 : Calling `Module.to()` shouldn't be differentiable<br/>
2830

2931
## NuGet Version 0.101.2
3032

src/TorchSharp/NN/Module.cs

Lines changed: 45 additions & 155 deletions
Original file line numberDiff line numberDiff line change
@@ -163,66 +163,6 @@ protected internal virtual Module _to(Device device, ScalarType dtype)
163163
return this;
164164
}
165165

166-
protected void _toEpilog(Device device, ScalarType dtype)
167-
{
168-
foreach (var (_, sm) in named_children()) sm._to(device, dtype);
169-
170-
var alreadyHandled = new HashSet<IntPtr>();
171-
172-
foreach (var field in GetType().GetFields(BindingFlags.NonPublic | BindingFlags.Public | BindingFlags.Instance)) {
173-
174-
var fieldName = field.ComponentName();
175-
var value = field.GetValue(this);
176-
177-
switch (value) {
178-
// This order in which these cases are arranged is significant.
179-
case Parameter param when dtype == param.dtype && device.type == param.device_type && device.index == param.device_index:
180-
alreadyHandled.Add(param.handle);
181-
continue;
182-
183-
case Parameter param: {
184-
var t = param.to(dtype, device);
185-
t.retain_grad();
186-
var p = new Parameter(t, param.requires_grad);
187-
field.SetValue(this, p);
188-
ConditionallyRegisterParameter(fieldName, p);
189-
alreadyHandled.Add(p.handle);
190-
break;
191-
}
192-
193-
case Tensor tensor when (device.type != tensor.device_type || device.index != tensor.device_index): {
194-
var t = tensor.to(dtype, device);
195-
field.SetValue(this, t);
196-
ConditionallyRegisterBuffer(fieldName, t);
197-
alreadyHandled.Add(t.handle);
198-
break;
199-
}
200-
201-
case Tensor tensor:
202-
alreadyHandled.Add(tensor.handle);
203-
break;
204-
}
205-
}
206-
207-
foreach (var (name, param) in named_parameters(false).ToList()) {
208-
if (alreadyHandled.Contains(param.handle)) continue;
209-
var t = param.to(dtype, device);
210-
ConditionallyRegisterParameter(name, t);
211-
}
212-
213-
foreach (var (name, buffer) in named_buffers(false).ToList()) {
214-
if (alreadyHandled.Contains(buffer.handle)) continue;
215-
var t = buffer.to(dtype, device);
216-
ConditionallyRegisterBuffer(name, t);
217-
}
218-
219-
_deviceType = device.type;
220-
_deviceIndex = device.index;
221-
222-
Debug.Assert(_deviceType == DeviceType.CUDA || _deviceIndex == -1);
223-
}
224-
225-
226166
/// <summary>
227167
/// Moves the parameters and buffers.
228168
/// </summary>
@@ -249,63 +189,6 @@ protected internal virtual Module _to(DeviceType deviceType, int deviceIndex = -
249189
return this;
250190
}
251191

252-
protected void _toEpilog(DeviceType deviceType, int deviceIndex)
253-
{
254-
foreach (var (_, sm) in named_children()) sm._to(deviceType, deviceIndex);
255-
256-
var alreadyHandled = new HashSet<IntPtr>();
257-
258-
foreach (var field in GetType().GetFields(BindingFlags.NonPublic | BindingFlags.Public | BindingFlags.Instance)) {
259-
260-
var fieldName = field.ComponentName();
261-
var value = field.GetValue(this);
262-
263-
switch (value) {
264-
// This order in which these cases are arranged is significant.
265-
case Parameter param when deviceType == param.device_type && deviceIndex == param.device_index:
266-
alreadyHandled.Add(param.handle);
267-
continue;
268-
269-
case Parameter param: {
270-
var t = param.to(deviceType, deviceIndex);
271-
t.retain_grad();
272-
var p = new Parameter(t, param.requires_grad);
273-
field.SetValue(this, p);
274-
ConditionallyRegisterParameter(fieldName, p);
275-
alreadyHandled.Add(p.handle);
276-
break;
277-
}
278-
279-
case Tensor tensor when (deviceType != tensor.device_type || deviceIndex != tensor.device_index): {
280-
var t = tensor.to(deviceType, deviceIndex);
281-
field.SetValue(this, t);
282-
ConditionallyRegisterBuffer(fieldName, t);
283-
alreadyHandled.Add(t.handle);
284-
break;
285-
}
286-
287-
case Tensor tensor:
288-
alreadyHandled.Add(tensor.handle);
289-
break;
290-
}
291-
}
292-
293-
foreach (var (name, param) in named_parameters(false).ToList()) {
294-
if (alreadyHandled.Contains(param.handle)) continue;
295-
var t = param.to(deviceType, deviceIndex);
296-
ConditionallyRegisterParameter(name, t);
297-
}
298-
299-
foreach (var (name, buffer) in named_buffers(false).ToList()) {
300-
if (alreadyHandled.Contains(buffer.handle)) continue;
301-
var t = buffer.to(deviceType, deviceIndex);
302-
ConditionallyRegisterBuffer(name, t);
303-
}
304-
305-
_deviceType = deviceType;
306-
_deviceIndex = deviceIndex;
307-
}
308-
309192
private DeviceType _deviceType = DeviceType.CPU;
310193
private int _deviceIndex = -1;
311194

@@ -325,55 +208,62 @@ protected internal virtual Module _to(ScalarType dtype)
325208

326209
protected void _toEpilog(ScalarType dtype)
327210
{
328-
foreach (var (_, sm) in named_children()) sm._to(dtype);
211+
_toEpilog(dtype, null);
212+
}
329213

330-
var alreadyHandled = new HashSet<IntPtr>();
214+
protected void _toEpilog(Device device, ScalarType dtype)
215+
{
216+
_toEpilog(dtype, device);
217+
}
331218

332-
foreach (var field in GetType().GetFields(BindingFlags.NonPublic | BindingFlags.Public | BindingFlags.Instance)) {
219+
protected void _toEpilog(DeviceType deviceType, int deviceIndex)
220+
{
221+
_toEpilog(null, new Device(deviceType, deviceIndex));
222+
}
333223

334-
var fieldName = field.ComponentName();
335-
var value = field.GetValue(this);
224+
private void _toEpilog(ScalarType? dtype, Device device)
225+
{
226+
foreach (var (_, sm) in named_children()) {
227+
if (device is null) sm._to(dtype.Value);
228+
else if (dtype is null) sm._to(device.type, device.index);
229+
else sm._to(device, dtype.Value);
230+
}
336231

337-
switch (value) {
338-
// This order in which these cases are arranged is significant.
339-
case Parameter param when dtype == param.dtype:
340-
alreadyHandled.Add(param.handle);
341-
continue;
342-
343-
case Parameter param: {
344-
var t = param.to(dtype);
345-
t.retain_grad();
346-
var p = new Parameter(t, param.requires_grad);
347-
field.SetValue(this, p);
348-
ConditionallyRegisterParameter(fieldName, p);
349-
alreadyHandled.Add(p.handle);
350-
break;
351-
}
232+
var fieldsByComponentName = GetType().GetFields(BindingFlags.NonPublic | BindingFlags.Public | BindingFlags.Instance)
233+
.ToDictionary(field => field.ComponentName());
352234

353-
case Tensor tensor when dtype == tensor.dtype:
354-
alreadyHandled.Add(tensor.handle);
355-
continue;
235+
foreach (var (name, param) in named_parameters(false).ToList()) {
236+
if (!param.toWillCopy(dtype ?? param.dtype, device ?? param.device)) continue;
356237

357-
case Tensor tensor: {
358-
var t = tensor.to(dtype);
359-
field.SetValue(this, t);
360-
ConditionallyRegisterBuffer(fieldName, t);
361-
alreadyHandled.Add(t.handle);
362-
break;
363-
}
364-
}
365-
}
238+
// Store the requires_grad flag ahead, since we dispose the parameter after moving
239+
bool requiresGrad = param.requires_grad;
240+
Parameter p;
241+
// When moving the parameter, we don't want the autograd to track this movement on the graph.
242+
// In addition, we need the new tensor to be a leaf to accumulate gradients, so if we didn't
243+
// disable grad we would need to call .detach() on the moved tensor.
244+
using (var d = torch.no_grad())
245+
p = new Parameter(param.to(dtype ?? param.dtype, device ?? param.device, disposeAfter: true), requiresGrad);
246+
ConditionallyRegisterParameter(name, p);
366247

367-
foreach (var (name, param) in named_parameters(false).ToList()) {
368-
if (alreadyHandled.Contains(param.handle)) continue;
369-
var t = param.to(dtype);
370-
ConditionallyRegisterParameter(name, t);
248+
// If this parameter is a field, set it
249+
if (fieldsByComponentName.TryGetValue(name, out var field))
250+
field.SetValue(this, p);
371251
}
372252

373253
foreach (var (name, buffer) in named_buffers(false).ToList()) {
374-
if (alreadyHandled.Contains(buffer.handle)) continue;
375-
var t = buffer.to(dtype);
254+
if (!buffer.toWillCopy(dtype ?? buffer.dtype, device ?? buffer.device)) continue;
255+
256+
// Buffers don't get grads so we don't need to detach them afterwards
257+
var t = buffer.to(dtype ?? buffer.dtype, device ?? buffer.device, disposeAfter: true);
376258
ConditionallyRegisterBuffer(name, t);
259+
260+
if (fieldsByComponentName.TryGetValue(name, out var field))
261+
field.SetValue(this, t);
262+
}
263+
264+
if (device is not null) {
265+
_deviceType = device.type;
266+
_deviceIndex = device.index;
377267
}
378268
}
379269

src/TorchSharp/NN/ParameterDict.cs

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,37 @@ protected override void RegisterComponents()
6060

6161
private bool _registered = false;
6262

63+
protected internal override Module _to(DeviceType deviceType, int deviceIndex = -1)
64+
{
65+
base._to(deviceType, deviceIndex);
66+
_toEpilog();
67+
return this;
68+
}
69+
70+
protected internal override Module _to(torch.Device device, torch.ScalarType dtype)
71+
{
72+
base._to(device, dtype);
73+
_toEpilog();
74+
return this;
75+
}
76+
77+
protected internal override Module _to(torch.ScalarType dtype)
78+
{
79+
base._to(dtype);
80+
_toEpilog();
81+
return this;
82+
}
83+
84+
void _toEpilog()
85+
{
86+
for (int i = 0; i < _list.Count; i++) {
87+
string name = _list[i].Item1;
88+
var param = base.get_parameter(name);
89+
_list[i] = (name, param);
90+
_dict[name] = param;
91+
}
92+
}
93+
6394
/// <summary>
6495
/// Return the ParameterDict values.
6596
/// </summary>

src/TorchSharp/NN/ParameterList.cs

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,35 @@ protected override void RegisterComponents()
3333
_registered = true;
3434
}
3535

36+
37+
protected internal override Module _to(DeviceType deviceType, int deviceIndex = -1)
38+
{
39+
base._to(deviceType, deviceIndex);
40+
_toEpilog();
41+
return this;
42+
}
43+
44+
protected internal override Module _to(torch.Device device, torch.ScalarType dtype)
45+
{
46+
base._to(device, dtype);
47+
_toEpilog();
48+
return this;
49+
}
50+
51+
protected internal override Module _to(torch.ScalarType dtype)
52+
{
53+
base._to(dtype);
54+
_toEpilog();
55+
return this;
56+
}
57+
58+
void _toEpilog()
59+
{
60+
for (int i = 0; i < _list.Count; i++) {
61+
_list[i] = base.get_parameter($"{i}");
62+
}
63+
}
64+
3665
public override IEnumerable<(string name, Parameter parameter)> named_parameters(bool recurse = true)
3766
{
3867
return Enumerable.Range(0, _list.Count).Select(i => ($"{i}", _list[i]));

src/TorchSharp/Tensor/TensorExtensionMethods.cs

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -704,5 +704,66 @@ public static long TotalSize(this IEnumerable<long> shape)
704704
}
705705
return result;
706706
}
707+
708+
/// <summary>
709+
/// Checks if calling `Tensor.to()` with the parameters will result in actually copying the tensor.
710+
/// </summary>
711+
/// <param name="tensor">The tensor</param>
712+
/// <param name="device">The device to move to</param>
713+
/// <returns>True if the tensor will be copied</returns>
714+
internal static bool toWillCopy(this Tensor tensor, Device device)
715+
{
716+
return tensor.toWillCopy(device.type, device.index);
717+
}
718+
719+
/// <summary>
720+
/// Checks if calling `Tensor.to()` with the parameters will result in actually copying the tensor.
721+
/// </summary>
722+
/// <param name="tensor">The tensor</param>
723+
/// <param name="deviceType">The device type to move to</param>
724+
/// <param name="deviceIndex">The device index to move to</param>
725+
/// <returns>True if the tensor will be copied</returns>
726+
internal static bool toWillCopy(this Tensor tensor, DeviceType deviceType, int deviceIndex)
727+
{
728+
return tensor.device_index != deviceIndex || tensor.device_type != deviceType;
729+
}
730+
731+
/// <summary>
732+
/// Checks if calling `Tensor.to()` with the parameters will result in actually copying the tensor.
733+
/// </summary>
734+
/// <param name="tensor">The tensor</param>
735+
/// <param name="dtype">The dtype to move to</param>
736+
/// <returns>True if the tensor will be copied</returns>
737+
internal static bool toWillCopy(this Tensor tensor, ScalarType dtype)
738+
{
739+
return tensor.dtype != dtype;
740+
}
741+
742+
/// <summary>
743+
/// Checks if calling `Tensor.to()` with the parameters will result in actually copying the tensor.
744+
/// </summary>
745+
/// <param name="tensor">The tensor</param>
746+
/// <param name="dtype">The dtype to move to</param>
747+
/// <param name="device">The device to move to</param>
748+
/// <returns>True if the tensor will be copied</returns>
749+
internal static bool toWillCopy(this Tensor tensor, ScalarType dtype, Device device)
750+
{
751+
return tensor.toWillCopy(dtype, device.type, device.index);
752+
}
753+
754+
/// <summary>
755+
/// Checks if calling `Tensor.to()` with the parameters will result in actually copying the tensor.
756+
/// </summary>
757+
/// <param name="tensor">The tensor</param>
758+
/// <param name="dtype">The dtype to move to</param>
759+
/// <param name="deviceType">The device type to move to</param>
760+
/// <param name="deviceIndex">The device index to move to</param>
761+
/// <returns>True if the tensor will be copied</returns>
762+
internal static bool toWillCopy(this Tensor tensor, ScalarType dtype, DeviceType deviceType, int deviceIndex)
763+
{
764+
return tensor.device_index != deviceIndex || tensor.device_type != deviceType || tensor.dtype != dtype;
765+
}
766+
767+
707768
}
708769
}

0 commit comments

Comments
 (0)