diff --git a/RELEASENOTES.md b/RELEASENOTES.md
index 222e467bc..afc1b5bb2 100644
--- a/RELEASENOTES.md
+++ b/RELEASENOTES.md
@@ -27,6 +27,7 @@ __Bug Fixes__:
#1174 : Loading CUDA tensor from stream threw an error
#1179 : Calling `Module.to()` with the `ParameterList` and `ParameterDict` module didn't move the parameters stored in the field.
#1148 : Calling `Module.to()` shouldn't be differentiable
+#1126 : Calling `ScriptModule.to()` doesn't move attributes
#1180 : Module.to(ScalarType) has restrictions in PyTorch which aren't restricted in TorchSharp.
## NuGet Version 0.101.2
diff --git a/src/Native/LibTorchSharp/THSJIT.cpp b/src/Native/LibTorchSharp/THSJIT.cpp
index 886006204..097b0dd81 100644
--- a/src/Native/LibTorchSharp/THSJIT.cpp
+++ b/src/Native/LibTorchSharp/THSJIT.cpp
@@ -151,11 +151,11 @@ void THSJIT_Module_named_buffers(const JITModule module,
}
}
-void THSJIT_Module_named_attributes(const JITModule module,
+void THSJIT_Module_named_attributes(const JITModule module, bool recurse,
Tensor* (*allocator)(size_t length),
const char** (*allocator2)(size_t length))
{
- auto attributes = (*module)->named_attributes();
+ auto attributes = (*module)->named_attributes(recurse);
Tensor* result = allocator(attributes.size());
const char** names = allocator2(attributes.size());
int i = 0;
@@ -170,6 +170,11 @@ void THSJIT_Module_named_attributes(const JITModule module,
}
}
+void THSJIT_Module_set_attribute(const JITModule module, const char *name, Tensor tensor)
+{
+ CATCH((*module)->setattr(name, *tensor););
+}
+
JITMethod THSJIT_Module_get_method(const JITModule module, const char* name)
{
auto method = (*module)->get_method(name);
diff --git a/src/Native/LibTorchSharp/THSJIT.h b/src/Native/LibTorchSharp/THSJIT.h
index 40981c468..c6ebe8d65 100644
--- a/src/Native/LibTorchSharp/THSJIT.h
+++ b/src/Native/LibTorchSharp/THSJIT.h
@@ -78,10 +78,12 @@ EXPORT_API(void) THSJIT_Module_named_buffers(const JITModule module,
Tensor* (*allocator)(size_t length),
const char** (*allocator2)(size_t length));
-EXPORT_API(void) THSJIT_Module_named_attributes(const JITModule module,
+EXPORT_API(void) THSJIT_Module_named_attributes(const JITModule module, bool recurse,
Tensor* (*allocator)(size_t length),
const char** (*allocator2)(size_t length));
+EXPORT_API(void) THSJIT_Module_set_attribute(const JITModule module, const char* name, Tensor tensor);
+
EXPORT_API(int) THSJIT_Method_num_inputs(const JITMethod method);
EXPORT_API(void) THSJIT_Method_dispose(const JITMethod method);
diff --git a/src/TorchSharp/JIT/ScriptModule.cs b/src/TorchSharp/JIT/ScriptModule.cs
index 8deca127f..862346607 100644
--- a/src/TorchSharp/JIT/ScriptModule.cs
+++ b/src/TorchSharp/JIT/ScriptModule.cs
@@ -46,11 +46,11 @@ protected override (string name, Tensor buffer)[] _named_buffers()
return ptrArray.Select((x, i) => (Marshal.PtrToStringAnsi(strArray[i]), new Tensor(x))).ToArray();
}
- public (string name, Tensor buffer)[] named_attributes()
+ public (string name, Tensor buffer)[] named_attributes(bool recurse = true)
{
using var pa = new PinnedArray();
using var sa = new PinnedArray();
- THSJIT_Module_named_attributes(handle, pa.CreateArray, sa.CreateArray);
+ THSJIT_Module_named_attributes(handle, recurse, pa.CreateArray, sa.CreateArray);
CheckForErrors();
var ptrArray = pa.Array;
var strArray = sa.Array;
@@ -58,6 +58,12 @@ protected override (string name, Tensor buffer)[] _named_buffers()
return ptrArray.Select((x, i) => (Marshal.PtrToStringAnsi(strArray[i]), new Tensor(x))).ToArray();
}
+ public void set_attribute(string name, Tensor buffer)
+ {
+ THSJIT_Module_set_attribute(handle, name, buffer.Handle);
+ CheckForErrors();
+ }
+
///
/// Returns an enumerable of all modules in the network, yielding both the name of the module as well as the module itself.
///
@@ -149,7 +155,7 @@ protected internal override nn.Module _to(Device device, ScalarType dtype)
CheckForErrors();
_toEpilog(device, dtype);
-
+ _toScriptEpilog(device, dtype);
return this;
}
@@ -172,6 +178,7 @@ protected internal override nn.Module _to(DeviceType deviceType, int deviceIndex
CheckForErrors();
_toEpilog(deviceType, deviceIndex);
+ _toScriptEpilog(deviceType, deviceIndex);
}
Debug.Assert(_deviceType == DeviceType.CUDA || _deviceIndex == -1);
@@ -192,10 +199,35 @@ protected internal override nn.Module _to(ScalarType dtype)
CheckForErrors();
_toEpilog(dtype);
+ _toScriptEpilog(dtype);
return this;
}
+ protected void _toScriptEpilog(ScalarType dtype)
+ {
+ _toScriptEpilog(dtype, null);
+ }
+
+ protected void _toScriptEpilog(Device device, ScalarType dtype)
+ {
+ _toScriptEpilog(dtype, device);
+ }
+
+ protected void _toScriptEpilog(DeviceType deviceType, int deviceIndex)
+ {
+ _toScriptEpilog(null, new Device(deviceType, deviceIndex));
+ }
+
+ private void _toScriptEpilog(ScalarType? dtype, Device device)
+ {
+ foreach (var (name, buffer) in named_attributes(recurse: false)) {
+ if (name is null || !buffer.toWillCopy(dtype ?? buffer.dtype, device ?? buffer.device)) continue;
+
+ set_attribute(name, buffer.to(dtype ?? buffer.dtype, device ?? buffer.device, disposeAfter: true));
+ }
+ }
+
#if false // These functions "work," but the native code doesn't seem to find any interesting information.
public Type GetInputType(int index)
diff --git a/src/TorchSharp/PInvoke/LibTorchSharp.THSJIT.cs b/src/TorchSharp/PInvoke/LibTorchSharp.THSJIT.cs
index 18e27347e..07d09c1e7 100644
--- a/src/TorchSharp/PInvoke/LibTorchSharp.THSJIT.cs
+++ b/src/TorchSharp/PInvoke/LibTorchSharp.THSJIT.cs
@@ -27,7 +27,10 @@ internal static partial class NativeMethods
internal static extern void THSJIT_Module_named_buffers(torch.nn.Module.HType module, AllocatePinnedArray allocator1, AllocatePinnedArray allocator2);
[DllImport("LibTorchSharp")]
- internal static extern void THSJIT_Module_named_attributes(torch.nn.Module.HType module, AllocatePinnedArray allocator1, AllocatePinnedArray allocator2);
+ internal static extern void THSJIT_Module_named_attributes(torch.nn.Module.HType module, [MarshalAs(UnmanagedType.U1)] bool recurse, AllocatePinnedArray allocator1, AllocatePinnedArray allocator2);
+
+ [DllImport("LibTorchSharp")]
+ internal static extern void THSJIT_Module_set_attribute(torch.nn.Module.HType module, [MarshalAs(UnmanagedType.LPStr)] string name, IntPtr tensor);
[DllImport("LibTorchSharp")]
internal static extern void THSJIT_Module_named_modules(torch.nn.Module.HType module, AllocatePinnedArray allocator1, AllocatePinnedArray allocator2);
diff --git a/test/TorchSharpTest/TestTorchTensorBugs.cs b/test/TorchSharpTest/TestTorchTensorBugs.cs
index fab1ae36c..30bd49120 100644
--- a/test/TorchSharpTest/TestTorchTensorBugs.cs
+++ b/test/TorchSharpTest/TestTorchTensorBugs.cs
@@ -1395,7 +1395,7 @@ public void Validate1047_2()
Assert.Equal(iterations_1047, xx);
}
- [Fact(Skip = "Work in progress")]
+ [Fact]
public void Validate1126()
{
var device = torch.cuda.is_available() ? torch.CUDA : torch.CPU;
@@ -1405,12 +1405,9 @@ public void Validate1126()
var model = torch.jit.load("shakespeare.pt.zip").to(device);
- var attributes = model.named_attributes().ToArray();
-
var xs = torch.tensor(new int[blockSize].Select(_ => new Random().Next(vocabSize)).ToArray(), device: device, dtype: torch.int64).unsqueeze(0);
var ys = model.forward(xs);
-
}
[Fact]