Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions RELEASENOTES.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ __Bug Fixes__:
#1174 : Loading CUDA tensor from stream threw an error<br/>
#1179 : Calling `Module.to()` with the `ParameterList` and `ParameterDict` module didn't move the parameters stored in the field.<br/>
#1148 : Calling `Module.to()` shouldn't be differentiable<br/>
#1126 : Calling `ScriptModule.to()` doesn't move attributes<br/>
#1180 : Module.to(ScalarType) has restrictions in PyTorch which aren't restricted in TorchSharp.<br/>

## NuGet Version 0.101.2
Expand Down
9 changes: 7 additions & 2 deletions src/Native/LibTorchSharp/THSJIT.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -151,11 +151,11 @@ void THSJIT_Module_named_buffers(const JITModule module,
}
}

void THSJIT_Module_named_attributes(const JITModule module,
void THSJIT_Module_named_attributes(const JITModule module, bool recurse,
Tensor* (*allocator)(size_t length),
const char** (*allocator2)(size_t length))
{
auto attributes = (*module)->named_attributes();
auto attributes = (*module)->named_attributes(recurse);
Tensor* result = allocator(attributes.size());
const char** names = allocator2(attributes.size());
int i = 0;
Expand All @@ -170,6 +170,11 @@ void THSJIT_Module_named_attributes(const JITModule module,
}
}

void THSJIT_Module_set_attribute(const JITModule module, const char *name, Tensor tensor)
{
CATCH((*module)->setattr(name, *tensor););
}

JITMethod THSJIT_Module_get_method(const JITModule module, const char* name)
{
auto method = (*module)->get_method(name);
Expand Down
4 changes: 3 additions & 1 deletion src/Native/LibTorchSharp/THSJIT.h
Original file line number Diff line number Diff line change
Expand Up @@ -78,10 +78,12 @@ EXPORT_API(void) THSJIT_Module_named_buffers(const JITModule module,
Tensor* (*allocator)(size_t length),
const char** (*allocator2)(size_t length));

EXPORT_API(void) THSJIT_Module_named_attributes(const JITModule module,
EXPORT_API(void) THSJIT_Module_named_attributes(const JITModule module, bool recurse,
Tensor* (*allocator)(size_t length),
const char** (*allocator2)(size_t length));

EXPORT_API(void) THSJIT_Module_set_attribute(const JITModule module, const char* name, Tensor tensor);

EXPORT_API(int) THSJIT_Method_num_inputs(const JITMethod method);

EXPORT_API(void) THSJIT_Method_dispose(const JITMethod method);
Expand Down
38 changes: 35 additions & 3 deletions src/TorchSharp/JIT/ScriptModule.cs
Original file line number Diff line number Diff line change
Expand Up @@ -46,18 +46,24 @@ protected override (string name, Tensor buffer)[] _named_buffers()
return ptrArray.Select((x, i) => (Marshal.PtrToStringAnsi(strArray[i]), new Tensor(x))).ToArray();
}

public (string name, Tensor buffer)[] named_attributes()
public (string name, Tensor buffer)[] named_attributes(bool recurse = true)
{
using var pa = new PinnedArray<IntPtr>();
using var sa = new PinnedArray<IntPtr>();
THSJIT_Module_named_attributes(handle, pa.CreateArray, sa.CreateArray);
THSJIT_Module_named_attributes(handle, recurse, pa.CreateArray, sa.CreateArray);
CheckForErrors();
var ptrArray = pa.Array;
var strArray = sa.Array;

return ptrArray.Select((x, i) => (Marshal.PtrToStringAnsi(strArray[i]), new Tensor(x))).ToArray();
}

public void set_attribute(string name, Tensor buffer)
{
THSJIT_Module_set_attribute(handle, name, buffer.Handle);
CheckForErrors();
}

/// <summary>
/// Returns an enumerable of all modules in the network, yielding both the name of the module as well as the module itself.
/// </summary>
Expand Down Expand Up @@ -149,7 +155,7 @@ protected internal override nn.Module _to(Device device, ScalarType dtype)
CheckForErrors();

_toEpilog(device, dtype);

_toScriptEpilog(device, dtype);
return this;
}

Expand All @@ -172,6 +178,7 @@ protected internal override nn.Module _to(DeviceType deviceType, int deviceIndex
CheckForErrors();

_toEpilog(deviceType, deviceIndex);
_toScriptEpilog(deviceType, deviceIndex);
}

Debug.Assert(_deviceType == DeviceType.CUDA || _deviceIndex == -1);
Expand All @@ -192,10 +199,35 @@ protected internal override nn.Module _to(ScalarType dtype)
CheckForErrors();

_toEpilog(dtype);
_toScriptEpilog(dtype);

return this;
}

protected void _toScriptEpilog(ScalarType dtype)
{
_toScriptEpilog(dtype, null);
}

protected void _toScriptEpilog(Device device, ScalarType dtype)
{
_toScriptEpilog(dtype, device);
}

protected void _toScriptEpilog(DeviceType deviceType, int deviceIndex)
{
_toScriptEpilog(null, new Device(deviceType, deviceIndex));
}

private void _toScriptEpilog(ScalarType? dtype, Device device)
{
foreach (var (name, buffer) in named_attributes(recurse: false)) {
if (name is null || !buffer.toWillCopy(dtype ?? buffer.dtype, device ?? buffer.device)) continue;

set_attribute(name, buffer.to(dtype ?? buffer.dtype, device ?? buffer.device, disposeAfter: true));
}
}

#if false // These functions "work," but the native code doesn't seem to find any interesting information.

public Type GetInputType(int index)
Expand Down
5 changes: 4 additions & 1 deletion src/TorchSharp/PInvoke/LibTorchSharp.THSJIT.cs
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,10 @@ internal static partial class NativeMethods
internal static extern void THSJIT_Module_named_buffers(torch.nn.Module.HType module, AllocatePinnedArray allocator1, AllocatePinnedArray allocator2);

[DllImport("LibTorchSharp")]
internal static extern void THSJIT_Module_named_attributes(torch.nn.Module.HType module, AllocatePinnedArray allocator1, AllocatePinnedArray allocator2);
internal static extern void THSJIT_Module_named_attributes(torch.nn.Module.HType module, [MarshalAs(UnmanagedType.U1)] bool recurse, AllocatePinnedArray allocator1, AllocatePinnedArray allocator2);

[DllImport("LibTorchSharp")]
internal static extern void THSJIT_Module_set_attribute(torch.nn.Module.HType module, [MarshalAs(UnmanagedType.LPStr)] string name, IntPtr tensor);

[DllImport("LibTorchSharp")]
internal static extern void THSJIT_Module_named_modules(torch.nn.Module.HType module, AllocatePinnedArray allocator1, AllocatePinnedArray allocator2);
Expand Down
5 changes: 1 addition & 4 deletions test/TorchSharpTest/TestTorchTensorBugs.cs
Original file line number Diff line number Diff line change
Expand Up @@ -1395,7 +1395,7 @@ public void Validate1047_2()
Assert.Equal(iterations_1047, xx);
}

[Fact(Skip = "Work in progress")]
[Fact]
public void Validate1126()
{
var device = torch.cuda.is_available() ? torch.CUDA : torch.CPU;
Expand All @@ -1405,12 +1405,9 @@ public void Validate1126()

var model = torch.jit.load<torch.Tensor, torch.Tensor>("shakespeare.pt.zip").to(device);

var attributes = model.named_attributes().ToArray();

var xs = torch.tensor(new int[blockSize].Select(_ => new Random().Next(vocabSize)).ToArray(), device: device, dtype: torch.int64).unsqueeze(0);

var ys = model.forward(xs);

}

[Fact]
Expand Down