Skip to content

Commit 6d4d20e

Browse files
Merge pull request #1183 from shaltielshmid/jit-attributes-to
Added set_attribute + added script module to to include attributes
2 parents 089511d + 8aac38d commit 6d4d20e

File tree

6 files changed

+51
-11
lines changed

6 files changed

+51
-11
lines changed

RELEASENOTES.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ __Bug Fixes__:
2727
#1174 : Loading CUDA tensor from stream threw an error<br/>
2828
#1179 : Calling `Module.to()` with the `ParameterList` and `ParameterDict` module didn't move the parameters stored in the field.<br/>
2929
#1148 : Calling `Module.to()` shouldn't be differentiable<br/>
30+
#1126 : Calling `ScriptModule.to()` doesn't move attributes<br/>
3031
#1180 : Module.to(ScalarType) has restrictions in PyTorch which aren't restricted in TorchSharp.<br/>
3132

3233
## NuGet Version 0.101.2

src/Native/LibTorchSharp/THSJIT.cpp

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -151,11 +151,11 @@ void THSJIT_Module_named_buffers(const JITModule module,
151151
}
152152
}
153153

154-
void THSJIT_Module_named_attributes(const JITModule module,
154+
void THSJIT_Module_named_attributes(const JITModule module, bool recurse,
155155
Tensor* (*allocator)(size_t length),
156156
const char** (*allocator2)(size_t length))
157157
{
158-
auto attributes = (*module)->named_attributes();
158+
auto attributes = (*module)->named_attributes(recurse);
159159
Tensor* result = allocator(attributes.size());
160160
const char** names = allocator2(attributes.size());
161161
int i = 0;
@@ -170,6 +170,11 @@ void THSJIT_Module_named_attributes(const JITModule module,
170170
}
171171
}
172172

173+
void THSJIT_Module_set_attribute(const JITModule module, const char *name, Tensor tensor)
174+
{
175+
CATCH((*module)->setattr(name, *tensor););
176+
}
177+
173178
JITMethod THSJIT_Module_get_method(const JITModule module, const char* name)
174179
{
175180
auto method = (*module)->get_method(name);

src/Native/LibTorchSharp/THSJIT.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,10 +78,12 @@ EXPORT_API(void) THSJIT_Module_named_buffers(const JITModule module,
7878
Tensor* (*allocator)(size_t length),
7979
const char** (*allocator2)(size_t length));
8080

81-
EXPORT_API(void) THSJIT_Module_named_attributes(const JITModule module,
81+
EXPORT_API(void) THSJIT_Module_named_attributes(const JITModule module, bool recurse,
8282
Tensor* (*allocator)(size_t length),
8383
const char** (*allocator2)(size_t length));
8484

85+
EXPORT_API(void) THSJIT_Module_set_attribute(const JITModule module, const char* name, Tensor tensor);
86+
8587
EXPORT_API(int) THSJIT_Method_num_inputs(const JITMethod method);
8688

8789
EXPORT_API(void) THSJIT_Method_dispose(const JITMethod method);

src/TorchSharp/JIT/ScriptModule.cs

Lines changed: 35 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -46,18 +46,24 @@ protected override (string name, Tensor buffer)[] _named_buffers()
4646
return ptrArray.Select((x, i) => (Marshal.PtrToStringAnsi(strArray[i]), new Tensor(x))).ToArray();
4747
}
4848

49-
public (string name, Tensor buffer)[] named_attributes()
49+
public (string name, Tensor buffer)[] named_attributes(bool recurse = true)
5050
{
5151
using var pa = new PinnedArray<IntPtr>();
5252
using var sa = new PinnedArray<IntPtr>();
53-
THSJIT_Module_named_attributes(handle, pa.CreateArray, sa.CreateArray);
53+
THSJIT_Module_named_attributes(handle, recurse, pa.CreateArray, sa.CreateArray);
5454
CheckForErrors();
5555
var ptrArray = pa.Array;
5656
var strArray = sa.Array;
5757

5858
return ptrArray.Select((x, i) => (Marshal.PtrToStringAnsi(strArray[i]), new Tensor(x))).ToArray();
5959
}
6060

61+
public void set_attribute(string name, Tensor buffer)
62+
{
63+
THSJIT_Module_set_attribute(handle, name, buffer.Handle);
64+
CheckForErrors();
65+
}
66+
6167
/// <summary>
6268
/// Returns an enumerable of all modules in the network, yielding both the name of the module as well as the module itself.
6369
/// </summary>
@@ -149,7 +155,7 @@ protected internal override nn.Module _to(Device device, ScalarType dtype)
149155
CheckForErrors();
150156

151157
_toEpilog(device, dtype);
152-
158+
_toScriptEpilog(device, dtype);
153159
return this;
154160
}
155161

@@ -172,6 +178,7 @@ protected internal override nn.Module _to(DeviceType deviceType, int deviceIndex
172178
CheckForErrors();
173179

174180
_toEpilog(deviceType, deviceIndex);
181+
_toScriptEpilog(deviceType, deviceIndex);
175182
}
176183

177184
Debug.Assert(_deviceType == DeviceType.CUDA || _deviceIndex == -1);
@@ -192,10 +199,35 @@ protected internal override nn.Module _to(ScalarType dtype)
192199
CheckForErrors();
193200

194201
_toEpilog(dtype);
202+
_toScriptEpilog(dtype);
195203

196204
return this;
197205
}
198206

207+
protected void _toScriptEpilog(ScalarType dtype)
208+
{
209+
_toScriptEpilog(dtype, null);
210+
}
211+
212+
protected void _toScriptEpilog(Device device, ScalarType dtype)
213+
{
214+
_toScriptEpilog(dtype, device);
215+
}
216+
217+
protected void _toScriptEpilog(DeviceType deviceType, int deviceIndex)
218+
{
219+
_toScriptEpilog(null, new Device(deviceType, deviceIndex));
220+
}
221+
222+
private void _toScriptEpilog(ScalarType? dtype, Device device)
223+
{
224+
foreach (var (name, buffer) in named_attributes(recurse: false)) {
225+
if (name is null || !buffer.toWillCopy(dtype ?? buffer.dtype, device ?? buffer.device)) continue;
226+
227+
set_attribute(name, buffer.to(dtype ?? buffer.dtype, device ?? buffer.device, disposeAfter: true));
228+
}
229+
}
230+
199231
#if false // These functions "work," but the native code doesn't seem to find any interesting information.
200232

201233
public Type GetInputType(int index)

src/TorchSharp/PInvoke/LibTorchSharp.THSJIT.cs

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,10 @@ internal static partial class NativeMethods
2727
internal static extern void THSJIT_Module_named_buffers(torch.nn.Module.HType module, AllocatePinnedArray allocator1, AllocatePinnedArray allocator2);
2828

2929
[DllImport("LibTorchSharp")]
30-
internal static extern void THSJIT_Module_named_attributes(torch.nn.Module.HType module, AllocatePinnedArray allocator1, AllocatePinnedArray allocator2);
30+
internal static extern void THSJIT_Module_named_attributes(torch.nn.Module.HType module, [MarshalAs(UnmanagedType.U1)] bool recurse, AllocatePinnedArray allocator1, AllocatePinnedArray allocator2);
31+
32+
[DllImport("LibTorchSharp")]
33+
internal static extern void THSJIT_Module_set_attribute(torch.nn.Module.HType module, [MarshalAs(UnmanagedType.LPStr)] string name, IntPtr tensor);
3134

3235
[DllImport("LibTorchSharp")]
3336
internal static extern void THSJIT_Module_named_modules(torch.nn.Module.HType module, AllocatePinnedArray allocator1, AllocatePinnedArray allocator2);

test/TorchSharpTest/TestTorchTensorBugs.cs

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1395,7 +1395,7 @@ public void Validate1047_2()
13951395
Assert.Equal(iterations_1047, xx);
13961396
}
13971397

1398-
[Fact(Skip = "Work in progress")]
1398+
[Fact]
13991399
public void Validate1126()
14001400
{
14011401
var device = torch.cuda.is_available() ? torch.CUDA : torch.CPU;
@@ -1405,12 +1405,9 @@ public void Validate1126()
14051405

14061406
var model = torch.jit.load<torch.Tensor, torch.Tensor>("shakespeare.pt.zip").to(device);
14071407

1408-
var attributes = model.named_attributes().ToArray();
1409-
14101408
var xs = torch.tensor(new int[blockSize].Select(_ => new Random().Next(vocabSize)).ToArray(), device: device, dtype: torch.int64).unsqueeze(0);
14111409

14121410
var ys = model.forward(xs);
1413-
14141411
}
14151412

14161413
[Fact]

0 commit comments

Comments
 (0)