diff --git a/RELEASENOTES.md b/RELEASENOTES.md index 222e467bc..afc1b5bb2 100644 --- a/RELEASENOTES.md +++ b/RELEASENOTES.md @@ -27,6 +27,7 @@ __Bug Fixes__: #1174 : Loading CUDA tensor from stream threw an error
#1179 : Calling `Module.to()` with the `ParameterList` and `ParameterDict` module didn't move the parameters stored in the field.
#1148 : Calling `Module.to()` shouldn't be differentiable
+#1126 : Calling `ScriptModule.to()` doesn't move attributes
#1180 : Module.to(ScalarType) has restrictions in PyTorch which aren't restricted in TorchSharp.
## NuGet Version 0.101.2 diff --git a/src/Native/LibTorchSharp/THSJIT.cpp b/src/Native/LibTorchSharp/THSJIT.cpp index 886006204..097b0dd81 100644 --- a/src/Native/LibTorchSharp/THSJIT.cpp +++ b/src/Native/LibTorchSharp/THSJIT.cpp @@ -151,11 +151,11 @@ void THSJIT_Module_named_buffers(const JITModule module, } } -void THSJIT_Module_named_attributes(const JITModule module, +void THSJIT_Module_named_attributes(const JITModule module, bool recurse, Tensor* (*allocator)(size_t length), const char** (*allocator2)(size_t length)) { - auto attributes = (*module)->named_attributes(); + auto attributes = (*module)->named_attributes(recurse); Tensor* result = allocator(attributes.size()); const char** names = allocator2(attributes.size()); int i = 0; @@ -170,6 +170,11 @@ void THSJIT_Module_named_attributes(const JITModule module, } } +void THSJIT_Module_set_attribute(const JITModule module, const char *name, Tensor tensor) +{ + CATCH((*module)->setattr(name, *tensor);); +} + JITMethod THSJIT_Module_get_method(const JITModule module, const char* name) { auto method = (*module)->get_method(name); diff --git a/src/Native/LibTorchSharp/THSJIT.h b/src/Native/LibTorchSharp/THSJIT.h index 40981c468..c6ebe8d65 100644 --- a/src/Native/LibTorchSharp/THSJIT.h +++ b/src/Native/LibTorchSharp/THSJIT.h @@ -78,10 +78,12 @@ EXPORT_API(void) THSJIT_Module_named_buffers(const JITModule module, Tensor* (*allocator)(size_t length), const char** (*allocator2)(size_t length)); -EXPORT_API(void) THSJIT_Module_named_attributes(const JITModule module, +EXPORT_API(void) THSJIT_Module_named_attributes(const JITModule module, bool recurse, Tensor* (*allocator)(size_t length), const char** (*allocator2)(size_t length)); +EXPORT_API(void) THSJIT_Module_set_attribute(const JITModule module, const char* name, Tensor tensor); + EXPORT_API(int) THSJIT_Method_num_inputs(const JITMethod method); EXPORT_API(void) THSJIT_Method_dispose(const JITMethod method); diff --git a/src/TorchSharp/JIT/ScriptModule.cs b/src/TorchSharp/JIT/ScriptModule.cs index 8deca127f..862346607 100644 --- a/src/TorchSharp/JIT/ScriptModule.cs +++ b/src/TorchSharp/JIT/ScriptModule.cs @@ -46,11 +46,11 @@ protected override (string name, Tensor buffer)[] _named_buffers() return ptrArray.Select((x, i) => (Marshal.PtrToStringAnsi(strArray[i]), new Tensor(x))).ToArray(); } - public (string name, Tensor buffer)[] named_attributes() + public (string name, Tensor buffer)[] named_attributes(bool recurse = true) { using var pa = new PinnedArray(); using var sa = new PinnedArray(); - THSJIT_Module_named_attributes(handle, pa.CreateArray, sa.CreateArray); + THSJIT_Module_named_attributes(handle, recurse, pa.CreateArray, sa.CreateArray); CheckForErrors(); var ptrArray = pa.Array; var strArray = sa.Array; @@ -58,6 +58,12 @@ protected override (string name, Tensor buffer)[] _named_buffers() return ptrArray.Select((x, i) => (Marshal.PtrToStringAnsi(strArray[i]), new Tensor(x))).ToArray(); } + public void set_attribute(string name, Tensor buffer) + { + THSJIT_Module_set_attribute(handle, name, buffer.Handle); + CheckForErrors(); + } + /// /// Returns an enumerable of all modules in the network, yielding both the name of the module as well as the module itself. /// @@ -149,7 +155,7 @@ protected internal override nn.Module _to(Device device, ScalarType dtype) CheckForErrors(); _toEpilog(device, dtype); - + _toScriptEpilog(device, dtype); return this; } @@ -172,6 +178,7 @@ protected internal override nn.Module _to(DeviceType deviceType, int deviceIndex CheckForErrors(); _toEpilog(deviceType, deviceIndex); + _toScriptEpilog(deviceType, deviceIndex); } Debug.Assert(_deviceType == DeviceType.CUDA || _deviceIndex == -1); @@ -192,10 +199,35 @@ protected internal override nn.Module _to(ScalarType dtype) CheckForErrors(); _toEpilog(dtype); + _toScriptEpilog(dtype); return this; } + protected void _toScriptEpilog(ScalarType dtype) + { + _toScriptEpilog(dtype, null); + } + + protected void _toScriptEpilog(Device device, ScalarType dtype) + { + _toScriptEpilog(dtype, device); + } + + protected void _toScriptEpilog(DeviceType deviceType, int deviceIndex) + { + _toScriptEpilog(null, new Device(deviceType, deviceIndex)); + } + + private void _toScriptEpilog(ScalarType? dtype, Device device) + { + foreach (var (name, buffer) in named_attributes(recurse: false)) { + if (name is null || !buffer.toWillCopy(dtype ?? buffer.dtype, device ?? buffer.device)) continue; + + set_attribute(name, buffer.to(dtype ?? buffer.dtype, device ?? buffer.device, disposeAfter: true)); + } + } + #if false // These functions "work," but the native code doesn't seem to find any interesting information. public Type GetInputType(int index) diff --git a/src/TorchSharp/PInvoke/LibTorchSharp.THSJIT.cs b/src/TorchSharp/PInvoke/LibTorchSharp.THSJIT.cs index 18e27347e..07d09c1e7 100644 --- a/src/TorchSharp/PInvoke/LibTorchSharp.THSJIT.cs +++ b/src/TorchSharp/PInvoke/LibTorchSharp.THSJIT.cs @@ -27,7 +27,10 @@ internal static partial class NativeMethods internal static extern void THSJIT_Module_named_buffers(torch.nn.Module.HType module, AllocatePinnedArray allocator1, AllocatePinnedArray allocator2); [DllImport("LibTorchSharp")] - internal static extern void THSJIT_Module_named_attributes(torch.nn.Module.HType module, AllocatePinnedArray allocator1, AllocatePinnedArray allocator2); + internal static extern void THSJIT_Module_named_attributes(torch.nn.Module.HType module, [MarshalAs(UnmanagedType.U1)] bool recurse, AllocatePinnedArray allocator1, AllocatePinnedArray allocator2); + + [DllImport("LibTorchSharp")] + internal static extern void THSJIT_Module_set_attribute(torch.nn.Module.HType module, [MarshalAs(UnmanagedType.LPStr)] string name, IntPtr tensor); [DllImport("LibTorchSharp")] internal static extern void THSJIT_Module_named_modules(torch.nn.Module.HType module, AllocatePinnedArray allocator1, AllocatePinnedArray allocator2); diff --git a/test/TorchSharpTest/TestTorchTensorBugs.cs b/test/TorchSharpTest/TestTorchTensorBugs.cs index fab1ae36c..30bd49120 100644 --- a/test/TorchSharpTest/TestTorchTensorBugs.cs +++ b/test/TorchSharpTest/TestTorchTensorBugs.cs @@ -1395,7 +1395,7 @@ public void Validate1047_2() Assert.Equal(iterations_1047, xx); } - [Fact(Skip = "Work in progress")] + [Fact] public void Validate1126() { var device = torch.cuda.is_available() ? torch.CUDA : torch.CPU; @@ -1405,12 +1405,9 @@ public void Validate1126() var model = torch.jit.load("shakespeare.pt.zip").to(device); - var attributes = model.named_attributes().ToArray(); - var xs = torch.tensor(new int[blockSize].Select(_ => new Random().Next(vocabSize)).ToArray(), device: device, dtype: torch.int64).unsqueeze(0); var ys = model.forward(xs); - } [Fact]