diff --git a/src/Examples/AdversarialExampleGeneration.cs b/src/Examples/AdversarialExampleGeneration.cs index a68733103..276dddcf5 100644 --- a/src/Examples/AdversarialExampleGeneration.cs +++ b/src/Examples/AdversarialExampleGeneration.cs @@ -114,7 +114,7 @@ private static Tensor Attack(Tensor image, double ε, Tensor data_grad) private static double Test( MNIST.Model model, - Loss criterion, + Loss criterion, double ε, Device device, Dataset dataset, diff --git a/src/Examples/AlexNet.cs b/src/Examples/AlexNet.cs index 7bb4f1fed..321f1ec79 100644 --- a/src/Examples/AlexNet.cs +++ b/src/Examples/AlexNet.cs @@ -8,11 +8,11 @@ namespace TorchSharp.Examples /// /// Modified version of original AlexNet to fix CIFAR10 32x32 images. /// - class AlexNet : Module + class AlexNet : Module { - private readonly Module features; - private readonly Module avgPool; - private readonly Module classifier; + private readonly Module features; + private readonly Module avgPool; + private readonly Module classifier; public AlexNet(string name, int numClasses, torch.Device device = null) : base(name) { diff --git a/src/Examples/CIFAR10.cs b/src/Examples/CIFAR10.cs index 5cb383a3c..2bb9a0d4f 100644 --- a/src/Examples/CIFAR10.cs +++ b/src/Examples/CIFAR10.cs @@ -56,7 +56,7 @@ internal static void Main(string[] args) Console.WriteLine($"\tCreating the model..."); - Module model = null; + Module model = null; switch (modelName.ToLower()) { case "alexnet": @@ -134,9 +134,9 @@ internal static void Main(string[] args) } private static void Train( - Module model, + Module model, torch.optim.Optimizer optimizer, - Loss loss, + Loss loss, DataLoader dataLoader, int epoch, long batchSize, @@ -182,8 +182,8 @@ private static void Train( } private static void Test( - Module model, - Loss loss, + Module model, + Loss loss, DataLoader dataLoader, long size) { diff --git a/src/Examples/MNIST.cs b/src/Examples/MNIST.cs index 3cdca7e6a..dded22016 100644 --- a/src/Examples/MNIST.cs +++ b/src/Examples/MNIST.cs @@ -96,26 +96,26 @@ internal static void TrainingLoop(string dataset, Device device, Model model, Da model.save(dataset + ".model.bin"); } - internal class Model : Module + internal class Model : Module { - private Module conv1 = Conv2d(1, 32, 3); - private Module conv2 = Conv2d(32, 64, 3); - private Module fc1 = Linear(9216, 128); - private Module fc2 = Linear(128, 10); + private Module conv1 = Conv2d(1, 32, 3); + private Module conv2 = Conv2d(32, 64, 3); + private Module fc1 = Linear(9216, 128); + private Module fc2 = Linear(128, 10); // These don't have any parameters, so the only reason to instantiate // them is performance, since they will be used over and over. - private Module pool1 = MaxPool2d(kernelSize: new long[] { 2, 2 }); + private Module pool1 = MaxPool2d(kernelSize: new long[] { 2, 2 }); - private Module relu1 = ReLU(); - private Module relu2 = ReLU(); - private Module relu3 = ReLU(); + private Module relu1 = ReLU(); + private Module relu2 = ReLU(); + private Module relu3 = ReLU(); - private Module dropout1 = Dropout(0.25); - private Module dropout2 = Dropout(0.5); + private Module dropout1 = Dropout(0.25); + private Module dropout2 = Dropout(0.5); - private Module flatten = Flatten(); - private Module logsm = LogSoftmax(1); + private Module flatten = Flatten(); + private Module logsm = LogSoftmax(1); public Model(string name, torch.Device device = null) : base(name) { @@ -151,7 +151,7 @@ public override Tensor forward(Tensor input) private static void Train( Model model, torch.optim.Optimizer optimizer, - Loss loss, + Loss loss, DataLoader dataLoader, int epoch, long size) @@ -191,7 +191,7 @@ private static void Train( private static void Test( Model model, - Loss loss, + Loss loss, DataLoader dataLoader, long size) { diff --git a/src/Examples/MobileNet.cs b/src/Examples/MobileNet.cs index 57b4be266..e1c3187c8 100644 --- a/src/Examples/MobileNet.cs +++ b/src/Examples/MobileNet.cs @@ -14,7 +14,7 @@ namespace TorchSharp.Examples /// With an unaugmented CIFAR-10 data set, the author of this saw training converge /// at roughly 75% accuracy on the test set, over the course of 1500 epochs. /// - class MobileNet : Module + class MobileNet : Module { // The code here is is loosely based on https://github.com/kuangliu/pytorch-cifar/blob/master/models/mobilenet.py // Licence and copypright notice at: https://github.com/kuangliu/pytorch-cifar/blob/master/LICENSE @@ -22,13 +22,13 @@ class MobileNet : Module private readonly long[] planes = new long[] { 64, 128, 128, 256, 256, 512, 512, 512, 512, 512, 512, 1024, 1024 }; private readonly long[] strides = new long[] { 1, 2, 1, 2, 1, 2, 1, 1, 1, 1, 1, 2, 1 }; - private readonly Module layers; + private readonly Module layers; public MobileNet(string name, int numClasses, Device device = null) : base(name) { if (planes.Length != strides.Length) throw new ArgumentException("'planes' and 'strides' must have the same length."); - var modules = new List<(string, Module)>(); + var modules = new List<(string, Module)>(); modules.Add(($"conv2d-first", Conv2d(3, 32, kernelSize: 3, stride: 1, padding: 1, bias: false))); modules.Add(($"bnrm2d-first", BatchNorm2d(32))); @@ -46,7 +46,7 @@ public MobileNet(string name, int numClasses, Device device = null) : base(name) this.to(device); } - private void MakeLayers(List<(string, Module)> modules, long in_planes) + private void MakeLayers(List<(string, Module)> modules, long in_planes) { for (var i = 0; i < strides.Length; i++) { diff --git a/src/Examples/ResNet.cs b/src/Examples/ResNet.cs index 74745b324..b266fdbcb 100644 --- a/src/Examples/ResNet.cs +++ b/src/Examples/ResNet.cs @@ -10,12 +10,12 @@ namespace TorchSharp.Examples /// /// Modified version of ResNet to classify CIFAR10 32x32 images. /// - class ResNet : Module + class ResNet : Module { // The code here is is loosely based on https://github.com/kuangliu/pytorch-cifar/blob/master/models/resnet.py // Licence and copypright notice at: https://github.com/kuangliu/pytorch-cifar/blob/master/LICENSE - private readonly Module layers; + private readonly Module layers; private int in_planes = 64; public static ResNet ResNet18(int numClasses, Device device = null) @@ -68,9 +68,9 @@ public static ResNet ResNet152(int numClasses, Device device = null) device); } - public ResNet(string name, Func block, int expansion, IList num_blocks, int numClasses, Device device = null) : base(name) + public ResNet(string name, Func> block, int expansion, IList num_blocks, int numClasses, Device device = null) : base(name) { - var modules = new List<(string, Module)>(); + var modules = new List<(string, Module)>(); modules.Add(($"conv2d-first", Conv2d(3, 64, kernelSize: 3, stride: 1, padding: 1, bias: false))); modules.Add(($"bnrm2d-first", BatchNorm2d(64))); @@ -91,7 +91,7 @@ public ResNet(string name, Func block, int expansion this.to(device); } - private void MakeLayer(List<(string, Module)> modules, Func block, int expansion, int planes, int num_blocks, int stride) + private void MakeLayer(List<(string, Module)> modules, Func> block, int expansion, int planes, int num_blocks, int stride) { var strides = new List(); strides.Add(stride); @@ -109,11 +109,11 @@ public override Tensor forward(Tensor input) return layers.forward(input); } - class BasicBlock : Module + class BasicBlock : Module { public BasicBlock (string name, int in_planes, int planes, int stride) : base(name) { - var modules = new List<(string, Module)>(); + var modules = new List<(string, Module)>(); modules.Add(($"{name}-conv2d-1", Conv2d(in_planes, planes, kernelSize: 3, stride: stride, padding: 1, bias: false))); modules.Add(($"{name}-bnrm2d-1", BatchNorm2d(planes))); @@ -146,15 +146,15 @@ public override Tensor forward(Tensor t) public static int expansion = 1; - private readonly Module layers; - private readonly Module shortcut; + private readonly Module layers; + private readonly Module shortcut; } - class Bottleneck : Module + class Bottleneck : Module { public Bottleneck(string name, int in_planes, int planes, int stride) : base(name) { - var modules = new List<(string, Module)>(); + var modules = new List<(string, Module)>(); modules.Add(($"{name}-conv2d-1", Conv2d(in_planes, planes, kernelSize: 1, bias: false))); modules.Add(($"{name}-bnrm2d-1", BatchNorm2d(planes))); @@ -187,8 +187,8 @@ public override Tensor forward(Tensor t) public static int expansion = 4; - private readonly Module layers; - private readonly Module shortcut; + private readonly Module layers; + private readonly Module shortcut; } } } diff --git a/src/Examples/SequenceToSequence.cs b/src/Examples/SequenceToSequence.cs index 4add405a5..d243f67a0 100644 --- a/src/Examples/SequenceToSequence.cs +++ b/src/Examples/SequenceToSequence.cs @@ -104,7 +104,7 @@ internal static void Main(string[] args) Console.WriteLine($"\nEnd of training | time: {totalTime.Elapsed.TotalSeconds:0.0}s | loss: {tst_loss:0.00}\n"); } - private static void train(int epoch, Tensor train_data, TransformerModel model, Loss criterion, int bptt, int ntokens, torch.optim.Optimizer optimizer) + private static void train(int epoch, Tensor train_data, TransformerModel model, Loss criterion, int bptt, int ntokens, torch.optim.Optimizer optimizer) { model.train(); @@ -149,7 +149,7 @@ private static void train(int epoch, Tensor train_data, TransformerModel model, } } - private static double evaluate(Tensor eval_data, TransformerModel model, Loss criterion, int bptt, int ntokens, torch.optim.Optimizer optimizer) + private static double evaluate(Tensor eval_data, TransformerModel model, Loss criterion, int bptt, int ntokens, torch.optim.Optimizer optimizer) { model.eval(); @@ -211,9 +211,9 @@ static Tensor Batchify(Tensor data, int batch_size) return (data, target); } - class TransformerModel : Module + class TransformerModel : Module { - private Module transformer_encoder; + private Modules.TransformerEncoder transformer_encoder; private PositionalEncoding pos_encoder; private Modules.Embedding encoder; private Modules.Linear decoder; @@ -252,11 +252,6 @@ private void InitWeights() init.uniform_(decoder.weight, -initrange, initrange); } - public override Tensor forward(Tensor t) - { - throw new NotImplementedException("single-argument forward()"); - } - public override Tensor forward(Tensor t, Tensor mask) { var src = pos_encoder.forward(encoder.forward(t) * MathF.Sqrt(ninputs)); @@ -271,9 +266,9 @@ protected override Module _to(DeviceType deviceType, int deviceIndex = -1) } } - class PositionalEncoding : Module + class PositionalEncoding : Module { - private Module dropout; + private Module dropout; private Tensor pe; public PositionalEncoding(long dmodel, double dropout, int maxLen = 5000) : base("PositionalEncoding") diff --git a/src/Examples/SpeechCommands.cs b/src/Examples/SpeechCommands.cs index b7bcc176a..ebe39831e 100644 --- a/src/Examples/SpeechCommands.cs +++ b/src/Examples/SpeechCommands.cs @@ -123,7 +123,7 @@ private static void Train( M5 model, ITransform transform, torch.optim.Optimizer optimizer, - Loss criteria, + Loss criteria, DataLoader dataLoader, int epoch, long size) @@ -160,7 +160,7 @@ private static void Train( private static void Test( M5 model, ITransform transform, - Loss criteria, + Loss criteria, DataLoader dataLoader, long size) { @@ -217,21 +217,21 @@ private class BatchItem public torch.Tensor label; } - internal class M5 : Module + internal class M5 : Module { - private readonly Module conv1; - private readonly Module bn1; - private readonly Module pool1; - private readonly Module conv2; - private readonly Module bn2; - private readonly Module pool2; - private readonly Module conv3; - private readonly Module bn3; - private readonly Module pool3; - private readonly Module conv4; - private readonly Module bn4; - private readonly Module pool4; - private readonly Module fc1; + private readonly Module conv1; + private readonly Module bn1; + private readonly Module pool1; + private readonly Module conv2; + private readonly Module bn2; + private readonly Module pool2; + private readonly Module conv3; + private readonly Module bn3; + private readonly Module pool3; + private readonly Module conv4; + private readonly Module bn4; + private readonly Module pool4; + private readonly Module fc1; public M5(string name, int n_input = 1, int n_output = 35, int stride = 16, int n_channel = 32) : base(name) { diff --git a/src/Examples/TextClassification.cs b/src/Examples/TextClassification.cs index 9bbc587dc..d456dc130 100644 --- a/src/Examples/TextClassification.cs +++ b/src/Examples/TextClassification.cs @@ -109,7 +109,7 @@ internal static void Main(string[] args) } } - static void train(int epoch, IEnumerable<(Tensor, Tensor, Tensor)> train_data, TextClassificationModel model, Loss criterion, torch.optim.Optimizer optimizer) + static void train(int epoch, IEnumerable<(Tensor, Tensor, Tensor)> train_data, TextClassificationModel model, Loss criterion, torch.optim.Optimizer optimizer) { model.train(); @@ -145,7 +145,7 @@ static void train(int epoch, IEnumerable<(Tensor, Tensor, Tensor)> train_data, T } } - static double evaluate(IEnumerable<(Tensor, Tensor, Tensor)> test_data, TextClassificationModel model, Loss criterion) + static double evaluate(IEnumerable<(Tensor, Tensor, Tensor)> test_data, TextClassificationModel model, Loss criterion) { model.eval(); @@ -166,7 +166,7 @@ static double evaluate(IEnumerable<(Tensor, Tensor, Tensor)> test_data, TextClas } } - class TextClassificationModel : Module + class TextClassificationModel : Module { private Modules.EmbeddingBag embedding; private Modules.Linear fc; @@ -194,7 +194,7 @@ public override Tensor forward(Tensor t) throw new NotImplementedException(); } - public override Tensor forward(Tensor input, Tensor offsets) + public Tensor forward(Tensor input, Tensor offsets) { return fc.forward(embedding.forward(input, offsets)); } diff --git a/src/Examples/VGG.cs b/src/Examples/VGG.cs index de952341d..e3e8b813e 100644 --- a/src/Examples/VGG.cs +++ b/src/Examples/VGG.cs @@ -13,7 +13,7 @@ namespace TorchSharp.Examples /// With an unaugmented CIFAR-10 data set, the author of this saw training converge /// at roughly 85% accuracy on the test set, after 50 epochs using VGG-16. /// - class VGG : Module + class VGG : Module { // The code here is is loosely based on https://github.com/kuangliu/pytorch-cifar/blob/master/models/vgg.py // Licence and copypright notice at: https://github.com/kuangliu/pytorch-cifar/blob/master/LICENSE @@ -25,11 +25,11 @@ class VGG : Module { "VGG19", new long[] { 64, 64, 0, 128, 128, 0, 256, 256, 256, 256, 0, 512, 512, 512, 512, 0, 512, 512, 512, 512, 0 } } }; - private readonly Module layers; + private readonly Module layers; public VGG(string name, int numClasses, Device device = null) : base(name) { - var modules = new List<(string, Module)>(); + var modules = new List<(string, Module)>(); var channels = _channels[name]; diff --git a/src/FSharp.Examples/AlexNet.fs b/src/FSharp.Examples/AlexNet.fs index f4858ebf1..956ffc1fe 100644 --- a/src/FSharp.Examples/AlexNet.fs +++ b/src/FSharp.Examples/AlexNet.fs @@ -44,32 +44,32 @@ let getDataFiles sourceDir targetDir = Utils.Decompress.ExtractTGZ(Path.Combine(sourceDir, "cifar-10-binary.tar.gz"), targetDir) type Model(name,device:torch.Device) as this = - inherit Module(name) - - let features = Sequential(("c1", Conv2d(3L, 64L, kernelSize=3L, stride=2L, padding=1L) :> Module), - ("r1", ReLU(inplace=true) :> Module), - ("mp1", MaxPool2d(kernelSize=[|2L; 2L|]) :> Module), - ("c2", Conv2d(64L, 192L, kernelSize=3L, padding=1L) :> Module), - ("r2", ReLU(inplace=true) :> Module), - ("mp2", MaxPool2d(kernelSize=[|2L; 2L|]) :> Module), - ("c3", Conv2d(192L, 384L, kernelSize=3L, padding=1L) :> Module), - ("r3", ReLU(inplace=true) :> Module), - ("c4", Conv2d(384L, 256L, kernelSize=3L, padding=1L) :> Module), - ("r4", ReLU(inplace=true) :> Module), - ("c5", Conv2d(256L, 256L, kernelSize=3L, padding=1L) :> Module), - ("r5", ReLU(inplace=true) :> Module), - ("mp3", MaxPool2d(kernelSize=[|2L; 2L|]) :> Module), - ("avg", AdaptiveAvgPool2d([|2L; 2L|]) :> Module)) - - let classifier = Sequential(("d1", Dropout() :> Module), - ("l1", Linear(256L * 2L * 2L, 4096L) :> Module), - ("r6", ReLU(inplace=true) :> Module), - ("d2", Dropout() :> Module), - ("l2", Linear(4096L, 4096L) :> Module), - ("r7", ReLU(inplace=true) :> Module), - ("d3", Dropout() :> Module), - ("l3", Linear(4096L, numClasses) :> Module), - ("logsm", LogSoftmax(1L) :> Module)) + inherit Module(name) + + let features = Sequential(("c1", Conv2d(3L, 64L, kernelSize=3L, stride=2L, padding=1L) :> Module), + ("r1", ReLU(inplace=true) :> Module), + ("mp1", MaxPool2d(kernelSize=[|2L; 2L|]) :> Module), + ("c2", Conv2d(64L, 192L, kernelSize=3L, padding=1L) :> Module), + ("r2", ReLU(inplace=true) :> Module), + ("mp2", MaxPool2d(kernelSize=[|2L; 2L|]) :> Module), + ("c3", Conv2d(192L, 384L, kernelSize=3L, padding=1L) :> Module), + ("r3", ReLU(inplace=true) :> Module), + ("c4", Conv2d(384L, 256L, kernelSize=3L, padding=1L) :> Module), + ("r4", ReLU(inplace=true) :> Module), + ("c5", Conv2d(256L, 256L, kernelSize=3L, padding=1L) :> Module), + ("r5", ReLU(inplace=true) :> Module), + ("mp3", MaxPool2d(kernelSize=[|2L; 2L|]) :> Module), + ("avg", AdaptiveAvgPool2d([|2L; 2L|]) :> Module)) + + let classifier = Sequential(("d1", Dropout() :> Module), + ("l1", Linear(256L * 2L * 2L, 4096L) :> Module), + ("r6", ReLU(inplace=true) :> Module), + ("d2", Dropout() :> Module), + ("l2", Linear(4096L, 4096L) :> Module), + ("r7", ReLU(inplace=true) :> Module), + ("d3", Dropout() :> Module), + ("l3", Linear(4096L, numClasses) :> Module), + ("logsm", LogSoftmax(1L) :> Module)) do this.RegisterComponents() diff --git a/src/FSharp.Examples/MNIST.fs b/src/FSharp.Examples/MNIST.fs index f3a357efc..de72b2bb8 100644 --- a/src/FSharp.Examples/MNIST.fs +++ b/src/FSharp.Examples/MNIST.fs @@ -44,7 +44,7 @@ let hasCUDA = TorchText.Datasets.cuda_is_available() //torch.cuda.is_available() let device = if hasCUDA then torch.CUDA else torch.CPU type Model(name,device:torch.Device) as this = - inherit Module(name) + inherit Module(name) let conv1 = Conv2d(1L, 32L, 3L) let conv2 = Conv2d(32L, 64L, 3L) diff --git a/src/FSharp.Examples/SequenceToSequence.fs b/src/FSharp.Examples/SequenceToSequence.fs index 3794af1cc..5f52564f1 100644 --- a/src/FSharp.Examples/SequenceToSequence.fs +++ b/src/FSharp.Examples/SequenceToSequence.fs @@ -53,7 +53,7 @@ let loss = torch.nn.CrossEntropyLoss(reduction=Reduction.Mean) let criterion x y = loss.forward(x,y) type PositionalEncoding(dmodel, maxLen) as this = - inherit Module("PositionalEncoding") + inherit Module("PositionalEncoding") let dropout = Dropout(dropout) let mutable pe = torch.zeros([| maxLen; dmodel|]) @@ -79,7 +79,7 @@ type PositionalEncoding(dmodel, maxLen) as this = dropout.forward(x) type TransformerModel(ntokens, device:torch.Device) as this = - inherit Module("Transformer") + inherit Module("Transformer") let pos_encoder = new PositionalEncoding(emsize, 5000L) let encoder_layers = TransformerEncoderLayer(emsize, nheads, nhidden, dropout) @@ -101,8 +101,6 @@ type TransformerModel(ntokens, device:torch.Device) as this = if device.``type`` = DeviceType.CUDA then this.``to``(device) |> ignore - override _.forward(input) = raise (NotImplementedException("single-argument forward()")) - override _.forward(t, mask) = let src = pos_encoder.forward(encoder.forward(t) * sqrEmSz) let enc = transformer_encoder.forward(src, mask) diff --git a/src/FSharp.Examples/TextClassification.fs b/src/FSharp.Examples/TextClassification.fs index 8300ce8ea..dc6783487 100644 --- a/src/FSharp.Examples/TextClassification.fs +++ b/src/FSharp.Examples/TextClassification.fs @@ -46,7 +46,7 @@ let loss = torch.nn.CrossEntropyLoss() let criterion x y = loss.forward(x,y) type TextClassificationModel(vocabSize, embedDim, nClasses, device:torch.Device) as this = - inherit Module("Transformer") + inherit Module("Transformer") let embedding = EmbeddingBag(vocabSize, embedDim, sparse=false) let fc = Linear(embedDim, nClasses) @@ -63,8 +63,6 @@ type TextClassificationModel(vocabSize, embedDim, nClasses, device:torch.Device) if device.``type`` = DeviceType.CUDA then this.``to``(device) |> ignore - override _.forward(input) = raise (NotImplementedException("single-argument forward()")) - override _.forward(input, offsets) = embedding.forward(input, offsets) --> fc diff --git a/src/Native/LibTorchSharp/THSJIT.cpp b/src/Native/LibTorchSharp/THSJIT.cpp index 5ea945589..e1c2d7e87 100644 --- a/src/Native/LibTorchSharp/THSJIT.cpp +++ b/src/Native/LibTorchSharp/THSJIT.cpp @@ -145,9 +145,46 @@ JITMethod THSJIT_Module_get_method(const JITModule module, const char* name) return new std::shared_ptr(copy); } -Tensor THSJIT_Module_forward(const JITModule module, const Tensor* tensorPtrs, const int length) +void THSJIT_Module_forward(const JITModule module, const Tensor* tensorPtrs, const int length, Tensor* (*allocator)(size_t length), int8_t* typeCode) { - CATCH_TENSOR((*module)->forward(toTensors((torch::Tensor**)tensorPtrs, length)).toTensor()); + *typeCode = 0; + + CATCH( + auto result = (*module)->forward(toTensors((torch::Tensor**)tensorPtrs, length)); + + // TypeCode: + // + // 0 -- Not supported + // 1 -- Single tensor + // 2 -- Tuple of tensors + // 3 -- List of tensors + + if (result.isTensor()) { + Tensor* output = allocator(1); + output[0] = ResultTensor(result.toTensor()); + *typeCode = 1; + return; + } + if (result.isTensorList()) { + auto list = result.toTensorList(); + *typeCode = 3; + Tensor* output = allocator(list.size()); + for (size_t i = 0; i < list.size(); i++) + output[i] = ResultTensor(list[i]); + return; + } + if (result.isTuple()) { + auto tuple = result.toTuple(); + auto list = tuple->elements(); + auto sz = list.size(); + *typeCode = 2; + Tensor* output = allocator(list.size()); + for (size_t i = 0; i < list.size(); i++) + // Assuming that all elements are tensors. + output[i] = ResultTensor(list[i].toTensor()); + return; + } + ) } void THSJIT_Module_dispose(const JITModule module) diff --git a/src/Native/LibTorchSharp/THSJIT.h b/src/Native/LibTorchSharp/THSJIT.h index f8e8c2206..0b3d0c5b9 100644 --- a/src/Native/LibTorchSharp/THSJIT.h +++ b/src/Native/LibTorchSharp/THSJIT.h @@ -25,7 +25,7 @@ EXPORT_API(void) THSJIT_Module_dispose(const JITModule module); EXPORT_API(int) THSJIT_Module_num_inputs(const JITModule method); EXPORT_API(int) THSJIT_Module_num_outputs(const JITModule method); -EXPORT_API(Tensor) THSJIT_Module_forward(const JITModule module, const Tensor* tensorPtrs, const int length); +EXPORT_API(void) THSJIT_Module_forward(const JITModule module, const Tensor* tensorPtrs, const int length, Tensor* (*allocator)(size_t length), int8_t* typeCode); EXPORT_API(int) THSJIT_Module_is_training(JITModule module); EXPORT_API(void) THSJIT_Module_train(JITModule module, bool on); diff --git a/src/TorchSharp/JIT/ScriptModule.cs b/src/TorchSharp/JIT/ScriptModule.cs index d56f2849f..7835e5258 100644 --- a/src/TorchSharp/JIT/ScriptModule.cs +++ b/src/TorchSharp/JIT/ScriptModule.cs @@ -6,6 +6,7 @@ using System.Reflection; using System.Runtime.InteropServices; using static TorchSharp.torch; +using System.Net; namespace TorchSharp { @@ -268,91 +269,125 @@ private Type GetType(Type type) #endif [DllImport("LibTorchSharp")] - private static extern IntPtr THSJIT_Module_forward(HType module, IntPtr tensors, int length); + private static extern void THSJIT_Module_forward(HType module, IntPtr tensors, int length, AllocatePinnedArray allocator, out sbyte typeCode); - /// - /// Invoke the 'forward' function of the script with one tensor as its argument - /// - /// The input tensor - /// - public unsafe override Tensor forward(Tensor tensor) + public object forward(params object[] objs) { - var tensorRefs = stackalloc[] { tensor.Handle }; - var res = THSJIT_Module_forward(handle, (IntPtr)tensorRefs, 1); - if (res == IntPtr.Zero) - CheckForErrors(); - return new Tensor(res); + if (!objs.All(o => typeof(Tensor).IsAssignableFrom(o.GetType()))) { + throw new NotImplementedException("ScriptModule.forward() taking non-tensors as input arguments"); + } + + IntPtr[] ptrArray = null; + sbyte typeCode = 0; + + using (var parray = new PinnedArray()) { + + var tensors = objs.Select(o => (Tensor)o).ToArray(); + var count = tensors.Length; + var tensorRefs = new IntPtr[count]; + for (var i = 0; i < tensors.Length; i++) tensorRefs[i] = tensors[i].Handle; + + THSJIT_Module_forward(handle, parray.CreateArray(tensorRefs), count, parray.CreateArray, out typeCode); + torch.CheckForErrors(); + ptrArray = parray.Array; + } + + + switch (typeCode) { + default: + // Nothing. + throw new NotImplementedException("ScriptModule.forward() returning something else than a tensor, a tuple of tensors, or list of tensors."); + case 1: + // Tensor + return new Tensor(ptrArray[0]); + case 2: + // Tuple + switch (ptrArray.Length) { + case 1: + return new Tensor(ptrArray[0]); + case 2: + return (new Tensor(ptrArray[0]), new Tensor(ptrArray[1])); + case 3: + return (new Tensor(ptrArray[0]), new Tensor(ptrArray[1]), new Tensor(ptrArray[2])); + case 4: + return (new Tensor(ptrArray[0]), new Tensor(ptrArray[1]), new Tensor(ptrArray[2]), new Tensor(ptrArray[3])); + case 5: + return (new Tensor(ptrArray[0]), new Tensor(ptrArray[1]), new Tensor(ptrArray[2]), new Tensor(ptrArray[3]), new Tensor(ptrArray[4])); + default: { + // Too long a tuple, return as a list, instead. + var result = new Tensor[ptrArray.Length]; + for (var i = 0; i < ptrArray.Length; i++) { + result[i] = new Tensor(ptrArray[i]); + } + return result; + } + } + case 3: { + // List of tensors + var result = new Tensor[ptrArray.Length]; + for (var i = 0; i < ptrArray.Length; i++) { + result[i] = new Tensor(ptrArray[i]); + } + return result; + } + } } + } + + /// + /// A script module taking any number of tensors as input + /// + /// The return type of the module. + public class ScriptModule : ScriptModule, torch.nn.IModule + { + internal ScriptModule(IntPtr handle) : base(handle) { } /// - /// Invoke the 'forward' function of the script with two tensors as its argument + /// Invoke the 'forward' function of the script with one tensor as its argument /// - /// The first input tensor - /// The second input tensor /// - public unsafe override Tensor forward(Tensor x, Tensor y) + public TResult forward(params Tensor[] tensor) { - var tensorRefs = stackalloc[] { x.Handle, y.Handle }; - var res = THSJIT_Module_forward(handle, (IntPtr)tensorRefs, 2); - if (res == IntPtr.Zero) - CheckForErrors(); - return new Tensor(res); + return (TResult)base.forward(tensor); } + } + + /// + /// A script module taking a single argument. + /// + /// The argument type. + /// The return type of the module. + public class ScriptModule : ScriptModule, torch.nn.IModule + { + internal ScriptModule(IntPtr handle) : base(handle) { } /// - /// Invoke the 'forward' function of the script with three tensors as its argument + /// Invoke the 'forward' function of the script with one tensor as its argument /// - /// The first input tensor - /// The second input tensor - /// The third input tensor /// - public unsafe override Tensor forward(Tensor x, Tensor y, Tensor z) + public TResult forward(T tensor) { - var tensorRefs = stackalloc[] { x.Handle, y.Handle, z.Handle }; - var res = THSJIT_Module_forward(handle, (IntPtr)tensorRefs, 3); - if (res == IntPtr.Zero) - CheckForErrors(); - return new Tensor(res); + return (TResult)base.forward(tensor); } + } + + /// + /// A script module taking two arguments. + /// + /// The first argument type. + /// The second argument type. + /// The return type of the module. + public class ScriptModule : ScriptModule, torch.nn.IModule + { + internal ScriptModule(IntPtr handle) : base(handle) { } /// - /// Invoke the 'forward' function of the script with four or more tensors as its argument + /// Invoke the 'forward' function of the script with one tensor as its argument /// - /// The first input tensor - /// The second input tensor - /// The third input tensor - /// The remaining tensors. /// - public unsafe Tensor forward(Tensor x, Tensor y, Tensor z, params Tensor[] tensors) + public TResult forward(T1 input1, T2 input2) { - var count = 3 + tensors.Length; - - if (count < 32) { - var tensorRefs = stackalloc IntPtr[count]; - tensorRefs[0] = x.Handle; - tensorRefs[1] = y.Handle; - tensorRefs[2] = z.Handle; - for (var i = 0; i < tensors.Length; i++) tensorRefs[3 + i] = tensors[i].Handle; - - var res = THSJIT_Module_forward(handle, (IntPtr)tensorRefs, count); - if (res == IntPtr.Zero) - CheckForErrors(); - return new Tensor(res); - } else { - // It the unlikely event that there's a great number of arguments, use heap allocation. - var tensorRefs = new IntPtr[count]; - tensorRefs[0] = x.Handle; - tensorRefs[1] = y.Handle; - tensorRefs[2] = z.Handle; - for (var i = 0; i < tensors.Length; i++) tensorRefs[3 + i] = tensors[i].Handle; - - using (var parray = new PinnedArray()) { - var res = THSJIT_Module_forward(handle, parray.CreateArray(tensorRefs), count); - if (res == IntPtr.Zero) - CheckForErrors(); - return new Tensor(res); - } - } + return (TResult)base.forward(input1, input2); } } @@ -362,13 +397,66 @@ public unsafe Tensor forward(Tensor x, Tensor y, Tensor z, params Tensor[] tenso /// /// Load a ScriptModule or ScriptFunction previously saved with torch.jit.save /// - /// + /// The file name of the module. /// A ScriptModule instance, whether the script originated as a module or function. /// /// All previously saved modules, no matter their device, are first loaded onto CPU, and then are moved to the devices they were saved from.If this fails (e.g.because the run time system doesn’t have certain devices), an exception is raised. /// /// Raised if the file is not found. public static ScriptModule load(string filename) + { + return new ScriptModule(_load(filename)); + } + + /// + /// Load a ScriptModule or ScriptFunction previously saved with torch.jit.save + /// + /// The return type of the module. + /// The file name of the module. + /// A ScriptModule instance, whether the script originated as a module or function. + /// + /// All previously saved modules, no matter their device, are first loaded onto CPU, and then are moved to the devices they were saved from.If this fails (e.g.because the run time system doesn’t have certain devices), an exception is raised. + /// + /// Raised if the file is not found. + public static ScriptModule load(string filename) + { + return new ScriptModule(_load(filename)); + } + + /// + /// Load a ScriptModule or ScriptFunction previously saved with torch.jit.save + /// + /// The argument type. + /// The return type of the module. + /// The file name of the module. + /// A ScriptModule instance, whether the script originated as a module or function. + /// + /// All previously saved modules, no matter their device, are first loaded onto CPU, and then are moved to the devices they were saved from.If this fails (e.g.because the run time system doesn’t have certain devices), an exception is raised. + /// + /// Raised if the file is not found. + public static ScriptModule load(string filename) + { + return new ScriptModule(_load(filename)); + } + + /// + /// Load a ScriptModule or ScriptFunction previously saved with torch.jit.save + /// + /// The first argument type. + /// The second argument type. + /// The return type of the module. + /// The file name of the module. + /// A ScriptModule instance, whether the script originated as a module or function. + /// + /// All previously saved modules, no matter their device, are first loaded onto CPU, and then are moved to the devices they were saved from.If this fails (e.g.because the run time system doesn’t have certain devices), an exception is raised. + /// + /// Raised if the file is not found. + public static ScriptModule load(string filename) + { + return new ScriptModule(_load(filename)); + } + + private static IntPtr _load(string filename) { if (!System.IO.File.Exists(filename)) throw new System.IO.FileNotFoundException(filename); @@ -376,7 +464,7 @@ public static ScriptModule load(string filename) var result = THSJIT_load(filename); if (result == IntPtr.Zero) CheckForErrors(); - return new ScriptModule(result); + return result; } [DllImport("LibTorchSharp")] @@ -388,8 +476,8 @@ public static ScriptModule load(string filename) /// The saved module serializes all of the methods, submodules, parameters, and attributes of this module. /// It can be loaded into the C++ API using torch::jit::load(filename) or into the .NET API with torch.jit.load(). /// - /// - /// + /// The script module to save. + /// The file name of the module. public static void save(ScriptModule module, string filename) { THSJIT_save(module.handle, filename); diff --git a/src/TorchSharp/NN/Activation/CELU.cs b/src/TorchSharp/NN/Activation/CELU.cs index 008ac8fed..f521244be 100644 --- a/src/TorchSharp/NN/Activation/CELU.cs +++ b/src/TorchSharp/NN/Activation/CELU.cs @@ -12,7 +12,7 @@ namespace Modules /// /// This class is used to represent a CELU module. /// - public class CELU : torch.nn.Module + public class CELU : torch.nn.Module { internal CELU(IntPtr handle, IntPtr boxedHandle) : base(handle, boxedHandle) { } diff --git a/src/TorchSharp/NN/Activation/ELU.cs b/src/TorchSharp/NN/Activation/ELU.cs index cab8136f4..d7e77c708 100644 --- a/src/TorchSharp/NN/Activation/ELU.cs +++ b/src/TorchSharp/NN/Activation/ELU.cs @@ -12,7 +12,7 @@ namespace Modules /// /// This class is used to represent a ELU module. /// - public class ELU : torch.nn.Module + public class ELU : torch.nn.Module { internal ELU(IntPtr handle, IntPtr boxedHandle) : base(handle, boxedHandle) { } diff --git a/src/TorchSharp/NN/Activation/GELU.cs b/src/TorchSharp/NN/Activation/GELU.cs index 9d5925b22..d417422d8 100644 --- a/src/TorchSharp/NN/Activation/GELU.cs +++ b/src/TorchSharp/NN/Activation/GELU.cs @@ -12,7 +12,7 @@ namespace Modules /// /// This class is used to represent a GELU module. /// - public class GELU : torch.nn.Module + public class GELU : torch.nn.Module { internal GELU(IntPtr handle, IntPtr boxedHandle) : base(handle, boxedHandle) { } diff --git a/src/TorchSharp/NN/Activation/GLU.cs b/src/TorchSharp/NN/Activation/GLU.cs index 13c8cf1ac..96633ad63 100644 --- a/src/TorchSharp/NN/Activation/GLU.cs +++ b/src/TorchSharp/NN/Activation/GLU.cs @@ -12,7 +12,7 @@ namespace Modules /// /// This class is used to represent a GLU (gated linear unit) module. /// - public class GLU : torch.nn.Module + public class GLU : torch.nn.Module { internal GLU(IntPtr handle, IntPtr boxedHandle) : base(handle, boxedHandle) { } diff --git a/src/TorchSharp/NN/Activation/Hardshrink.cs b/src/TorchSharp/NN/Activation/Hardshrink.cs index 42c0cd8a4..0ae30ac08 100644 --- a/src/TorchSharp/NN/Activation/Hardshrink.cs +++ b/src/TorchSharp/NN/Activation/Hardshrink.cs @@ -12,7 +12,7 @@ namespace Modules /// /// This class is used to represent a Hardshrink module. /// - public class Hardshrink : torch.nn.Module + public class Hardshrink : torch.nn.Module { internal Hardshrink(IntPtr handle, IntPtr boxedHandle) : base(handle, boxedHandle) { } diff --git a/src/TorchSharp/NN/Activation/Hardsigmoid.cs b/src/TorchSharp/NN/Activation/Hardsigmoid.cs index 406099c64..758106868 100644 --- a/src/TorchSharp/NN/Activation/Hardsigmoid.cs +++ b/src/TorchSharp/NN/Activation/Hardsigmoid.cs @@ -11,7 +11,7 @@ namespace Modules /// /// This class is used to represent a Hardsigmoid module. /// - public class Hardsigmoid : torch.nn.Module + public class Hardsigmoid : torch.nn.Module { private readonly bool inplace; diff --git a/src/TorchSharp/NN/Activation/Hardswish.cs b/src/TorchSharp/NN/Activation/Hardswish.cs index 9910e3058..83716784a 100644 --- a/src/TorchSharp/NN/Activation/Hardswish.cs +++ b/src/TorchSharp/NN/Activation/Hardswish.cs @@ -11,7 +11,7 @@ namespace Modules /// /// This class is used to represent a Hardswish module. /// - public class Hardswish : torch.nn.Module + public class Hardswish : torch.nn.Module { private readonly bool inplace; diff --git a/src/TorchSharp/NN/Activation/Hardtanh.cs b/src/TorchSharp/NN/Activation/Hardtanh.cs index 203d351f5..f5e78b573 100644 --- a/src/TorchSharp/NN/Activation/Hardtanh.cs +++ b/src/TorchSharp/NN/Activation/Hardtanh.cs @@ -12,7 +12,7 @@ namespace Modules /// /// This class is used to represent a Hardtanh module. /// - public class Hardtanh : torch.nn.Module + public class Hardtanh : torch.nn.Module { internal Hardtanh(IntPtr handle, IntPtr boxedHandle) : base(handle, boxedHandle) { } diff --git a/src/TorchSharp/NN/Activation/LeakyReLU.cs b/src/TorchSharp/NN/Activation/LeakyReLU.cs index d21086f5f..3485bc1df 100644 --- a/src/TorchSharp/NN/Activation/LeakyReLU.cs +++ b/src/TorchSharp/NN/Activation/LeakyReLU.cs @@ -12,7 +12,7 @@ namespace Modules /// /// This class is used to represent a LeakyReLU module. /// - public class LeakyReLU : torch.nn.Module + public class LeakyReLU : torch.nn.Module { internal LeakyReLU(IntPtr handle, IntPtr boxedHandle) : base(handle, boxedHandle) { } diff --git a/src/TorchSharp/NN/Activation/LogSoftMax.cs b/src/TorchSharp/NN/Activation/LogSoftMax.cs index 25eb3ddba..ed7fb8382 100644 --- a/src/TorchSharp/NN/Activation/LogSoftMax.cs +++ b/src/TorchSharp/NN/Activation/LogSoftMax.cs @@ -12,7 +12,7 @@ namespace Modules /// /// This class is used to represent a log softmax module. /// - public class LogSoftmax : torch.nn.Module + public class LogSoftmax : torch.nn.Module { internal LogSoftmax(IntPtr handle, IntPtr boxedHandle) : base(handle, boxedHandle) { diff --git a/src/TorchSharp/NN/Activation/Mish.cs b/src/TorchSharp/NN/Activation/Mish.cs index 93b4d9f03..d79da4611 100644 --- a/src/TorchSharp/NN/Activation/Mish.cs +++ b/src/TorchSharp/NN/Activation/Mish.cs @@ -12,7 +12,7 @@ namespace Modules /// /// This class is used to represent a Mish module. /// - public class Mish : torch.nn.Module + public class Mish : torch.nn.Module { internal Mish(IntPtr handle, IntPtr boxedHandle) : base(handle, boxedHandle) { } diff --git a/src/TorchSharp/NN/Activation/RReLU.cs b/src/TorchSharp/NN/Activation/RReLU.cs index 7671d688f..4bf70531f 100644 --- a/src/TorchSharp/NN/Activation/RReLU.cs +++ b/src/TorchSharp/NN/Activation/RReLU.cs @@ -12,7 +12,7 @@ namespace Modules /// /// This class is used to represent a RReLU module. /// - public class RReLU : torch.nn.Module + public class RReLU : torch.nn.Module { internal RReLU(IntPtr handle, IntPtr boxedHandle) : base(handle, boxedHandle) { } diff --git a/src/TorchSharp/NN/Activation/ReLU6.cs b/src/TorchSharp/NN/Activation/ReLU6.cs index 60278fca6..48064dd39 100644 --- a/src/TorchSharp/NN/Activation/ReLU6.cs +++ b/src/TorchSharp/NN/Activation/ReLU6.cs @@ -12,7 +12,7 @@ namespace Modules /// /// This class is used to represent a ReLU6 module. /// - public class ReLU6 : torch.nn.Module + public class ReLU6 : torch.nn.Module { internal ReLU6(IntPtr handle, IntPtr boxedHandle) : base(handle, boxedHandle) { } diff --git a/src/TorchSharp/NN/Activation/ReLu.cs b/src/TorchSharp/NN/Activation/ReLu.cs index 2d6e44a8c..ce2ede482 100644 --- a/src/TorchSharp/NN/Activation/ReLu.cs +++ b/src/TorchSharp/NN/Activation/ReLu.cs @@ -12,7 +12,7 @@ namespace Modules /// /// This class is used to represent a ReLU module. /// - public class ReLU : torch.nn.Module + public class ReLU : torch.nn.Module { internal ReLU(IntPtr handle, IntPtr boxedHandle) : base(handle, boxedHandle) { } diff --git a/src/TorchSharp/NN/Activation/SELU.cs b/src/TorchSharp/NN/Activation/SELU.cs index 07edb5c22..8659a5c78 100644 --- a/src/TorchSharp/NN/Activation/SELU.cs +++ b/src/TorchSharp/NN/Activation/SELU.cs @@ -12,7 +12,7 @@ namespace Modules /// /// This class is used to represent a SELU module. /// - public class SELU : torch.nn.Module + public class SELU : torch.nn.Module { internal SELU(IntPtr handle, IntPtr boxedHandle) : base(handle, boxedHandle) { } diff --git a/src/TorchSharp/NN/Activation/SiLU.cs b/src/TorchSharp/NN/Activation/SiLU.cs index 04f720837..0e79b74db 100644 --- a/src/TorchSharp/NN/Activation/SiLU.cs +++ b/src/TorchSharp/NN/Activation/SiLU.cs @@ -12,7 +12,7 @@ namespace Modules /// /// This class is used to represent a SiLU module. /// - public class SiLU : torch.nn.Module + public class SiLU : torch.nn.Module { internal SiLU(IntPtr handle, IntPtr boxedHandle) : base(handle, boxedHandle) { } diff --git a/src/TorchSharp/NN/Activation/Sigmoid.cs b/src/TorchSharp/NN/Activation/Sigmoid.cs index 4b9287b8e..2e4781a6f 100644 --- a/src/TorchSharp/NN/Activation/Sigmoid.cs +++ b/src/TorchSharp/NN/Activation/Sigmoid.cs @@ -12,7 +12,7 @@ namespace Modules /// /// This class is used to represent a Sigmoid module. /// - public class Sigmoid : torch.nn.Module + public class Sigmoid : torch.nn.Module { internal Sigmoid(IntPtr handle, IntPtr boxedHandle) : base(handle, boxedHandle) { } diff --git a/src/TorchSharp/NN/Activation/Softmax.cs b/src/TorchSharp/NN/Activation/Softmax.cs index 51c73bd86..026ba08b4 100644 --- a/src/TorchSharp/NN/Activation/Softmax.cs +++ b/src/TorchSharp/NN/Activation/Softmax.cs @@ -12,7 +12,7 @@ namespace Modules /// /// This class is used to represent a Softmax module. /// - public class Softmax : torch.nn.Module + public class Softmax : torch.nn.Module { internal Softmax(IntPtr handle, IntPtr boxedHandle) : base(handle, boxedHandle) { } diff --git a/src/TorchSharp/NN/Activation/Softmax2d.cs b/src/TorchSharp/NN/Activation/Softmax2d.cs index 25e7a480b..65aecc39e 100644 --- a/src/TorchSharp/NN/Activation/Softmax2d.cs +++ b/src/TorchSharp/NN/Activation/Softmax2d.cs @@ -12,7 +12,7 @@ namespace Modules /// /// This class is used to represent a Softmax2d module. /// - public class Softmax2d : torch.nn.Module + public class Softmax2d : torch.nn.Module { internal Softmax2d(IntPtr handle, IntPtr boxedHandle) : base(handle, boxedHandle) { } diff --git a/src/TorchSharp/NN/Activation/Softmin.cs b/src/TorchSharp/NN/Activation/Softmin.cs index 260eec624..90cafb8e7 100644 --- a/src/TorchSharp/NN/Activation/Softmin.cs +++ b/src/TorchSharp/NN/Activation/Softmin.cs @@ -12,7 +12,7 @@ namespace Modules /// /// This class is used to represent a Softmin module. /// - public class Softmin : torch.nn.Module + public class Softmin : torch.nn.Module { internal Softmin(IntPtr handle, IntPtr boxedHandle) : base(handle, boxedHandle) { } diff --git a/src/TorchSharp/NN/Activation/Softplus.cs b/src/TorchSharp/NN/Activation/Softplus.cs index 8ae453579..61c398bf1 100644 --- a/src/TorchSharp/NN/Activation/Softplus.cs +++ b/src/TorchSharp/NN/Activation/Softplus.cs @@ -12,7 +12,7 @@ namespace Modules /// /// This class is used to represent a Softplus module. /// - public class Softplus : torch.nn.Module + public class Softplus : torch.nn.Module { internal Softplus(IntPtr handle, IntPtr boxedHandle) : base(handle, boxedHandle) { } diff --git a/src/TorchSharp/NN/Activation/Softshrink.cs b/src/TorchSharp/NN/Activation/Softshrink.cs index 3d5dbb0b3..e6143f109 100644 --- a/src/TorchSharp/NN/Activation/Softshrink.cs +++ b/src/TorchSharp/NN/Activation/Softshrink.cs @@ -12,7 +12,7 @@ namespace Modules /// /// This class is used to represent a Softshrink module. /// - public class Softshrink : torch.nn.Module + public class Softshrink : torch.nn.Module { internal Softshrink(IntPtr handle, IntPtr boxedHandle) : base(handle, boxedHandle) { } diff --git a/src/TorchSharp/NN/Activation/Softsign.cs b/src/TorchSharp/NN/Activation/Softsign.cs index bd3fbfdcc..3a7e47f0d 100644 --- a/src/TorchSharp/NN/Activation/Softsign.cs +++ b/src/TorchSharp/NN/Activation/Softsign.cs @@ -12,7 +12,7 @@ namespace Modules /// /// This class is used to represent a Softsign module. /// - public class Softsign : torch.nn.Module + public class Softsign : torch.nn.Module { internal Softsign(IntPtr handle, IntPtr boxedHandle) : base(handle, boxedHandle) { } diff --git a/src/TorchSharp/NN/Activation/Tanh.cs b/src/TorchSharp/NN/Activation/Tanh.cs index 490e0aa92..f874de3a1 100644 --- a/src/TorchSharp/NN/Activation/Tanh.cs +++ b/src/TorchSharp/NN/Activation/Tanh.cs @@ -12,7 +12,7 @@ namespace Modules /// /// This class is used to represent a Tanh module. /// - public class Tanh : torch.nn.Module + public class Tanh : torch.nn.Module { internal Tanh(IntPtr handle, IntPtr boxedHandle) : base(handle, boxedHandle) { } diff --git a/src/TorchSharp/NN/Activation/Tanhshrink.cs b/src/TorchSharp/NN/Activation/Tanhshrink.cs index 286570df8..6e26c78ec 100644 --- a/src/TorchSharp/NN/Activation/Tanhshrink.cs +++ b/src/TorchSharp/NN/Activation/Tanhshrink.cs @@ -12,7 +12,7 @@ namespace Modules /// /// This class is used to represent a Tanhshrink module. /// - public class Tanhshrink : torch.nn.Module + public class Tanhshrink : torch.nn.Module { internal Tanhshrink(IntPtr handle, IntPtr boxedHandle) : base(handle, boxedHandle) { } diff --git a/src/TorchSharp/NN/Activation/Threshold.cs b/src/TorchSharp/NN/Activation/Threshold.cs index 76de3757e..2ddefbae9 100644 --- a/src/TorchSharp/NN/Activation/Threshold.cs +++ b/src/TorchSharp/NN/Activation/Threshold.cs @@ -12,7 +12,7 @@ namespace Modules /// /// This class is used to represent a Threshold module. /// - public class Threshold : torch.nn.Module + public class Threshold : torch.nn.Module { internal Threshold(IntPtr handle, IntPtr boxedHandle) : base(handle, boxedHandle) { } diff --git a/src/TorchSharp/NN/AlphaDropout.cs b/src/TorchSharp/NN/AlphaDropout.cs index 9300015ed..687437c7c 100644 --- a/src/TorchSharp/NN/AlphaDropout.cs +++ b/src/TorchSharp/NN/AlphaDropout.cs @@ -17,7 +17,7 @@ namespace Modules /// The elements to masked are randomized on every forward call, and scaled and shifted to maintain zero mean and unit standard deviation. /// During evaluation the module simply computes an identity function. /// - public class AlphaDropout : torch.nn.Module + public class AlphaDropout : torch.nn.Module { internal AlphaDropout(IntPtr handle, IntPtr boxedHandle) : base(handle, boxedHandle) { } diff --git a/src/TorchSharp/NN/Bilinear.cs b/src/TorchSharp/NN/Bilinear.cs index a817f8872..51f8066d8 100644 --- a/src/TorchSharp/NN/Bilinear.cs +++ b/src/TorchSharp/NN/Bilinear.cs @@ -11,13 +11,13 @@ namespace TorchSharp namespace Modules { - public class Bilinear : torch.nn.Module + public class Bilinear : torch.nn.Module { internal Bilinear(IntPtr handle, IntPtr boxedHandle) : base(handle, boxedHandle) { } public new static Bilinear Load(String modelPath) { - var res = Module.Load(modelPath); + var res = Module.Load(modelPath); return new Bilinear(res.handle.DangerousGetHandle(), IntPtr.Zero); } diff --git a/src/TorchSharp/NN/Convolution/Conv1D.cs b/src/TorchSharp/NN/Convolution/Conv1D.cs index 5fa57f434..a1903b371 100644 --- a/src/TorchSharp/NN/Convolution/Conv1D.cs +++ b/src/TorchSharp/NN/Convolution/Conv1D.cs @@ -25,7 +25,7 @@ public enum Padding namespace Modules { - public class Conv1d : torch.nn.Module + public class Conv1d : torch.nn.Module { internal Conv1d(IntPtr handle, IntPtr boxedHandle) : base(handle, boxedHandle) { } diff --git a/src/TorchSharp/NN/Convolution/Conv2D.cs b/src/TorchSharp/NN/Convolution/Conv2D.cs index 726dda35f..34599da46 100644 --- a/src/TorchSharp/NN/Convolution/Conv2D.cs +++ b/src/TorchSharp/NN/Convolution/Conv2D.cs @@ -10,7 +10,7 @@ namespace TorchSharp namespace Modules { - public class Conv2d : torch.nn.Module + public class Conv2d : torch.nn.Module { internal Conv2d(IntPtr handle, IntPtr boxedHandle) : base(handle, boxedHandle) { } diff --git a/src/TorchSharp/NN/Convolution/Conv3D.cs b/src/TorchSharp/NN/Convolution/Conv3D.cs index 8ca8e6551..bb82ccbe9 100644 --- a/src/TorchSharp/NN/Convolution/Conv3D.cs +++ b/src/TorchSharp/NN/Convolution/Conv3D.cs @@ -10,7 +10,7 @@ namespace TorchSharp namespace Modules { - public class Conv3d : torch.nn.Module + public class Conv3d : torch.nn.Module { internal Conv3d(IntPtr handle, IntPtr boxedHandle) : base(handle, boxedHandle) { } diff --git a/src/TorchSharp/NN/Convolution/ConvTranspose1D.cs b/src/TorchSharp/NN/Convolution/ConvTranspose1D.cs index 90586092e..af17e690e 100644 --- a/src/TorchSharp/NN/Convolution/ConvTranspose1D.cs +++ b/src/TorchSharp/NN/Convolution/ConvTranspose1D.cs @@ -10,7 +10,7 @@ namespace TorchSharp namespace Modules { - public class ConvTranspose1d : torch.nn.Module + public class ConvTranspose1d : torch.nn.Module { internal ConvTranspose1d(IntPtr handle, IntPtr boxedHandle) : base(handle, boxedHandle) { } diff --git a/src/TorchSharp/NN/Convolution/ConvTranspose2D.cs b/src/TorchSharp/NN/Convolution/ConvTranspose2D.cs index f013f5d34..fee90fe51 100644 --- a/src/TorchSharp/NN/Convolution/ConvTranspose2D.cs +++ b/src/TorchSharp/NN/Convolution/ConvTranspose2D.cs @@ -10,7 +10,7 @@ namespace TorchSharp namespace Modules { - public class ConvTranspose2d : torch.nn.Module + public class ConvTranspose2d : torch.nn.Module { internal ConvTranspose2d(IntPtr handle, IntPtr boxedHandle) : base(handle, boxedHandle) { } diff --git a/src/TorchSharp/NN/Convolution/ConvTranspose3D.cs b/src/TorchSharp/NN/Convolution/ConvTranspose3D.cs index 512ca107f..0b9868d5f 100644 --- a/src/TorchSharp/NN/Convolution/ConvTranspose3D.cs +++ b/src/TorchSharp/NN/Convolution/ConvTranspose3D.cs @@ -10,7 +10,7 @@ namespace TorchSharp namespace Modules { - public class ConvTranspose3d : torch.nn.Module + public class ConvTranspose3d : torch.nn.Module { internal ConvTranspose3d(IntPtr handle, IntPtr boxedHandle) : base(handle, boxedHandle) { } diff --git a/src/TorchSharp/NN/CosineSimilarity.cs b/src/TorchSharp/NN/CosineSimilarity.cs index f6f442570..0ad28970f 100644 --- a/src/TorchSharp/NN/CosineSimilarity.cs +++ b/src/TorchSharp/NN/CosineSimilarity.cs @@ -12,7 +12,7 @@ namespace Modules /// /// This class is used to represent a dropout module for 2d/3d convolutational layers. /// - public class CosineSimilarity : torch.nn.Module + public class CosineSimilarity : torch.nn.Module { internal CosineSimilarity(IntPtr handle, IntPtr boxedHandle) : base(handle, boxedHandle) { diff --git a/src/TorchSharp/NN/Dropout.cs b/src/TorchSharp/NN/Dropout.cs index 916ece1af..f0080d3bf 100644 --- a/src/TorchSharp/NN/Dropout.cs +++ b/src/TorchSharp/NN/Dropout.cs @@ -12,7 +12,7 @@ namespace Modules /// /// This class is used to represent a dropout module. /// - public class Dropout : torch.nn.Module + public class Dropout : torch.nn.Module { internal Dropout(IntPtr handle, IntPtr boxedHandle) : base(handle, boxedHandle) { } diff --git a/src/TorchSharp/NN/Dropout2d.cs b/src/TorchSharp/NN/Dropout2d.cs index 35c698b33..6d3eaa366 100644 --- a/src/TorchSharp/NN/Dropout2d.cs +++ b/src/TorchSharp/NN/Dropout2d.cs @@ -12,7 +12,7 @@ namespace Modules /// /// This class is used to represent a Dropout2d module. /// - public class Dropout2d : torch.nn.Module + public class Dropout2d : torch.nn.Module { internal Dropout2d(IntPtr handle, IntPtr boxedHandle) : base(handle, boxedHandle) { } diff --git a/src/TorchSharp/NN/Dropout3d.cs b/src/TorchSharp/NN/Dropout3d.cs index 7107da8ea..b00c44e92 100644 --- a/src/TorchSharp/NN/Dropout3d.cs +++ b/src/TorchSharp/NN/Dropout3d.cs @@ -12,7 +12,7 @@ namespace Modules /// /// This class is used to represent a Dropout3d module. /// - public class Dropout3d : torch.nn.Module + public class Dropout3d : torch.nn.Module { internal Dropout3d(IntPtr handle, IntPtr boxedHandle) : base(handle, boxedHandle) { } diff --git a/src/TorchSharp/NN/Embedding.cs b/src/TorchSharp/NN/Embedding.cs index 262e9f551..e91bad9d0 100644 --- a/src/TorchSharp/NN/Embedding.cs +++ b/src/TorchSharp/NN/Embedding.cs @@ -10,7 +10,7 @@ namespace TorchSharp namespace Modules { - public class Embedding : torch.nn.Module + public class Embedding : torch.nn.Module { internal Embedding(IntPtr handle, IntPtr boxedHandle) : base(handle, boxedHandle) { } diff --git a/src/TorchSharp/NN/EmbeddingBag.cs b/src/TorchSharp/NN/EmbeddingBag.cs index 32ffd9acf..987a3114d 100644 --- a/src/TorchSharp/NN/EmbeddingBag.cs +++ b/src/TorchSharp/NN/EmbeddingBag.cs @@ -17,7 +17,7 @@ public enum EmbeddingBagMode namespace Modules { - public class EmbeddingBag : torch.nn.Module + public class EmbeddingBag : torch.nn.Module { internal EmbeddingBag(IntPtr handle, IntPtr boxedHandle) : base(handle, boxedHandle) { } @@ -33,7 +33,7 @@ internal EmbeddingBag(IntPtr handle, IntPtr boxedHandle) : base(handle, boxedHan /// If specified, per_sample_weights must have exactly the same shape as input and is treated as having the same offsets, if those are not None. /// Only supported for mode='sum'. /// - public override Tensor forward(Tensor input, Tensor offsets, Tensor perSampleWeights) + public Tensor forward(Tensor input, Tensor offsets, Tensor perSampleWeights) { if (!input.IsIntegral()) throw new ArgumentException("Embedding input must be an integral tensor."); if (!(offsets is null) && input.dtype != offsets.dtype) throw new ArgumentException("input and offsets must have the same element type."); @@ -53,7 +53,7 @@ public override Tensor forward(Tensor input, Tensor offsets, Tensor perSampleWei /// Tensor containing bags of indices into the embedding matrix. /// Only used when input is 1D. offsets determines the starting index position of each bag (sequence) in input. /// - public override Tensor forward(Tensor input, Tensor offsets) + public Tensor forward(Tensor input, Tensor offsets) { if (!input.IsIntegral()) throw new ArgumentException("Embedding input must be an integral tensor."); if (!(offsets is null) && input.dtype != offsets.dtype) throw new ArgumentException("input and offsets must have the same element type."); diff --git a/src/TorchSharp/NN/FeatureDropout.cs b/src/TorchSharp/NN/FeatureDropout.cs index fbc4bf5ea..7d4fd62f8 100644 --- a/src/TorchSharp/NN/FeatureDropout.cs +++ b/src/TorchSharp/NN/FeatureDropout.cs @@ -12,7 +12,7 @@ namespace Modules /// /// This class is used to represent a dropout module for 2d/3d convolutational layers. /// - public class FeatureAlphaDropout : torch.nn.Module + public class FeatureAlphaDropout : torch.nn.Module { internal FeatureAlphaDropout(IntPtr handle, IntPtr boxedHandle) : base(handle, boxedHandle) { diff --git a/src/TorchSharp/NN/Flatten.cs b/src/TorchSharp/NN/Flatten.cs index ba2ab5a6e..dfe20821c 100644 --- a/src/TorchSharp/NN/Flatten.cs +++ b/src/TorchSharp/NN/Flatten.cs @@ -12,7 +12,7 @@ namespace Modules /// /// This class is used to represent a dropout module for 2d/3d convolutational layers. /// - public class Flatten : torch.nn.Module + public class Flatten : torch.nn.Module { internal Flatten(IntPtr handle, IntPtr boxedHandle) : base(handle, boxedHandle) { diff --git a/src/TorchSharp/NN/Identity.cs b/src/TorchSharp/NN/Identity.cs index 4242adeca..91ed3e074 100644 --- a/src/TorchSharp/NN/Identity.cs +++ b/src/TorchSharp/NN/Identity.cs @@ -10,7 +10,7 @@ namespace TorchSharp namespace Modules { - public class Identity : torch.nn.Module + public class Identity : torch.nn.Module { internal Identity(IntPtr handle, IntPtr boxedHandle) : base(handle, boxedHandle) { } diff --git a/src/TorchSharp/NN/Linear.cs b/src/TorchSharp/NN/Linear.cs index b1e0358c9..670c2623c 100644 --- a/src/TorchSharp/NN/Linear.cs +++ b/src/TorchSharp/NN/Linear.cs @@ -12,7 +12,7 @@ namespace TorchSharp namespace Modules { - public class Linear : torch.nn.Module + public class Linear : torch.nn.Module { internal Linear(IntPtr handle, IntPtr boxedHandle) : base(handle, boxedHandle) { @@ -20,7 +20,7 @@ internal Linear(IntPtr handle, IntPtr boxedHandle) : base(handle, boxedHandle) public new static Linear Load(String modelPath) { - var res = Module.Load(modelPath); + var res = Module.Load(modelPath); return new Linear(res.handle.DangerousGetHandle(), IntPtr.Zero); } diff --git a/src/TorchSharp/NN/Losses.cs b/src/TorchSharp/NN/Losses.cs index 98aaedca4..b368cc029 100644 --- a/src/TorchSharp/NN/Losses.cs +++ b/src/TorchSharp/NN/Losses.cs @@ -12,16 +12,56 @@ namespace TorchSharp { using Modules; - public class Loss : nn.Module + public abstract class Loss : nn.Module { - public Loss(torch.nn.Reduction reduction = nn.Reduction.Mean) : base(nameof(Loss)) + public Loss(torch.nn.Reduction reduction = nn.Reduction.Mean) : base(nameof(Loss)) { this.reduction = reduction; } public torch.nn.Reduction reduction { get; } } - public class WeightedLoss : Loss + + public abstract class Loss : nn.Module + { + public Loss(torch.nn.Reduction reduction = nn.Reduction.Mean) : base(nameof(Loss)) + { + this.reduction = reduction; + } + + public torch.nn.Reduction reduction { get; } + } + public abstract class Loss : nn.Module + { + public Loss(torch.nn.Reduction reduction = nn.Reduction.Mean) : base(nameof(Loss)) + { + this.reduction = reduction; + } + + public torch.nn.Reduction reduction { get; } + } + + public abstract class WeightedLoss : Loss + { + public WeightedLoss(Tensor? weight = null, torch.nn.Reduction reduction = nn.Reduction.Mean) : base(reduction) + { + this.weight = weight; + } + + public Tensor? weight { get; } + } + + public abstract class WeightedLoss : Loss + { + public WeightedLoss(Tensor? weight = null, torch.nn.Reduction reduction = nn.Reduction.Mean) : base(reduction) + { + this.weight = weight; + } + + public Tensor? weight { get; } + } + + public abstract class WeightedLoss : Loss { public WeightedLoss(Tensor? weight = null, torch.nn.Reduction reduction = nn.Reduction.Mean) : base(reduction) { @@ -758,7 +798,7 @@ namespace Modules { using static torch.nn.functional; - public sealed class CrossEntropyLoss : WeightedLoss + public sealed class CrossEntropyLoss : WeightedLoss { public CrossEntropyLoss(Tensor? weight = null, long? ignore_index = null, Reduction reduction = Reduction.Mean) : base(weight, reduction) { @@ -776,7 +816,7 @@ public override Tensor forward(Tensor input, Tensor target) public long? ignore_index { get; } } - public sealed class BCELoss : WeightedLoss + public sealed class BCELoss : WeightedLoss { public BCELoss(Tensor? weight = null, Reduction reduction = Reduction.Mean) : base(weight, reduction) { @@ -790,7 +830,7 @@ public override Tensor forward(Tensor input, Tensor target) } } - public sealed class BCEWithLogitsLoss : WeightedLoss + public sealed class BCEWithLogitsLoss : WeightedLoss { public BCEWithLogitsLoss(Tensor? weight = null, Reduction reduction = Reduction.Mean, Tensor? pos_weights = null) : base(weight, reduction) { @@ -807,7 +847,7 @@ public override Tensor forward(Tensor input, Tensor target) public Tensor? pos_weights { get; } } - public sealed class CosineEmbeddingLoss : Loss + public sealed class CosineEmbeddingLoss : Loss { public CosineEmbeddingLoss(double margin = 0.0, Reduction reduction = Reduction.Mean) : base(reduction) { @@ -824,7 +864,7 @@ public override Tensor forward(Tensor input1, Tensor input2, Tensor target) public double margin { get; } } - public sealed class CTCLoss : Loss + public sealed class CTCLoss : Loss { public CTCLoss(long blank = 0, bool zero_infinity = false, Reduction reduction = Reduction.Mean) : base(reduction) { @@ -832,7 +872,7 @@ public CTCLoss(long blank = 0, bool zero_infinity = false, Reduction reduction = this.zero_infinity = zero_infinity; } - public Tensor forward(Tensor log_probs, Tensor targets, Tensor input_lengths, Tensor target_lengths) + public override Tensor forward(Tensor log_probs, Tensor targets, Tensor input_lengths, Tensor target_lengths) { var res = THSNN_ctc_loss(log_probs.Handle, targets.Handle, input_lengths.Handle, target_lengths.Handle, blank, zero_infinity, (long)reduction); if (res == IntPtr.Zero) { torch.CheckForErrors(); } @@ -843,7 +883,7 @@ public Tensor forward(Tensor log_probs, Tensor targets, Tensor input_lengths, Te public bool zero_infinity { get; } } - public sealed class HingeEmbeddingLoss : Loss + public sealed class HingeEmbeddingLoss : Loss { public HingeEmbeddingLoss(double margin = 0.0, Reduction reduction = Reduction.Mean) : base(reduction) { @@ -860,7 +900,7 @@ public override Tensor forward(Tensor input, Tensor target) public double margin { get; } } - public sealed class HuberLoss : Loss + public sealed class HuberLoss : Loss { public HuberLoss(double delta = 1.0, Reduction reduction = Reduction.Mean) : base(reduction) { @@ -877,7 +917,7 @@ public override Tensor forward(Tensor input, Tensor target) public double delta { get; } } - public sealed class MarginRankingLoss : Loss + public sealed class MarginRankingLoss : Loss { public MarginRankingLoss(double margin = 0.0, Reduction reduction = Reduction.Mean) : base(reduction) { @@ -894,7 +934,7 @@ public override Tensor forward(Tensor input1, Tensor input2, Tensor target) public double margin { get; } } - public sealed class MultiLabelMarginLoss : Loss + public sealed class MultiLabelMarginLoss : Loss { public MultiLabelMarginLoss(Reduction reduction = Reduction.Mean) : base(reduction) { @@ -908,7 +948,7 @@ public override Tensor forward(Tensor input, Tensor target) } } - public sealed class MultiLabelSoftMarginLoss : WeightedLoss + public sealed class MultiLabelSoftMarginLoss : WeightedLoss { public MultiLabelSoftMarginLoss(Tensor? weight = null, Reduction reduction = Reduction.Mean) : base(weight, reduction) { @@ -922,7 +962,7 @@ public override Tensor forward(Tensor input, Tensor target) } } - public sealed class MultiMarginLoss : WeightedLoss + public sealed class MultiMarginLoss : WeightedLoss { public MultiMarginLoss(int p = 1, double margin = 1.0, Tensor? weight = null, Reduction reduction = Reduction.Mean) : base(weight, reduction) { @@ -943,7 +983,7 @@ public override Tensor forward(Tensor input, Tensor target) public int p { get; } } - public sealed class MSELoss : Loss + public sealed class MSELoss : Loss { public MSELoss(Reduction reduction = Reduction.Mean) : base(reduction) { @@ -957,7 +997,7 @@ public override Tensor forward(Tensor input, Tensor target) } } - public sealed class L1Loss : Loss + public sealed class L1Loss : Loss { public L1Loss(Reduction reduction = Reduction.Mean) : base(reduction) { @@ -971,7 +1011,7 @@ public override Tensor forward(Tensor input, Tensor target) } } - public sealed class NLLLoss : WeightedLoss + public sealed class NLLLoss : WeightedLoss { public NLLLoss(Tensor? weight = null, Reduction reduction = Reduction.Mean) : base(weight, reduction) { @@ -985,7 +1025,7 @@ public override Tensor forward(Tensor input, Tensor target) } } - public sealed class PoissonNLLLoss : Loss + public sealed class PoissonNLLLoss : Loss { public PoissonNLLLoss(bool log_input = true, bool full = false, float eps = 1e-8f, Reduction reduction = Reduction.Mean) : base(reduction) { @@ -1007,7 +1047,7 @@ public override Tensor forward(Tensor input, Tensor target) } - public sealed class GaussianNLLLoss : Loss + public sealed class GaussianNLLLoss : Loss { public GaussianNLLLoss(bool full = false, float eps = 1e-8f, Reduction reduction = Reduction.Mean) : base(reduction) { @@ -1043,7 +1083,7 @@ public override Tensor forward(Tensor input, Tensor target, Tensor variance) } - public sealed class KLDivLoss : Loss + public sealed class KLDivLoss : Loss { public KLDivLoss(bool log_target = true, Reduction reduction = Reduction.Mean) : base(reduction) { @@ -1060,7 +1100,7 @@ public override Tensor forward(Tensor input, Tensor target) public bool log_target { get; } } - public sealed class SmoothL1Loss : Loss + public sealed class SmoothL1Loss : Loss { public SmoothL1Loss(Reduction reduction = Reduction.Mean, double beta = 1.0) : base(reduction) { @@ -1077,7 +1117,7 @@ public override Tensor forward(Tensor input, Tensor target) public double beta { get; } } - public sealed class SoftMarginLoss : Loss + public sealed class SoftMarginLoss : Loss { public SoftMarginLoss(Reduction reduction = Reduction.Mean) : base(reduction) { @@ -1091,7 +1131,7 @@ public override Tensor forward(Tensor input, Tensor target) } } - public sealed class TripletMarginLoss : Loss + public sealed class TripletMarginLoss : Loss { public TripletMarginLoss(double margin = 1.0, long p = 2, double eps = 1e-06, bool swap = false, Reduction reduction = Reduction.Mean) : base(reduction) { @@ -1114,7 +1154,7 @@ public override Tensor forward(Tensor anchor, Tensor positive, Tensor negative) bool swap { get; } } - public sealed class TripletMarginWithDistanceLoss : Loss + public sealed class TripletMarginWithDistanceLoss : Loss { public TripletMarginWithDistanceLoss(Func? distance = null, double margin = 1.0, bool swap = false, Reduction reduction = Reduction.Mean) : base(reduction) { diff --git a/src/TorchSharp/NN/Module.cs b/src/TorchSharp/NN/Module.cs index 47c0b0c0b..161a4cdf8 100644 --- a/src/TorchSharp/NN/Module.cs +++ b/src/TorchSharp/NN/Module.cs @@ -31,7 +31,7 @@ public class Module : IDisposable /// /// Class wrapping PyTorch's module object reference. /// - internal sealed class HType : SafeHandle + internal protected sealed class HType : SafeHandle { public HType(IntPtr preexistingHandle, bool ownsHandle, Action dispose = null) : base(IntPtr.Zero, ownsHandle) @@ -803,15 +803,6 @@ public virtual string GetName() return res; } - public virtual Tensor forward(Tensor t) - => throw new NotImplementedException("forward(t)"); - - public virtual Tensor forward(Tensor x, Tensor y) - => throw new NotImplementedException("forward(x,y)"); - - public virtual Tensor forward(Tensor x, Tensor y, Tensor z) - => throw new NotImplementedException("forward(x,y,z)"); - /// /// Save the parameters and buffers of the module to a disk location. /// @@ -979,7 +970,7 @@ protected Module(string name) : this(IntPtr.Zero, IntPtr.Zero) IntPtr ForwardNative(IntPtr t) { var input = new Tensor(t); - var output = forward(input); + var output = ((nn.Module)this).forward(input); // handles must live on - we don't own them, but // the managed objects should go away. @@ -1120,11 +1111,121 @@ protected void Dispose(bool disposing) } } } + + /// + /// Interface for concrete modules with a forward() that takes a single argument. + /// + /// The argument type of the module's forward() function. + /// The return type of the module's forward() function. + public interface IModule + { + public abstract TResult forward(T input1); + } + + /// + /// Interface for concrete modules with a forward() that takes two arguments. + /// + /// The first argument type of the module's forward() function. + /// The second argument type of the module's forward() function. + /// The return type of the module's forward() function. + public interface IModule + { + public abstract TResult forward(T1 input1, T2 input2); + } + + /// + /// Interface for concrete modules with a forward() that takes three arguments. + /// + /// The first argument type of the module's forward() function. + /// The second argument type of the module's forward() function. + /// The third argument type of the module's forward() function. + /// The return type of the module's forward() function. + public interface IModule + { + public abstract TResult forward(T1 input1, T2 input2, T3 input3); + } + + /// + /// Interface for concrete modules with a forward() that takes four arguments. + /// + /// The first argument type of the module's forward() function. + /// The second argument type of the module's forward() function. + /// The third argument type of the module's forward() function. + /// The fourth argument type of the module's forward() function. + /// The return type of the module's forward() function. + public interface IModule + { + public abstract TResult forward(T1 input1, T2 input2, T3 input3, T4 input4); + } + + + /// + /// Base class for concrete modules with a forward() that takes a single argument. + /// + /// The argument type of the module's forward() function. + /// The return type of the module's forward() function. + public abstract class Module : Module, IModule + { + protected Module(string name) : base(name) { } + protected Module(IntPtr handle, IntPtr boxedHandle) : base(handle, boxedHandle) { } + internal Module(HType handle, IntPtr? boxedHandle) : base(handle, boxedHandle) { } + public abstract TResult forward(T input1); + } + + /// + /// Base class for concrete modules with a forward() that takes two arguments. + /// + /// The first argument type of the module's forward() function. + /// The second argument type of the module's forward() function. + /// The return type of the module's forward() function. + public abstract class Module : Module, IModule + { + protected Module(string name) : base(name) { } + protected Module(IntPtr handle, IntPtr boxedHandle) : base(handle, boxedHandle) { } + internal Module(HType handle, IntPtr? boxedHandle) : base(handle, boxedHandle) { } + public abstract TResult forward(T1 input1, T2 input2); + } + + /// + /// Base class for concrete modules with a forward() that takes three arguments. + /// + /// The first argument type of the module's forward() function. + /// The second argument type of the module's forward() function. + /// The third argument type of the module's forward() function. + /// The return type of the module's forward() function. + public abstract class Module : Module, IModule + { + protected Module(string name) : base(name) { } + protected Module(IntPtr handle, IntPtr boxedHandle) : base(handle, boxedHandle) { } + internal Module(HType handle, IntPtr? boxedHandle) : base(handle, boxedHandle) { } + public abstract TResult forward(T1 input1, T2 input2, T3 input3); + } + + /// + /// Base class for concrete modules with a forward() that takes four arguments. + /// + /// The first argument type of the module's forward() function. + /// The second argument type of the module's forward() function. + /// The third argument type of the module's forward() function. + /// The fourth argument type of the module's forward() function. + /// The return type of the module's forward() function. + public abstract class Module : Module, IModule + { + protected Module(string name) : base(name) { } + protected Module(IntPtr handle, IntPtr boxedHandle) : base(handle, boxedHandle) { } + internal Module(HType handle, IntPtr? boxedHandle) : base(handle, boxedHandle) { } + public abstract TResult forward(T1 input1, T2 input2, T3 input3, T4 input4); + } } } public static class ModuleExtensionMethods { + /// + /// Converts the parameters and buffers. + /// + /// The module to move + /// The target element type. public static T to(this T module, torch.ScalarType type) where T : torch.nn.Module { return (T)module._to(type); diff --git a/src/TorchSharp/NN/Normalization/BatchNorm1D.cs b/src/TorchSharp/NN/Normalization/BatchNorm1D.cs index 71aa0d863..b6a21120f 100644 --- a/src/TorchSharp/NN/Normalization/BatchNorm1D.cs +++ b/src/TorchSharp/NN/Normalization/BatchNorm1D.cs @@ -13,7 +13,7 @@ namespace Modules /// /// This class is used to represent a BatchNorm1D module. /// - public class BatchNorm1d : torch.nn.Module + public class BatchNorm1d : torch.nn.Module { internal BatchNorm1d(IntPtr handle, IntPtr boxedHandle) : base(handle, boxedHandle) { diff --git a/src/TorchSharp/NN/Normalization/BatchNorm2D.cs b/src/TorchSharp/NN/Normalization/BatchNorm2D.cs index 162b194fc..fff1b50ae 100644 --- a/src/TorchSharp/NN/Normalization/BatchNorm2D.cs +++ b/src/TorchSharp/NN/Normalization/BatchNorm2D.cs @@ -13,7 +13,7 @@ namespace Modules /// /// This class is used to represent a BatchNorm2D module. /// - public class BatchNorm2d : torch.nn.Module + public class BatchNorm2d : torch.nn.Module { internal BatchNorm2d(IntPtr handle, IntPtr boxedHandle) : base(handle, boxedHandle) { diff --git a/src/TorchSharp/NN/Normalization/BatchNorm3D.cs b/src/TorchSharp/NN/Normalization/BatchNorm3D.cs index 5e6a5cd39..175eee571 100644 --- a/src/TorchSharp/NN/Normalization/BatchNorm3D.cs +++ b/src/TorchSharp/NN/Normalization/BatchNorm3D.cs @@ -13,7 +13,7 @@ namespace Modules /// /// This class is used to represent a BatchNorm3D module. /// - public class BatchNorm3d : torch.nn.Module + public class BatchNorm3d : torch.nn.Module { internal BatchNorm3d(IntPtr handle, IntPtr boxedHandle) : base(handle, boxedHandle) { diff --git a/src/TorchSharp/NN/Normalization/GroupNorm.cs b/src/TorchSharp/NN/Normalization/GroupNorm.cs index 092c194cb..999211755 100644 --- a/src/TorchSharp/NN/Normalization/GroupNorm.cs +++ b/src/TorchSharp/NN/Normalization/GroupNorm.cs @@ -14,7 +14,7 @@ namespace Modules /// /// This class is used to represent a GroupNorm module. /// - public class GroupNorm : torch.nn.Module + public class GroupNorm : torch.nn.Module { internal GroupNorm(IntPtr handle, IntPtr boxedHandle) : base(handle, boxedHandle) { diff --git a/src/TorchSharp/NN/Normalization/InstanceNorm1d.cs b/src/TorchSharp/NN/Normalization/InstanceNorm1d.cs index 649c43dc4..6a3769909 100644 --- a/src/TorchSharp/NN/Normalization/InstanceNorm1d.cs +++ b/src/TorchSharp/NN/Normalization/InstanceNorm1d.cs @@ -14,7 +14,7 @@ namespace Modules /// /// This class is used to represent a InstanceNorm1D module. /// - public class InstanceNorm1d : torch.nn.Module + public class InstanceNorm1d : torch.nn.Module { internal InstanceNorm1d(IntPtr handle, IntPtr boxedHandle) : base(handle, boxedHandle) { diff --git a/src/TorchSharp/NN/Normalization/InstanceNorm2d.cs b/src/TorchSharp/NN/Normalization/InstanceNorm2d.cs index e4ba17a91..4ea747308 100644 --- a/src/TorchSharp/NN/Normalization/InstanceNorm2d.cs +++ b/src/TorchSharp/NN/Normalization/InstanceNorm2d.cs @@ -14,7 +14,7 @@ namespace Modules /// /// This class is used to represent a InstanceNorm2D module. /// - public class InstanceNorm2d : torch.nn.Module + public class InstanceNorm2d : torch.nn.Module { internal InstanceNorm2d(IntPtr handle, IntPtr boxedHandle) : base(handle, boxedHandle) { diff --git a/src/TorchSharp/NN/Normalization/InstanceNorm3d.cs b/src/TorchSharp/NN/Normalization/InstanceNorm3d.cs index 5ca90ef99..2a1eefff8 100644 --- a/src/TorchSharp/NN/Normalization/InstanceNorm3d.cs +++ b/src/TorchSharp/NN/Normalization/InstanceNorm3d.cs @@ -14,7 +14,7 @@ namespace Modules /// /// This class is used to represent a InstanceNorm3D module. /// - public class InstanceNorm3d : torch.nn.Module + public class InstanceNorm3d : torch.nn.Module { internal InstanceNorm3d(IntPtr handle, IntPtr boxedHandle) : base(handle, boxedHandle) { diff --git a/src/TorchSharp/NN/Normalization/LayerNorm.cs b/src/TorchSharp/NN/Normalization/LayerNorm.cs index 1ee65d0c3..fbbd00f2b 100644 --- a/src/TorchSharp/NN/Normalization/LayerNorm.cs +++ b/src/TorchSharp/NN/Normalization/LayerNorm.cs @@ -14,7 +14,7 @@ namespace Modules /// /// This class is used to represent a LayerNorm module. /// - public class LayerNorm : torch.nn.Module + public class LayerNorm : torch.nn.Module { internal LayerNorm(IntPtr handle, IntPtr boxedHandle) : base(handle, boxedHandle) { diff --git a/src/TorchSharp/NN/Normalization/LocalResponseNorm.cs b/src/TorchSharp/NN/Normalization/LocalResponseNorm.cs index 6fa124fc4..fde1df852 100644 --- a/src/TorchSharp/NN/Normalization/LocalResponseNorm.cs +++ b/src/TorchSharp/NN/Normalization/LocalResponseNorm.cs @@ -12,7 +12,7 @@ namespace Modules /// /// This class is used to represent a LocalResponseNorm module. /// - public class LocalResponseNorm : torch.nn.Module + public class LocalResponseNorm : torch.nn.Module { internal LocalResponseNorm(IntPtr handle, IntPtr boxedHandle) : base(handle, boxedHandle) { diff --git a/src/TorchSharp/NN/Padding/ConstantPad1d.cs b/src/TorchSharp/NN/Padding/ConstantPad1d.cs index 7b2e6afd8..e0f52d41d 100644 --- a/src/TorchSharp/NN/Padding/ConstantPad1d.cs +++ b/src/TorchSharp/NN/Padding/ConstantPad1d.cs @@ -12,7 +12,7 @@ namespace Modules /// /// This class is used to represent a ConstantPad1d module. /// - public class ConstantPad1d : torch.nn.Module + public class ConstantPad1d : torch.nn.Module { internal ConstantPad1d(IntPtr handle, IntPtr boxedHandle) : base(handle, boxedHandle) { } diff --git a/src/TorchSharp/NN/Padding/ConstantPad2d.cs b/src/TorchSharp/NN/Padding/ConstantPad2d.cs index bebd39f25..638c5417e 100644 --- a/src/TorchSharp/NN/Padding/ConstantPad2d.cs +++ b/src/TorchSharp/NN/Padding/ConstantPad2d.cs @@ -12,7 +12,7 @@ namespace Modules /// /// This class is used to represent a ConstantPad2d module. /// - public class ConstantPad2d : torch.nn.Module + public class ConstantPad2d : torch.nn.Module { internal ConstantPad2d(IntPtr handle, IntPtr boxedHandle) : base(handle, boxedHandle) { } diff --git a/src/TorchSharp/NN/Padding/ConstantPad3d.cs b/src/TorchSharp/NN/Padding/ConstantPad3d.cs index 39a555598..bb0505ff0 100644 --- a/src/TorchSharp/NN/Padding/ConstantPad3d.cs +++ b/src/TorchSharp/NN/Padding/ConstantPad3d.cs @@ -12,7 +12,7 @@ namespace Modules /// /// This class is used to represent a ConstantPad3d module. /// - public class ConstantPad3d : torch.nn.Module + public class ConstantPad3d : torch.nn.Module { internal ConstantPad3d(IntPtr handle, IntPtr boxedHandle) : base(handle, boxedHandle) { } diff --git a/src/TorchSharp/NN/Padding/ReflectionPad1d.cs b/src/TorchSharp/NN/Padding/ReflectionPad1d.cs index 7b0a301b0..798f2a75b 100644 --- a/src/TorchSharp/NN/Padding/ReflectionPad1d.cs +++ b/src/TorchSharp/NN/Padding/ReflectionPad1d.cs @@ -12,7 +12,7 @@ namespace Modules /// /// This class is used to represent a ReflectionPad1d module. /// - public class ReflectionPad1d : torch.nn.Module + public class ReflectionPad1d : torch.nn.Module { internal ReflectionPad1d(IntPtr handle, IntPtr boxedHandle) : base(handle, boxedHandle) { } diff --git a/src/TorchSharp/NN/Padding/ReflectionPad2d.cs b/src/TorchSharp/NN/Padding/ReflectionPad2d.cs index 6dc488a97..8351b461c 100644 --- a/src/TorchSharp/NN/Padding/ReflectionPad2d.cs +++ b/src/TorchSharp/NN/Padding/ReflectionPad2d.cs @@ -12,7 +12,7 @@ namespace Modules /// /// This class is used to represent a ReflectionPad2d module. /// - public class ReflectionPad2d : torch.nn.Module + public class ReflectionPad2d : torch.nn.Module { internal ReflectionPad2d(IntPtr handle, IntPtr boxedHandle) : base(handle, boxedHandle) { } diff --git a/src/TorchSharp/NN/Padding/ReflectionPad3d.cs b/src/TorchSharp/NN/Padding/ReflectionPad3d.cs index 4ccf9eabf..6dcdb8a3d 100644 --- a/src/TorchSharp/NN/Padding/ReflectionPad3d.cs +++ b/src/TorchSharp/NN/Padding/ReflectionPad3d.cs @@ -12,7 +12,7 @@ namespace Modules /// /// This class is used to represent a ReflectionPad3d module. /// - public class ReflectionPad3d : torch.nn.Module + public class ReflectionPad3d : torch.nn.Module { internal ReflectionPad3d(IntPtr handle, IntPtr boxedHandle) : base(handle, boxedHandle) { } diff --git a/src/TorchSharp/NN/Padding/ReplicationPad1d.cs b/src/TorchSharp/NN/Padding/ReplicationPad1d.cs index 01baa2822..e90906b95 100644 --- a/src/TorchSharp/NN/Padding/ReplicationPad1d.cs +++ b/src/TorchSharp/NN/Padding/ReplicationPad1d.cs @@ -12,7 +12,7 @@ namespace Modules /// /// This class is used to represent a ReplicationPad1d module. /// - public class ReplicationPad1d : torch.nn.Module + public class ReplicationPad1d : torch.nn.Module { internal ReplicationPad1d(IntPtr handle, IntPtr boxedHandle) : base(handle, boxedHandle) { } diff --git a/src/TorchSharp/NN/Padding/ReplicationPad2d.cs b/src/TorchSharp/NN/Padding/ReplicationPad2d.cs index 1edbe97e8..86a6a5147 100644 --- a/src/TorchSharp/NN/Padding/ReplicationPad2d.cs +++ b/src/TorchSharp/NN/Padding/ReplicationPad2d.cs @@ -12,7 +12,7 @@ namespace Modules /// /// This class is used to represent a ReplicationPad2d module. /// - public class ReplicationPad2d : torch.nn.Module + public class ReplicationPad2d : torch.nn.Module { internal ReplicationPad2d(IntPtr handle, IntPtr boxedHandle) : base(handle, boxedHandle) { } diff --git a/src/TorchSharp/NN/Padding/ReplicationPad3d.cs b/src/TorchSharp/NN/Padding/ReplicationPad3d.cs index d508f09e0..e9306e47b 100644 --- a/src/TorchSharp/NN/Padding/ReplicationPad3d.cs +++ b/src/TorchSharp/NN/Padding/ReplicationPad3d.cs @@ -12,7 +12,7 @@ namespace Modules /// /// This class is used to represent a ReplicationPad3d module. /// - public class ReplicationPad3d : torch.nn.Module + public class ReplicationPad3d : torch.nn.Module { internal ReplicationPad3d(IntPtr handle, IntPtr boxedHandle) : base(handle, boxedHandle) { } diff --git a/src/TorchSharp/NN/Padding/ZeroPad2d.cs b/src/TorchSharp/NN/Padding/ZeroPad2d.cs index a1e0f69a8..161fff7a9 100644 --- a/src/TorchSharp/NN/Padding/ZeroPad2d.cs +++ b/src/TorchSharp/NN/Padding/ZeroPad2d.cs @@ -12,7 +12,7 @@ namespace Modules /// /// This class is used to represent a ZeroPad2d module. /// - public class ZeroPad2d : torch.nn.Module + public class ZeroPad2d : torch.nn.Module { internal ZeroPad2d(IntPtr handle, IntPtr boxedHandle) : base(handle, boxedHandle) { } diff --git a/src/TorchSharp/NN/PairwiseDistance.cs b/src/TorchSharp/NN/PairwiseDistance.cs index 518b547f3..c3f65c5a0 100644 --- a/src/TorchSharp/NN/PairwiseDistance.cs +++ b/src/TorchSharp/NN/PairwiseDistance.cs @@ -12,7 +12,7 @@ namespace Modules /// /// Computes the pairwise distance between vectors using the p-norm. /// - public class PairwiseDistance : torch.nn.Module + public class PairwiseDistance : torch.nn.Module { internal PairwiseDistance(IntPtr handle, IntPtr boxedHandle) : base(handle, boxedHandle) { diff --git a/src/TorchSharp/NN/PixelShuffle.cs b/src/TorchSharp/NN/PixelShuffle.cs index 3191c1631..5befe2798 100644 --- a/src/TorchSharp/NN/PixelShuffle.cs +++ b/src/TorchSharp/NN/PixelShuffle.cs @@ -12,7 +12,7 @@ namespace Modules /// /// This class is used to represent a dropout module. /// - public class PixelShuffle : torch.nn.Module + public class PixelShuffle : torch.nn.Module { internal PixelShuffle(IntPtr handle, IntPtr boxedHandle) : base(handle, boxedHandle) { } diff --git a/src/TorchSharp/NN/PixelUnshuffle.cs b/src/TorchSharp/NN/PixelUnshuffle.cs index 95ba402bf..c49926f0b 100644 --- a/src/TorchSharp/NN/PixelUnshuffle.cs +++ b/src/TorchSharp/NN/PixelUnshuffle.cs @@ -12,7 +12,7 @@ namespace Modules /// /// This class is used to represent a dropout module. /// - public class PixelUnshuffle : torch.nn.Module + public class PixelUnshuffle : torch.nn.Module { internal PixelUnshuffle(IntPtr handle, IntPtr boxedHandle) : base(handle, boxedHandle) { } diff --git a/src/TorchSharp/NN/Pooling/AdaptiveAvgPool1D.cs b/src/TorchSharp/NN/Pooling/AdaptiveAvgPool1D.cs index 06b310845..bb4ca9af9 100644 --- a/src/TorchSharp/NN/Pooling/AdaptiveAvgPool1D.cs +++ b/src/TorchSharp/NN/Pooling/AdaptiveAvgPool1D.cs @@ -12,7 +12,7 @@ namespace Modules /// /// This class is used to represent a AdaptiveAvgPool1D module. /// - public class AdaptiveAvgPool1d : torch.nn.Module + public class AdaptiveAvgPool1d : torch.nn.Module { internal AdaptiveAvgPool1d(IntPtr handle, IntPtr boxedHandle) : base(handle, boxedHandle) { diff --git a/src/TorchSharp/NN/Pooling/AdaptiveAvgPool2D.cs b/src/TorchSharp/NN/Pooling/AdaptiveAvgPool2D.cs index 39d144e92..92287c2ec 100644 --- a/src/TorchSharp/NN/Pooling/AdaptiveAvgPool2D.cs +++ b/src/TorchSharp/NN/Pooling/AdaptiveAvgPool2D.cs @@ -12,7 +12,7 @@ namespace Modules /// /// This class is used to represent a AdaptiveAvgPool2D module. /// - public class AdaptiveAvgPool2d : torch.nn.Module + public class AdaptiveAvgPool2d : torch.nn.Module { internal AdaptiveAvgPool2d(IntPtr handle, IntPtr boxedHandle) : base(handle, boxedHandle) { diff --git a/src/TorchSharp/NN/Pooling/AdaptiveAvgPool3D.cs b/src/TorchSharp/NN/Pooling/AdaptiveAvgPool3D.cs index 658fb4375..018a6ccfc 100644 --- a/src/TorchSharp/NN/Pooling/AdaptiveAvgPool3D.cs +++ b/src/TorchSharp/NN/Pooling/AdaptiveAvgPool3D.cs @@ -12,7 +12,7 @@ namespace Modules /// /// This class is used to represent a AdaptiveAvgPool3D module. /// - public class AdaptiveAvgPool3d : torch.nn.Module + public class AdaptiveAvgPool3d : torch.nn.Module { internal AdaptiveAvgPool3d(IntPtr handle, IntPtr boxedHandle) : base(handle, boxedHandle) { diff --git a/src/TorchSharp/NN/Pooling/AdaptiveMaxPool1D.cs b/src/TorchSharp/NN/Pooling/AdaptiveMaxPool1D.cs index 89e113424..926a27102 100644 --- a/src/TorchSharp/NN/Pooling/AdaptiveMaxPool1D.cs +++ b/src/TorchSharp/NN/Pooling/AdaptiveMaxPool1D.cs @@ -12,7 +12,7 @@ namespace Modules /// /// This class is used to represent a AdaptiveMaxPool1D module. /// - public class AdaptiveMaxPool1d : torch.nn.Module + public class AdaptiveMaxPool1d : torch.nn.Module { internal AdaptiveMaxPool1d(IntPtr handle, IntPtr boxedHandle) : base(handle, boxedHandle) { diff --git a/src/TorchSharp/NN/Pooling/AdaptiveMaxPool2D.cs b/src/TorchSharp/NN/Pooling/AdaptiveMaxPool2D.cs index 37afd3ebf..5b0642844 100644 --- a/src/TorchSharp/NN/Pooling/AdaptiveMaxPool2D.cs +++ b/src/TorchSharp/NN/Pooling/AdaptiveMaxPool2D.cs @@ -12,7 +12,7 @@ namespace Modules /// /// This class is used to represent a AdaptiveMaxPool2D module. /// - public class AdaptiveMaxPool2d : torch.nn.Module + public class AdaptiveMaxPool2d : torch.nn.Module { internal AdaptiveMaxPool2d(IntPtr handle, IntPtr boxedHandle) : base(handle, boxedHandle) { diff --git a/src/TorchSharp/NN/Pooling/AdaptiveMaxPool3D.cs b/src/TorchSharp/NN/Pooling/AdaptiveMaxPool3D.cs index c7c994503..87f885fb1 100644 --- a/src/TorchSharp/NN/Pooling/AdaptiveMaxPool3D.cs +++ b/src/TorchSharp/NN/Pooling/AdaptiveMaxPool3D.cs @@ -12,7 +12,7 @@ namespace Modules /// /// This class is used to represent a AdaptiveMaxPool3D module. /// - public class AdaptiveMaxPool3d : torch.nn.Module + public class AdaptiveMaxPool3d : torch.nn.Module { internal AdaptiveMaxPool3d(IntPtr handle, IntPtr boxedHandle) : base(handle, boxedHandle) { diff --git a/src/TorchSharp/NN/Pooling/AvgPool1D.cs b/src/TorchSharp/NN/Pooling/AvgPool1D.cs index 552920a11..d252e6f9a 100644 --- a/src/TorchSharp/NN/Pooling/AvgPool1D.cs +++ b/src/TorchSharp/NN/Pooling/AvgPool1D.cs @@ -12,7 +12,7 @@ namespace Modules /// /// This class is used to represent a AvgPool1D module. /// - public class AvgPool1d : torch.nn.Module + public class AvgPool1d : torch.nn.Module { internal AvgPool1d(IntPtr handle, IntPtr boxedHandle) : base(handle, boxedHandle) { diff --git a/src/TorchSharp/NN/Pooling/AvgPool2D.cs b/src/TorchSharp/NN/Pooling/AvgPool2D.cs index 562d6c08e..6738ebaa8 100644 --- a/src/TorchSharp/NN/Pooling/AvgPool2D.cs +++ b/src/TorchSharp/NN/Pooling/AvgPool2D.cs @@ -12,7 +12,7 @@ namespace Modules /// /// This class is used to represent a AvgPool2D module. /// - public class AvgPool2d : torch.nn.Module + public class AvgPool2d : torch.nn.Module { internal AvgPool2d(IntPtr handle, IntPtr boxedHandle) : base(handle, boxedHandle) { diff --git a/src/TorchSharp/NN/Pooling/AvgPool3D.cs b/src/TorchSharp/NN/Pooling/AvgPool3D.cs index da28f325d..65ce456d9 100644 --- a/src/TorchSharp/NN/Pooling/AvgPool3D.cs +++ b/src/TorchSharp/NN/Pooling/AvgPool3D.cs @@ -12,7 +12,7 @@ namespace Modules /// /// This class is used to represent a AvgPool3D module. /// - public class AvgPool3d : torch.nn.Module + public class AvgPool3d : torch.nn.Module { internal AvgPool3d(IntPtr handle, IntPtr boxedHandle) : base(handle, boxedHandle) { diff --git a/src/TorchSharp/NN/Pooling/FractionalMaxPool2d.cs b/src/TorchSharp/NN/Pooling/FractionalMaxPool2d.cs index daa5fa340..17d45253f 100644 --- a/src/TorchSharp/NN/Pooling/FractionalMaxPool2d.cs +++ b/src/TorchSharp/NN/Pooling/FractionalMaxPool2d.cs @@ -12,7 +12,7 @@ namespace Modules /// /// This class is used to represent a FractionalMaxPool2D module. /// - public class FractionalMaxPool2d : torch.nn.Module + public class FractionalMaxPool2d : torch.nn.Module { internal FractionalMaxPool2d(IntPtr handle, IntPtr boxedHandle) : base(handle, boxedHandle) { diff --git a/src/TorchSharp/NN/Pooling/FractionalMaxPool3d.cs b/src/TorchSharp/NN/Pooling/FractionalMaxPool3d.cs index 1a24e6050..75f7febee 100644 --- a/src/TorchSharp/NN/Pooling/FractionalMaxPool3d.cs +++ b/src/TorchSharp/NN/Pooling/FractionalMaxPool3d.cs @@ -12,7 +12,7 @@ namespace Modules /// /// This class is used to represent a FractionalMaxPool3d module. /// - public class FractionalMaxPool3d : torch.nn.Module + public class FractionalMaxPool3d : torch.nn.Module { internal FractionalMaxPool3d(IntPtr handle, IntPtr boxedHandle, bool ratio) : base(handle, boxedHandle) { diff --git a/src/TorchSharp/NN/Pooling/LPPool1d.cs b/src/TorchSharp/NN/Pooling/LPPool1d.cs index cef236887..b2914149f 100644 --- a/src/TorchSharp/NN/Pooling/LPPool1d.cs +++ b/src/TorchSharp/NN/Pooling/LPPool1d.cs @@ -12,7 +12,7 @@ namespace Modules /// /// This class is used to represent a LPPool1D module. /// - public class LPPool1d : torch.nn.Module + public class LPPool1d : torch.nn.Module { internal LPPool1d(IntPtr handle, IntPtr boxedHandle) : base(handle, boxedHandle) { diff --git a/src/TorchSharp/NN/Pooling/LPPool2d.cs b/src/TorchSharp/NN/Pooling/LPPool2d.cs index 70e39b898..a220b4a13 100644 --- a/src/TorchSharp/NN/Pooling/LPPool2d.cs +++ b/src/TorchSharp/NN/Pooling/LPPool2d.cs @@ -12,7 +12,7 @@ namespace Modules /// /// This class is used to represent a LPPool2D module. /// - public class LPPool2d : torch.nn.Module + public class LPPool2d : torch.nn.Module { internal LPPool2d(IntPtr handle, IntPtr boxedHandle) : base(handle, boxedHandle) { diff --git a/src/TorchSharp/NN/Pooling/MaxPool1D.cs b/src/TorchSharp/NN/Pooling/MaxPool1D.cs index f2a1240aa..f594eba64 100644 --- a/src/TorchSharp/NN/Pooling/MaxPool1D.cs +++ b/src/TorchSharp/NN/Pooling/MaxPool1D.cs @@ -12,7 +12,7 @@ namespace Modules /// /// This class is used to represent a MaxPool1D module. /// - public class MaxPool1d : torch.nn.Module + public class MaxPool1d : torch.nn.Module { internal MaxPool1d(IntPtr handle, IntPtr boxedHandle) : base(handle, boxedHandle) { diff --git a/src/TorchSharp/NN/Pooling/MaxPool2D.cs b/src/TorchSharp/NN/Pooling/MaxPool2D.cs index 07c6be346..8c77c231f 100644 --- a/src/TorchSharp/NN/Pooling/MaxPool2D.cs +++ b/src/TorchSharp/NN/Pooling/MaxPool2D.cs @@ -12,7 +12,7 @@ namespace Modules /// /// This class is used to represent a MaxPool2D module. /// - public class MaxPool2d : torch.nn.Module + public class MaxPool2d : torch.nn.Module { internal MaxPool2d(IntPtr handle, IntPtr boxedHandle) : base(handle, boxedHandle) { diff --git a/src/TorchSharp/NN/Pooling/MaxPool3D.cs b/src/TorchSharp/NN/Pooling/MaxPool3D.cs index 8d7ed357d..9a65c6196 100644 --- a/src/TorchSharp/NN/Pooling/MaxPool3D.cs +++ b/src/TorchSharp/NN/Pooling/MaxPool3D.cs @@ -12,7 +12,7 @@ namespace Modules /// /// This class is used to represent a MaxPool3D module. /// - public class MaxPool3d : torch.nn.Module + public class MaxPool3d : torch.nn.Module { internal MaxPool3d(IntPtr handle, IntPtr boxedHandle) : base(handle, boxedHandle) { diff --git a/src/TorchSharp/NN/Recurrent/GRU.cs b/src/TorchSharp/NN/Recurrent/GRU.cs index 1d8e9096b..0d06e700f 100644 --- a/src/TorchSharp/NN/Recurrent/GRU.cs +++ b/src/TorchSharp/NN/Recurrent/GRU.cs @@ -35,7 +35,7 @@ internal GRU(IntPtr handle, IntPtr boxedHandle, long hiddenSize, long numLayers, /// Defaults to 0 if not provided. If the GRU is bidirectional, num_directions should be 2, else it should be 1. /// /// - public new (Tensor, Tensor) forward(Tensor input, Tensor h0 = null) + public (Tensor, Tensor) forward(Tensor input, Tensor h0 = null) { if (h0 is null) { var N = _batch_first ? input.shape[0] : input.shape[1]; diff --git a/src/TorchSharp/NN/Recurrent/GRUCell.cs b/src/TorchSharp/NN/Recurrent/GRUCell.cs index ef4af69ce..86d726c21 100644 --- a/src/TorchSharp/NN/Recurrent/GRUCell.cs +++ b/src/TorchSharp/NN/Recurrent/GRUCell.cs @@ -11,13 +11,13 @@ namespace TorchSharp namespace Modules { - public class GRUCell : torch.nn.Module + public class GRUCell : torch.nn.Module { internal GRUCell(IntPtr handle, IntPtr boxedHandle) : base(handle, boxedHandle) { } public new static GRUCell Load(String modelPath) { - var res = Module.Load(modelPath); + var res = Module.Load(modelPath); return new GRUCell(res.handle.DangerousGetHandle(), IntPtr.Zero); } diff --git a/src/TorchSharp/NN/Recurrent/LSTMCell.cs b/src/TorchSharp/NN/Recurrent/LSTMCell.cs index a645aa708..43e48e6f8 100644 --- a/src/TorchSharp/NN/Recurrent/LSTMCell.cs +++ b/src/TorchSharp/NN/Recurrent/LSTMCell.cs @@ -18,7 +18,7 @@ internal LSTMCell(IntPtr handle, IntPtr boxedHandle) : base(handle, boxedHandle) public new static LSTMCell Load(String modelPath) { - var res = Module.Load(modelPath); + var res = Module.Load(modelPath); return new LSTMCell(res.handle.DangerousGetHandle(), IntPtr.Zero); } diff --git a/src/TorchSharp/NN/Recurrent/RNN.cs b/src/TorchSharp/NN/Recurrent/RNN.cs index 15bbb8538..21772ec81 100644 --- a/src/TorchSharp/NN/Recurrent/RNN.cs +++ b/src/TorchSharp/NN/Recurrent/RNN.cs @@ -35,7 +35,7 @@ internal RNN(IntPtr handle, IntPtr boxedHandle, long hiddenSize, long numLayers, /// Tensor of shape (num_layers * num_directions, batch, hidden_size)containing the initial hidden state for each element in the batch. /// Defaults to 0 if not provided. If the RNN is bidirectional, num_directions should be 2, else it should be 1. /// - public new (Tensor, Tensor) forward(Tensor input, Tensor? h0 = null) + public (Tensor, Tensor) forward(Tensor input, Tensor? h0 = null) { if (h0 is null) { var N = _batch_first ? input.shape[0] : input.shape[1]; diff --git a/src/TorchSharp/NN/Recurrent/RNNCell.cs b/src/TorchSharp/NN/Recurrent/RNNCell.cs index 6893ca28d..30ef00592 100644 --- a/src/TorchSharp/NN/Recurrent/RNNCell.cs +++ b/src/TorchSharp/NN/Recurrent/RNNCell.cs @@ -12,7 +12,7 @@ namespace TorchSharp namespace Modules { - public class RNNCell : torch.nn.Module + public class RNNCell : torch.nn.Module { internal RNNCell(IntPtr handle, IntPtr boxedHandle) : base(handle, boxedHandle) { @@ -20,7 +20,7 @@ internal RNNCell(IntPtr handle, IntPtr boxedHandle) : base(handle, boxedHandle) public new static RNNCell Load(String modelPath) { - var res = Module.Load(modelPath); + var res = Module.Load(modelPath); return new RNNCell(res.handle.DangerousGetHandle(), IntPtr.Zero); } diff --git a/src/TorchSharp/NN/Sequential.cs b/src/TorchSharp/NN/Sequential.cs index 1fdbc99bc..bced378e9 100644 --- a/src/TorchSharp/NN/Sequential.cs +++ b/src/TorchSharp/NN/Sequential.cs @@ -16,30 +16,32 @@ namespace Modules /// /// This class is used to represent a Sequential module. /// - public class Sequential : torch.nn.Module + public class Sequential : torch.nn.Module { - public void append(string name, torch.nn.Module module) + public void append(string name, torch.nn.IModule module) { Add(name, module); } - internal void Add(string name, torch.nn.Module submodule) + internal void Add(string name, torch.nn.IModule sm) { + var submodule = (torch.nn.Module)sm; + Debug.Assert(!handle.IsInvalid); Debug.Assert(!submodule.handle.IsInvalid); // Keep the sub-module alive for at least as long as the Sequential object is alive. - _modules.Add(submodule); + _modules.Add(sm); _names.Add(name); } - public void append(torch.nn.Module module) + public void append(torch.nn.IModule module) { var name = _modules.Count.ToString(); Add(name, module); } - internal void Add(torch.nn.Module module) + internal void Add(torch.nn.IModule module) { var name = _modules.Count.ToString(); Add(name, module); @@ -52,7 +54,7 @@ internal void Add(torch.nn.Module module) var seen = new HashSet(); for (var i = 0; i < _names.Count; i++) { - foreach (var (n, p) in _modules[i].named_parameters(true)) { + foreach (var (n, p) in ((torch.nn.Module)_modules[i]).named_parameters(true)) { if (seen.Contains(p.Handle)) continue; seen.Add(p.Handle); yield return ($"{_names[i]}.{n}", p); @@ -65,7 +67,7 @@ internal void Add(torch.nn.Module module) if (!recurse) yield break; for (var i = 0; i < _names.Count; i++) { - foreach (var (n, p) in _modules[i].named_buffers(true)) { + foreach (var (n, p) in ((torch.nn.Module)_modules[i]).named_buffers(true)) { yield return ($"{_names[i]}.{n}", p); } } @@ -74,18 +76,18 @@ internal void Add(torch.nn.Module module) public override IEnumerable<(string name, torch.nn.Module module)> named_children() { for (var i = 0; i < _names.Count; i++) { - yield return ($"{_names[i]}", _modules[i]); + yield return ($"{_names[i]}", ((torch.nn.Module)_modules[i])); } } public override IEnumerable<(string name, torch.nn.Module module)> named_modules() { for (var i = 0; i < _names.Count; i++) { - yield return ($"{_names[i]}", _modules[i]); + yield return ($"{_names[i]}", ((torch.nn.Module)_modules[i])); } for (var i = 0; i < _names.Count; i++) { - var sm = _modules[i]; + var sm = (torch.nn.Module)_modules[i]; var name = _names[i]; foreach (var (n, p) in sm.named_modules()) { yield return ($"{name}.{n}", p); @@ -127,7 +129,7 @@ public override Tensor forward(Tensor tensor) public override nn.Module apply(Action fn) { // More efficient than asking C++ for the children. We already have the list, after all. - foreach (var m in _modules) m.apply(fn); + foreach (var m in _modules) ((torch.nn.Module)m).apply(fn); fn(this); return this; @@ -136,7 +138,7 @@ public override nn.Module apply(Action fn) protected override void Dispose(bool disposing) { if (disposing) { - foreach (var m in _modules) { m.Dispose(); } + foreach (var m in _modules) { ((torch.nn.Module)m).Dispose(); } } base.Dispose(disposing); } @@ -149,7 +151,7 @@ protected override void Dispose(bool disposing) /// public override void train(bool on = true) { - foreach (var m in _modules) { m.train(on); } + foreach (var m in _modules) { ((torch.nn.Module)m).train(on); } } /// @@ -160,24 +162,24 @@ public override void train(bool on = true) /// public override void eval() { - foreach (var m in _modules) { m.eval(); } + foreach (var m in _modules) { ((torch.nn.Module)m).eval(); } } internal protected override nn.Module _to(ScalarType dtype) { - foreach (var m in _modules) { m._to(dtype); } + foreach (var m in _modules) { ((torch.nn.Module)m)._to(dtype); } return this; } internal protected override nn.Module _to(Device device, ScalarType dtype) { - foreach (var m in _modules) { m._to(device, dtype); } + foreach (var m in _modules) { ((torch.nn.Module)m)._to(device, dtype); } return this; } internal protected override nn.Module _to(DeviceType deviceType, int deviceIndex = -1) { - foreach (var m in _modules) { m._to(deviceType, deviceIndex); } + foreach (var m in _modules) { ((torch.nn.Module)m)._to(deviceType, deviceIndex); } return this; } @@ -186,7 +188,7 @@ internal protected override nn.Module _to(DeviceType deviceType, int deviceIndex // The module handles are held in the native runtime, which calls back into managed code, // the .NET module instances need to stay alive, and keeping a list of them will do that. - private List _modules = new List(); + private List> _modules = new List>(); private List _names = new List(); } } @@ -221,7 +223,7 @@ static public Sequential Sequential() /// An ordered list of the contained modules. /// /// Sequential will take ownership of the modules and dispose of them when disposed. - static public Sequential Sequential(params (string name, torch.nn.Module submodule)[] modules) + static public Sequential Sequential(params (string name, torch.nn.Module submodule)[] modules) { var res = Sequential(); foreach (var module in modules) @@ -238,7 +240,7 @@ static public Sequential Sequential(params (string name, torch.nn.Module submodu /// An ordered list of the contained modules. /// /// Sequential will take ownership of the modules and dispose of them when disposed. - static public Sequential Sequential(params torch.nn.Module[] modules) + static public Sequential Sequential(params torch.nn.Module[] modules) { var res = Sequential(); foreach (var m in modules) @@ -255,7 +257,7 @@ static public Sequential Sequential(params torch.nn.Module[] modules) /// /// An ordered list of the contained modules. /// Sequential will take ownership of the modules and dispose of them when disposed. - static public Sequential Sequential(params System.Tuple[] modules) + static public Sequential Sequential(params System.Tuple>[] modules) { var res = Sequential(); foreach (var module in modules) @@ -272,7 +274,7 @@ static public Sequential Sequential(params System.Tuple /// /// An ordered list of the contained modules. /// Sequential will take ownership of the modules and dispose of them when disposed. - static public Sequential Sequential(IEnumerable<(string name, torch.nn.Module submodule)> modules) + static public Sequential Sequential(IEnumerable<(string name, torch.nn.Module submodule)> modules) { var res = Sequential(); foreach (var module in modules) @@ -289,7 +291,7 @@ static public Sequential Sequential(IEnumerable<(string name, torch.nn.Module su /// /// An ordered list of the contained modules. /// Sequential will take ownership of the modules and dispose of them when disposed. - static public Sequential Sequential(IEnumerable> modules) + static public Sequential Sequential(IEnumerable>> modules) { var res = Sequential(); foreach (var module in modules) @@ -306,7 +308,7 @@ static public Sequential Sequential(IEnumerableAn ordered list of the contained modules. /// /// Sequential will take ownership of the modules and dispose of them when disposed. - static public Sequential Sequential(IEnumerable modules) + static public Sequential Sequential(IEnumerable> modules) { var res = Sequential(); foreach (var module in modules) diff --git a/src/TorchSharp/NN/Shuffle/ChannelShuffle.cs b/src/TorchSharp/NN/Shuffle/ChannelShuffle.cs index 3a162322f..9dab39476 100644 --- a/src/TorchSharp/NN/Shuffle/ChannelShuffle.cs +++ b/src/TorchSharp/NN/Shuffle/ChannelShuffle.cs @@ -11,7 +11,7 @@ namespace Modules /// /// This class is used to represent a ChannelShuffle module. /// - public class ChannelShuffle : torch.nn.Module + public class ChannelShuffle : torch.nn.Module { internal ChannelShuffle(long groups) : base(nameof(ChannelShuffle)) { diff --git a/src/TorchSharp/NN/Transformer.cs b/src/TorchSharp/NN/Transformer.cs index 73c777a93..7faad8374 100644 --- a/src/TorchSharp/NN/Transformer.cs +++ b/src/TorchSharp/NN/Transformer.cs @@ -9,7 +9,7 @@ namespace TorchSharp namespace Modules { - public class Transformer : torch.nn.Module + public class Transformer : torch.nn.Module { internal Transformer(IntPtr handle, IntPtr boxedHandle) : base(handle, boxedHandle) { } diff --git a/src/TorchSharp/NN/TransformerDecoder.cs b/src/TorchSharp/NN/TransformerDecoder.cs index 53b6d91a2..3ff23e95e 100644 --- a/src/TorchSharp/NN/TransformerDecoder.cs +++ b/src/TorchSharp/NN/TransformerDecoder.cs @@ -9,7 +9,7 @@ namespace TorchSharp namespace Modules { - public class TransformerDecoder : torch.nn.Module + public class TransformerDecoder : torch.nn.Module { internal TransformerDecoder(IntPtr handle, IntPtr boxedHandle) : base(handle, boxedHandle) { } diff --git a/src/TorchSharp/NN/TransformerDecoderLayer.cs b/src/TorchSharp/NN/TransformerDecoderLayer.cs index eb2823eb7..769909d6f 100644 --- a/src/TorchSharp/NN/TransformerDecoderLayer.cs +++ b/src/TorchSharp/NN/TransformerDecoderLayer.cs @@ -9,7 +9,7 @@ namespace TorchSharp namespace Modules { - public class TransformerDecoderLayer : torch.nn.Module + public class TransformerDecoderLayer : torch.nn.Module { internal TransformerDecoderLayer(IntPtr handle, IntPtr boxedHandle) : base(handle, boxedHandle) { } diff --git a/src/TorchSharp/NN/TransformerEncoder.cs b/src/TorchSharp/NN/TransformerEncoder.cs index 84b1f05a7..e64007e88 100644 --- a/src/TorchSharp/NN/TransformerEncoder.cs +++ b/src/TorchSharp/NN/TransformerEncoder.cs @@ -9,7 +9,7 @@ namespace TorchSharp namespace Modules { - public class TransformerEncoder : torch.nn.Module + public class TransformerEncoder : torch.nn.Module { public enum Activations { @@ -29,7 +29,7 @@ internal TransformerEncoder(IntPtr handle, IntPtr boxedHandle) : base(handle, bo /// The additive mask for the src sequence (optional). /// The ByteTensor mask for src keys per batch (optional). /// - public override Tensor forward(Tensor src, Tensor src_mask, Tensor src_key_padding_mask) + public Tensor forward(Tensor src, Tensor src_mask, Tensor src_key_padding_mask) { var res = THSNN_TransformerEncoder_forward(handle, src.Handle, @@ -45,7 +45,7 @@ public override Tensor forward(Tensor src, Tensor src_mask, Tensor src_key_paddi /// The sequence to the encoder (required). /// The additive mask for the src sequence (optional). /// - public override Tensor forward(Tensor src, Tensor src_mask) + public Tensor forward(Tensor src, Tensor src_mask) { var res = THSNN_TransformerEncoder_forward(handle, src.Handle, diff --git a/src/TorchSharp/NN/TransformerEncoderLayer.cs b/src/TorchSharp/NN/TransformerEncoderLayer.cs index ae40d83d4..98455e1b6 100644 --- a/src/TorchSharp/NN/TransformerEncoderLayer.cs +++ b/src/TorchSharp/NN/TransformerEncoderLayer.cs @@ -5,6 +5,7 @@ namespace TorchSharp { + using System.Dynamic; using Modules; namespace Modules @@ -23,7 +24,7 @@ internal TransformerEncoderLayer(IntPtr handle, IntPtr boxedHandle) : base(handl /// The additive mask for the src sequence (optional). /// The ByteTensor mask for src keys per batch (optional). /// - public override Tensor forward(Tensor src, Tensor src_mask, Tensor src_key_padding_mask) + public Tensor forward(Tensor src, Tensor src_mask, Tensor src_key_padding_mask) { var res = THSNN_TransformerEncoderLayer_forward(handle, src.Handle, @@ -38,7 +39,7 @@ public override Tensor forward(Tensor src, Tensor src_mask, Tensor src_key_paddi /// /// The sequence to the encoder (required). /// The additive mask for the src sequence (optional). - public override Tensor forward(Tensor src, Tensor src_mask) + public Tensor forward(Tensor src, Tensor src_mask) { var res = THSNN_TransformerEncoderLayer_forward(handle, src.Handle, @@ -52,7 +53,7 @@ public override Tensor forward(Tensor src, Tensor src_mask) /// Pass the input through the encoder layer. /// /// The sequence to the encoder (required). - public override Tensor forward(Tensor src) + public Tensor forward(Tensor src) { var res = THSNN_TransformerEncoderLayer_forward(handle, src.Handle, diff --git a/src/TorchSharp/NN/Unflatten.cs b/src/TorchSharp/NN/Unflatten.cs index be4466b63..e34dabed0 100644 --- a/src/TorchSharp/NN/Unflatten.cs +++ b/src/TorchSharp/NN/Unflatten.cs @@ -12,7 +12,7 @@ namespace Modules /// /// This class is used to represent an unflattening operation. /// - public class Unflatten : torch.nn.Module + public class Unflatten : torch.nn.Module { internal Unflatten(IntPtr handle, IntPtr boxedHandle) : base(handle, boxedHandle) { diff --git a/src/TorchSharp/NN/Upsample.cs b/src/TorchSharp/NN/Upsample.cs index f2c9de768..e96de5f29 100644 --- a/src/TorchSharp/NN/Upsample.cs +++ b/src/TorchSharp/NN/Upsample.cs @@ -70,7 +70,7 @@ namespace Modules /// /// This class is used to represent an Upsample module. /// - public class Upsample : torch.nn.Module + public class Upsample : torch.nn.Module { internal Upsample(IntPtr handle, IntPtr boxedHandle) : base(handle, boxedHandle) { } diff --git a/src/TorchSharp/Tensor/Tensor.cs b/src/TorchSharp/Tensor/Tensor.cs index 98505b840..c68b15880 100644 --- a/src/TorchSharp/Tensor/Tensor.cs +++ b/src/TorchSharp/Tensor/Tensor.cs @@ -6664,7 +6664,7 @@ public static implicit operator Tensor(Scalar scalar) } // Specifically added to make F# look good. - public static Tensor op_MinusMinusGreater(Tensor t, torch.nn.Module m) => m.forward(t); + public static Tensor op_MinusMinusGreater(Tensor t, torch.nn.Module m) => m.forward(t); public override string ToString() => ToMetadataString(); diff --git a/src/TorchSharp/TorchAudio/Modules/HuBERTPretrainModel.cs b/src/TorchSharp/TorchAudio/Modules/HuBERTPretrainModel.cs index e7e2b7410..e6943baec 100644 --- a/src/TorchSharp/TorchAudio/Modules/HuBERTPretrainModel.cs +++ b/src/TorchSharp/TorchAudio/Modules/HuBERTPretrainModel.cs @@ -79,7 +79,7 @@ internal HuBERTPretrainModel( /// The feature mean value for additional penalty loss. /// Shape: `(1,)`. /// - public new (Tensor?, Tensor?, Tensor) forward( + public (Tensor?, Tensor?, Tensor) forward( Tensor waveforms, Tensor labels, Tensor? audio_lengths = null) diff --git a/src/TorchSharp/TorchAudio/Modules/Tacotron2.cs b/src/TorchSharp/TorchAudio/Modules/Tacotron2.cs index 0dd640932..00d37222d 100644 --- a/src/TorchSharp/TorchAudio/Modules/Tacotron2.cs +++ b/src/TorchSharp/TorchAudio/Modules/Tacotron2.cs @@ -218,7 +218,7 @@ private static Tensor _get_mask_from_lengths(Tensor lengths) return mask; } - private class LocationLayer : nn.Module + private class LocationLayer : nn.Module { private readonly Modules.Conv1d location_conv; private readonly Modules.Linear location_dense; @@ -315,7 +315,7 @@ private Tensor _get_alignment_energies(Tensor query, Tensor processed_memory, Te } } - public class Prenet : nn.Module + public class Prenet : nn.Module { private readonly Modules.ModuleList layers; @@ -336,13 +336,13 @@ public Prenet(string name, int in_dim, long[] out_sizes) : base(name) public override Tensor forward(Tensor x) { foreach (var linear in this.layers) { - x = F.dropout(F.relu(linear.forward(x)), p: 0.5, training: true); + x = F.dropout(F.relu(((nn.Module)linear).forward(x)), p: 0.5, training: true); } return x; } } - private class Postnet : nn.Module + private class Postnet : nn.Module { private readonly Modules.ModuleList convolutions; public readonly int n_convs; @@ -383,16 +383,16 @@ public override Tensor forward(Tensor x) for (int i = 0; i < this.convolutions.Count; i++) { var conv = this.convolutions[i]; if (i < this.n_convs - 1) { - x = F.dropout(torch.tanh(conv.forward(x)), 0.5, training: this.training); + x = F.dropout(torch.tanh(((nn.Module)conv).forward(x)), 0.5, training: this.training); } else { - x = F.dropout(conv.forward(x), 0.5, training: this.training); + x = F.dropout(((nn.Module)conv).forward(x), 0.5, training: this.training); } } return x; } } - private class Encoder : nn.Module + private class Encoder : nn.Module { private readonly Modules.ModuleList convolutions; private readonly Modules.LSTM lstm; @@ -434,7 +434,7 @@ public Encoder( public override Tensor forward(Tensor x, Tensor input_lengths) { foreach (var conv in this.convolutions) { - x = F.dropout(F.relu(conv.forward(x)), 0.5, training: this.training); + x = F.dropout(F.relu(((nn.Module)conv).forward(x)), 0.5, training: this.training); } x = x.transpose(1, 2); @@ -638,7 +638,7 @@ private Tensor _parse_decoder_inputs(Tensor decoder_inputs) } // Decoder forward pass for training. - public new (Tensor, Tensor, Tensor) forward(Tensor memory, Tensor mel_specgram_truth, Tensor memory_lengths) + public (Tensor, Tensor, Tensor) forward(Tensor memory, Tensor mel_specgram_truth, Tensor memory_lengths) { var decoder_input = this._get_initial_frame(memory).unsqueeze(0); var decoder_inputs = this._parse_decoder_inputs(mel_specgram_truth); diff --git a/src/TorchSharp/TorchAudio/Modules/Wav2Vec2Components.cs b/src/TorchSharp/TorchAudio/Modules/Wav2Vec2Components.cs index 73f87f490..004faeb6b 100644 --- a/src/TorchSharp/TorchAudio/Modules/Wav2Vec2Components.cs +++ b/src/TorchSharp/TorchAudio/Modules/Wav2Vec2Components.cs @@ -21,12 +21,12 @@ #nullable enable namespace TorchSharp.Modules { - public partial class Wav2Vec2Model : Module + public partial class Wav2Vec2Model : nn.Module { /// /// Layer norm with transpose /// - private class LayerNorm : Module + private class LayerNorm : Module { public readonly long[] normalized_shape; public readonly Parameter weight; @@ -60,9 +60,9 @@ public override Tensor forward(Tensor input) /// private class ConvLayerBlock : Module { - public readonly Module conv; + public readonly Module conv; public readonly long kernel_size; - public readonly Module? layer_norm; + public readonly Module? layer_norm; public readonly long stride; public ConvLayerBlock( @@ -72,7 +72,7 @@ public ConvLayerBlock( long kernel_size, long stride, bool bias, - Module? layer_norm) : base(name) + Module? layer_norm) : base(name) { this.kernel_size = kernel_size; this.stride = stride; @@ -92,7 +92,7 @@ public ConvLayerBlock( /// Shape ``[batch, out_channels, out_frames]``. /// Shape ``[batch, ]``. /// - public new (Tensor, Tensor?) forward( + public (Tensor, Tensor?) forward( Tensor x, Tensor? length) { @@ -135,7 +135,7 @@ public FeatureExtractor( /// Valid length of each output sample. shape: ``[batch, ]``. /// /// - public new (Tensor, Tensor?) forward(Tensor x, Tensor? length) + public (Tensor, Tensor?) forward(Tensor x, Tensor? length) { if (x.ndim != 2) { throw new ArgumentException("Expected the input Tensor to be 2D (batch, time), but received {list(x.shape)}"); @@ -154,11 +154,11 @@ public FeatureExtractor( /// /// Layer that connects FeatureExtractor and Encoder /// - private class FeatureProjection : Module + private class FeatureProjection : Module { - public readonly Module dropout; - public readonly Module layer_norm; - public readonly Module projection; + public readonly Module dropout; + public readonly Module layer_norm; + public readonly Module projection; /// /// Projects features to encoder dimension. @@ -195,9 +195,9 @@ public override Tensor forward(Tensor x) /// /// Positional embedding which is placed at the beginning of Transformer. /// - internal class ConvolutionalPositionalEmbedding : Module + internal class ConvolutionalPositionalEmbedding : Module { - public readonly Module conv; + public readonly Module conv; public readonly long embed_dim; public readonly long num_remove; @@ -241,7 +241,7 @@ public override Tensor forward(Tensor x) return x; } - private class WeightNormConv1d : Module + private class WeightNormConv1d : Module { private readonly Parameter weight_g; private readonly Parameter weight_v; @@ -293,15 +293,15 @@ public override Tensor forward(Tensor input) /// private class SelfAttention : Module { - public readonly Module dropout; + public readonly Module dropout; public readonly long embed_dim; public readonly long head_dim; - public readonly Module k_proj; + public readonly Module k_proj; public readonly long num_heads; - public readonly Module out_proj; - public readonly Module q_proj; + public readonly Module out_proj; + public readonly Module q_proj; public readonly double scaling; - public readonly Module v_proj; + public readonly Module v_proj; /// /// Total dimension of the model. @@ -336,7 +336,7 @@ public SelfAttention( /// shape: ``[batch_size, 1, sequence_length, sequence_length]`` /// The resulting tensor. shape: ``[batch, sequence_length, embed_dim]`` /// - public new Tensor forward(Tensor x, Tensor? attention_mask = null) + public Tensor forward(Tensor x, Tensor? attention_mask = null) { if (x.ndim != 3 || x.shape[2] != this.embed_dim) { throw new ArgumentException("The expected input shape is (batch, sequence, embed_dim=={self.embed_dim}). Found {x.shape}."); @@ -379,12 +379,12 @@ public SelfAttention( /// /// Layer that follows attention layer in encoder layer. /// - private class FeedForward : Module + private class FeedForward : Module { - public readonly Module intermediate_dense; - public readonly Module intermediate_dropout; - public readonly Module output_dense; - public readonly Module output_dropout; + public readonly Module intermediate_dense; + public readonly Module intermediate_dropout; + public readonly Module output_dense; + public readonly Module output_dropout; public FeedForward( string name, @@ -417,13 +417,13 @@ public override Tensor forward(Tensor x) /// /// A layer unit in encoder. Combines multihead self attention and feed forward. /// - private class EncoderLayer : Module + private class EncoderLayer : Module { public readonly SelfAttention attention; - public readonly Module dropout; - public readonly Module feed_forward; - public readonly Module final_layer_norm; - public readonly Module layer_norm; + public readonly Module dropout; + public readonly Module feed_forward; + public readonly Module final_layer_norm; + public readonly Module layer_norm; public bool layer_norm_first; public EncoderLayer( @@ -431,7 +431,7 @@ public EncoderLayer( SelfAttention attention, double dropout, bool layer_norm_first, - Module feed_forward) : base(name) + Module feed_forward) : base(name) { this.attention = attention; this.dropout = nn.Dropout(dropout); @@ -469,11 +469,11 @@ public override Tensor forward( } } - internal class Transformer : Module + internal class Transformer : Module { - public readonly Module dropout; + public readonly Module dropout; public readonly double layer_drop; - public readonly Module layer_norm; + public readonly Module layer_norm; public readonly bool layer_norm_first; public readonly ModuleList layers; @@ -515,7 +515,7 @@ public override Tensor forward( x = this._preprocess(x); foreach (var layer in this.layers) { if (!(this.training && torch.rand(1).item() <= this.layer_drop)) { - x = layer.forward(x, attention_mask); + x = ((nn.Module)layer).forward(x, attention_mask); } } @@ -539,7 +539,7 @@ public Tensor[] get_intermediate_outputs( var ret = new List(); x = this._preprocess(x); foreach (var layer in this.layers) { - x = layer.forward(x, attention_mask); + x = ((nn.Module)layer).forward(x, attention_mask); ret.Add(x); if (num_layers != null && ret.Count >= num_layers) { return ret.ToArray(); @@ -549,14 +549,14 @@ public Tensor[] get_intermediate_outputs( } } - internal class Encoder : Module + internal class Encoder : Module { - public readonly Module feature_projection; + public readonly Module feature_projection; public readonly Transformer transformer; public Encoder( string name, - Module feature_projection, + Module feature_projection, Transformer transformer) : base(name) { this.feature_projection = feature_projection; @@ -649,7 +649,7 @@ internal static FeatureExtractor _get_feature_extractor(FeatureExtractorNormMode var out_channels = shape[0]; var kernel_size = shape[1]; var stride = shape[2]; - Module? normalization = null; + Module? normalization = null; if (norm_mode == FeatureExtractorNormMode.group_norm && i == 0) { normalization = nn.GroupNorm( num_groups: out_channels, @@ -1080,7 +1080,7 @@ public MaskGenerator( /// The feature representations after masking. /// The generated mask indices. /// - public new (Tensor, Tensor?) forward(Tensor x, Tensor? padding_mask) + public (Tensor, Tensor?) forward(Tensor x, Tensor? padding_mask) { Tensor? mask_indices; var B = x.size(0); @@ -1154,7 +1154,7 @@ private static Tensor _compute_logits( /// internal class LogitGenerator : Module { - public readonly Module final_proj; + public readonly Module final_proj; public readonly Tensor label_embeddings; public readonly bool skip_masked; public readonly bool skip_nomask; diff --git a/src/TorchSharp/TorchAudio/Modules/Wav2Vec2Model.cs b/src/TorchSharp/TorchAudio/Modules/Wav2Vec2Model.cs index 6dff0ee9b..cf5f1ca38 100644 --- a/src/TorchSharp/TorchAudio/Modules/Wav2Vec2Model.cs +++ b/src/TorchSharp/TorchAudio/Modules/Wav2Vec2Model.cs @@ -24,7 +24,7 @@ public partial class Wav2Vec2Model : nn.Module { internal readonly FeatureExtractor feature_extractor; internal readonly Encoder encoder; - private readonly nn.Module? aux; + private readonly nn.Module? aux; /// /// Feature extractor that extracts feature vectors from raw audio Tensor. @@ -35,7 +35,7 @@ internal Wav2Vec2Model( string name, FeatureExtractor feature_extractor, Encoder encoder, - nn.Module? aux = null) : base(name) + nn.Module? aux = null) : base(name) { this.feature_extractor = feature_extractor; this.encoder = encoder; @@ -102,7 +102,7 @@ internal Wav2Vec2Model( /// is returned. /// It indicates the valid length in time axis of the output Tensor. /// - public new (Tensor, Tensor?) forward( + public (Tensor, Tensor?) forward( Tensor waveforms, Tensor? lengths = null) { diff --git a/src/TorchSharp/TorchAudio/Modules/WaveRNN.cs b/src/TorchSharp/TorchAudio/Modules/WaveRNN.cs index 89cf4a70e..78770356f 100644 --- a/src/TorchSharp/TorchAudio/Modules/WaveRNN.cs +++ b/src/TorchSharp/TorchAudio/Modules/WaveRNN.cs @@ -26,21 +26,21 @@ namespace TorchSharp.Modules /// /// This class is used to represent a WaveRNN module. /// - public class WaveRNN : nn.Module + public class WaveRNN : nn.Module { private readonly int _pad; - public readonly nn.Module fc; - public readonly nn.Module fc1; - public readonly nn.Module fc2; - public readonly nn.Module fc3; + public readonly nn.Module fc; + public readonly nn.Module fc1; + public readonly nn.Module fc2; + public readonly nn.Module fc3; public readonly int hop_length; public readonly int kernel_size; public readonly int n_aux; public readonly int n_bits; public readonly int n_classes; public readonly int n_rnn; - public readonly nn.Module relu1; - public readonly nn.Module relu2; + public readonly nn.Module relu1; + public readonly nn.Module relu2; public readonly GRU rnn1; public readonly GRU rnn2; internal readonly UpsampleNetwork upsample; @@ -222,9 +222,9 @@ public virtual (Tensor, Tensor?) infer(Tensor specgram, Tensor? lengths = null) return (torch.stack(output).permute(1, 2, 0), lengths); } - private class ResBlock : nn.Module + private class ResBlock : nn.Module { - public nn.Module resblock_model; + public nn.Module resblock_model; public ResBlock(string name, int n_freq = 128) : base(name) { @@ -243,9 +243,9 @@ public override Tensor forward(Tensor specgram) } } - internal class MelResNet : nn.Module + internal class MelResNet : nn.Module { - public readonly nn.Module melresnet_model; + public readonly nn.Module melresnet_model; public MelResNet( string name, @@ -255,7 +255,7 @@ public MelResNet( int n_output = 128, int kernel_size = 5) : base(name) { - var modules = new List(); + var modules = new List>(); modules.Add(nn.Conv1d(inputChannel: n_freq, outputChannel: n_hidden, kernelSize: kernel_size, bias: false)); modules.Add(nn.BatchNorm1d(n_hidden)); modules.Add(nn.ReLU(inplace: true)); @@ -273,7 +273,7 @@ public override Tensor forward(Tensor specgram) } } - public class Stretch2d : nn.Module + public class Stretch2d : nn.Module { public long freq_scale; public long time_scale; @@ -297,7 +297,7 @@ internal class UpsampleNetwork : nn.Module public readonly MelResNet resnet; public readonly Stretch2d resnet_stretch; public readonly long total_scale; - public readonly nn.Module upsample_layers; + public readonly nn.Module upsample_layers; public UpsampleNetwork( string name, @@ -318,7 +318,7 @@ public UpsampleNetwork( this.resnet = new MelResNet("melresnet", n_res_block, n_freq, n_hidden, n_output, kernel_size); this.resnet_stretch = new Stretch2d("stretch2d", total_scale, 1); - var up_layers = new List(); + var up_layers = new List>(); foreach (var scale in upsample_scales) { var stretch = new Stretch2d("stretch2d", scale, 1); var conv = nn.Conv2d(inputChannel: 1, outputChannel: 1, kernelSize: (1, scale * 2 + 1), padding: (0, scale), bias: false); @@ -330,7 +330,7 @@ public UpsampleNetwork( this.RegisterComponents(); } - public new (Tensor, Tensor) forward(Tensor specgram) + public (Tensor, Tensor) forward(Tensor specgram) { var resnet_output = this.resnet.forward(specgram).unsqueeze(1); resnet_output = this.resnet_stretch.forward(resnet_output); diff --git a/src/TorchSharp/TorchAudio/Wav2Vec2Models.cs b/src/TorchSharp/TorchAudio/Wav2Vec2Models.cs index 7f68f64c9..647e36204 100644 --- a/src/TorchSharp/TorchAudio/Wav2Vec2Models.cs +++ b/src/TorchSharp/TorchAudio/Wav2Vec2Models.cs @@ -191,7 +191,7 @@ public static Wav2Vec2Model wav2vec2_model( dropout: encoder_dropout, layer_norm_first: encoder_layer_norm_first, layer_drop: encoder_layer_drop); - Module? aux = null; + Module? aux = null; if (aux_num_out != null) { aux = torch.nn.Linear(inputSize: encoder_embed_dim, outputSize: aux_num_out.Value); } diff --git a/src/TorchSharp/TorchVision/Ops/Misc.cs b/src/TorchSharp/TorchVision/Ops/Misc.cs index ad6f16c50..3ce9e28f2 100644 --- a/src/TorchSharp/TorchVision/Ops/Misc.cs +++ b/src/TorchSharp/TorchVision/Ops/Misc.cs @@ -21,15 +21,15 @@ public static partial class torchvision { public static partial class ops { - private static nn.Module ConvNormActivation( + private static nn.Module ConvNormActivation( long in_channels, long out_channels, long kernel_size = 3, long stride = 1, long? padding = null, long groups = 1, - Func? norm_layer = null, - Func? activation_layer = null, + Func>? norm_layer = null, + Func>? activation_layer = null, long dilation = 1, bool inplace = true, bool? bias = null, @@ -43,7 +43,7 @@ private static nn.Module ConvNormActivation( bias = norm_layer == null; } - var layers = new List(); + var layers = new List>(); if (rank == 2) { layers.Add( nn.Conv2d( @@ -94,15 +94,15 @@ private static nn.Module ConvNormActivation( /// Spacing between kernel elements. /// Parameter for the activation layer, which can optionally do the operation in-place. /// Whether to use bias in the convolution layer. By default, biases are included if ``norm_layer is null``. - public static nn.Module Conv2dNormActivation( + public static nn.Module Conv2dNormActivation( long in_channels, long out_channels, long kernel_size = 3, long stride = 1, long? padding = null, long groups = 1, - Func? norm_layer = null, - Func? activation_layer = null, + Func>? norm_layer = null, + Func>? activation_layer = null, long dilation = 1, bool inplace = true, bool? bias = null) @@ -136,15 +136,15 @@ public static nn.Module Conv2dNormActivation( /// Spacing between kernel elements. /// Parameter for the activation layer, which can optionally do the operation in-place. /// Whether to use bias in the convolution layer. By default, biases are included if ``norm_layer is null``. - public static nn.Module Conv3dNormActivation( + public static nn.Module Conv3dNormActivation( long in_channels, long out_channels, long kernel_size = 3, long stride = 1, long? padding = null, long groups = 1, - Func? norm_layer = null, - Func? activation_layer = null, + Func>? norm_layer = null, + Func>? activation_layer = null, long dilation = 1, bool inplace = true, bool? bias = null) @@ -164,13 +164,13 @@ public static nn.Module Conv3dNormActivation( rank: 3); } - internal class SqueezeExcitation : torch.nn.Module + internal class SqueezeExcitation : torch.nn.Module { - private readonly nn.Module avgpool; - private readonly nn.Module fc1; - private readonly nn.Module fc2; - private readonly nn.Module activation; - private readonly nn.Module scale_activation; + private readonly nn.Module avgpool; + private readonly nn.Module fc1; + private readonly nn.Module fc2; + private readonly nn.Module activation; + private readonly nn.Module scale_activation; /// /// This block implements the Squeeze-and-Excitation block from https://arxiv.org/abs/1709.01507 (see Fig. 1). @@ -185,8 +185,8 @@ public SqueezeExcitation( string name, long input_channels, long squeeze_channels, - Func activation, - Func scale_activation) : base(name) + Func> activation, + Func> scale_activation) : base(name) { this.avgpool = torch.nn.AdaptiveAvgPool2d(1); this.fc1 = torch.nn.Conv2d(input_channels, squeeze_channels, 1); diff --git a/src/TorchSharp/TorchVision/models/AlexNet.cs b/src/TorchSharp/TorchVision/models/AlexNet.cs index 543576ad9..48fe409a2 100644 --- a/src/TorchSharp/TorchVision/models/AlexNet.cs +++ b/src/TorchSharp/TorchVision/models/AlexNet.cs @@ -60,11 +60,11 @@ namespace Modules // https://github.com/pytorch/vision/blob/main/torchvision/models/alexnet.py // Licence and copypright notice at: https://github.com/pytorch/vision/blob/main/LICENSE - public class AlexNet : Module + public class AlexNet : Module { - private readonly Module features; - private readonly Module avgpool; - private readonly Module classifier; + private readonly Module features; + private readonly Module avgpool; + private readonly Module classifier; public AlexNet(int numClasses, float dropout = 0.5f, string weights_file = null, bool skipfc = true, Device device = null) : base(nameof(AlexNet)) { diff --git a/src/TorchSharp/TorchVision/models/GoogleNet.cs b/src/TorchSharp/TorchVision/models/GoogleNet.cs index 522b63ce8..eb65d13bf 100644 --- a/src/TorchSharp/TorchVision/models/GoogleNet.cs +++ b/src/TorchSharp/TorchVision/models/GoogleNet.cs @@ -64,28 +64,28 @@ public static Modules.GoogleNet googlenet( namespace Modules { - public class GoogleNet : Module + public class GoogleNet : Module { // The code here is based on // https://github.com/pytorch/vision/blob/main/torchvision/models/googlenet.py // Licence and copypright notice at: https://github.com/pytorch/vision/blob/main/LICENSE - private readonly Module conv1; - private readonly Module maxpool1; - private readonly Module conv2; - private readonly Module conv3; - private readonly Module maxpool2; - private readonly Module inception3a; - private readonly Module inception3b; - private readonly Module maxpool3; - private readonly Module inception4a; - private readonly Module inception4b; - private readonly Module inception4c; - private readonly Module inception4d; - private readonly Module inception4e; - private readonly Module maxpool4; - private readonly Module inception5a; - private readonly Module inception5b; + private readonly Module conv1; + private readonly Module maxpool1; + private readonly Module conv2; + private readonly Module conv3; + private readonly Module maxpool2; + private readonly Module inception3a; + private readonly Module inception3b; + private readonly Module maxpool3; + private readonly Module inception4a; + private readonly Module inception4b; + private readonly Module inception4c; + private readonly Module inception4d; + private readonly Module inception4e; + private readonly Module maxpool4; + private readonly Module inception5a; + private readonly Module inception5b; //private readonly Module aux1; //private readonly Module aux2; @@ -165,7 +165,7 @@ public GoogleNet(int numClasses = 1000, } - private static Module conv_block(int in_channels, int out_channels, int kernel_size, int stride = 1, int padding = 0) + private static Module conv_block(int in_channels, int out_channels, int kernel_size, int stride = 1, int padding = 0) { return Sequential( ("conv", Conv2d(in_channels, out_channels, bias: false, kernelSize: kernel_size, stride: stride, padding: padding)), @@ -174,7 +174,7 @@ private static Module conv_block(int in_channels, int out_channels, int kernel_s ); } - private static Module conv_block(int in_channels, int out_channels, (long, long) kernel_size, (long, long)? stride = null, (long, long)? padding = null) + private static Module conv_block(int in_channels, int out_channels, (long, long) kernel_size, (long, long)? stride = null, (long, long)? padding = null) { return Sequential( ("conv", Conv2d(in_channels, out_channels, bias: false, kernelSize: kernel_size, stride: stride, padding: padding)), @@ -183,8 +183,8 @@ private static Module conv_block(int in_channels, int out_channels, (long, long) ); } - private Module inception_block(int in_channels, int ch1x1, int ch3x3red, int ch3x3, int ch5x5red, int ch5x5, int pool_proj) => new Inception(in_channels, ch1x1, ch3x3red, ch3x3, ch5x5red, ch5x5, pool_proj); - private Module inception_aux_block(int in_channels, int num_classes, float dropout) => new InceptionAux(in_channels, num_classes, dropout); + private Module inception_block(int in_channels, int ch1x1, int ch3x3red, int ch3x3, int ch5x5red, int ch5x5, int pool_proj) => new Inception(in_channels, ch1x1, ch3x3red, ch3x3, ch5x5red, ch5x5, pool_proj); + private Module inception_aux_block(int in_channels, int num_classes, float dropout) => new InceptionAux(in_channels, num_classes, dropout); public override Tensor forward(Tensor x) { @@ -253,7 +253,7 @@ public override Tensor forward(Tensor x) } } - class Inception : Module + class Inception : Module { public Inception(int in_channels, int ch1x1, int ch3x3red, int ch3x3, int ch5x5red, int ch5x5, int pool_proj) : base("Inception") { @@ -284,18 +284,18 @@ public override Tensor forward(Tensor x) return torch.cat(outputs, 1); } - private readonly Module branch1; - private readonly Module branch2; - private readonly Module branch3; - private readonly Module branch4; + private readonly Module branch1; + private readonly Module branch2; + private readonly Module branch3; + private readonly Module branch4; } - class InceptionAux : Module + class InceptionAux : Module { - private readonly Module conv; - private readonly Module fc1; - private readonly Module fc2; - private readonly Module dropout; + private readonly Module conv; + private readonly Module fc1; + private readonly Module fc2; + private readonly Module dropout; public InceptionAux(int in_channels, int num_classes, float dropout = 0.7f) : base("InceptionAux") { diff --git a/src/TorchSharp/TorchVision/models/InceptionV3.cs b/src/TorchSharp/TorchVision/models/InceptionV3.cs index c2e196209..22b43cf8a 100644 --- a/src/TorchSharp/TorchVision/models/InceptionV3.cs +++ b/src/TorchSharp/TorchVision/models/InceptionV3.cs @@ -62,32 +62,32 @@ public static Modules.InceptionV3 inception_v3( namespace Modules { - public class InceptionV3 : Module + public class InceptionV3 : Module { // The code here is is loosely based on // https://github.com/pytorch/vision/blob/main/torchvision/models/inception.py // Licence and copypright notice at: https://github.com/pytorch/vision/blob/main/LICENSE - private readonly Module Conv2d_1a_3x3; - private readonly Module Conv2d_2a_3x3; - private readonly Module Conv2d_2b_3x3; - private readonly Module maxpool1; - private readonly Module Conv2d_3b_1x1; - private readonly Module Conv2d_4a_3x3; - private readonly Module maxpool2; - - private readonly Module Mixed_5b; - private readonly Module Mixed_5c; - private readonly Module Mixed_5d; - private readonly Module Mixed_6a; - private readonly Module Mixed_6b; - private readonly Module Mixed_6c; - private readonly Module Mixed_6d; - private readonly Module Mixed_6e; - private readonly Module AuxLogits; - private readonly Module Mixed_7a; - private readonly Module Mixed_7b; - private readonly Module Mixed_7c; + private readonly Module Conv2d_1a_3x3; + private readonly Module Conv2d_2a_3x3; + private readonly Module Conv2d_2b_3x3; + private readonly Module maxpool1; + private readonly Module Conv2d_3b_1x1; + private readonly Module Conv2d_4a_3x3; + private readonly Module maxpool2; + + private readonly Module Mixed_5b; + private readonly Module Mixed_5c; + private readonly Module Mixed_5d; + private readonly Module Mixed_6a; + private readonly Module Mixed_6b; + private readonly Module Mixed_6c; + private readonly Module Mixed_6d; + private readonly Module Mixed_6e; + private readonly Module AuxLogits; + private readonly Module Mixed_7a; + private readonly Module Mixed_7b; + private readonly Module Mixed_7c; private readonly AdaptiveAvgPool2d avgpool; private Dropout dropout; private readonly Linear fc; @@ -165,7 +165,7 @@ public InceptionV3(int numClasses = 1000, } - private static Module conv_block(int in_channels, int out_channels, int kernel_size, int stride = 1, int padding = 0) + private static Module conv_block(int in_channels, int out_channels, int kernel_size, int stride = 1, int padding = 0) { return Sequential( ("conv", Conv2d(in_channels, out_channels, bias: false, kernelSize: kernel_size, stride: stride, padding: padding)), @@ -174,7 +174,7 @@ private static Module conv_block(int in_channels, int out_channels, int kernel_s ); } - private static Module conv_block(int in_channels, int out_channels, (long, long) kernel_size, (long, long)? stride = null, (long, long)? padding = null) + private static Module conv_block(int in_channels, int out_channels, (long, long) kernel_size, (long, long)? stride = null, (long, long)? padding = null) { return Sequential( ("conv", Conv2d(in_channels, out_channels, bias: false, kernelSize: kernel_size, stride: stride, padding: padding)), @@ -183,12 +183,12 @@ private static Module conv_block(int in_channels, int out_channels, (long, long) ); } - private Module inception_a(int in_channels, int pool_features) => new InceptionA(in_channels, pool_features); - private Module inception_b(int in_channels) => new InceptionB(in_channels); - private Module inception_c(int in_channels, int channels_7x7) => new InceptionC(in_channels, channels_7x7); - private Module inception_d(int in_channels) => new InceptionD(in_channels); - private Module inception_e(int in_channels) => new InceptionE(in_channels); - private Module inception_aux(int in_channels, int num_classes) => new InceptionAux(in_channels, num_classes); + private Module inception_a(int in_channels, int pool_features) => new InceptionA(in_channels, pool_features); + private Module inception_b(int in_channels) => new InceptionB(in_channels); + private Module inception_c(int in_channels, int channels_7x7) => new InceptionC(in_channels, channels_7x7); + private Module inception_d(int in_channels) => new InceptionD(in_channels); + private Module inception_e(int in_channels) => new InceptionE(in_channels); + private Module inception_aux(int in_channels, int num_classes) => new InceptionAux(in_channels, num_classes); public override Tensor forward(Tensor x) { @@ -253,7 +253,7 @@ public override Tensor forward(Tensor x) } } - class InceptionA : Module + class InceptionA : Module { public InceptionA(int in_channels, int pool_features) : base("InceptionA") { @@ -286,16 +286,16 @@ public override Tensor forward(Tensor x) return torch.cat(outputs, 1); } - private readonly Module branch1x1; - private readonly Module branch5x5_1; - private readonly Module branch5x5_2; - private readonly Module branch3x3dbl_1; - private readonly Module branch3x3dbl_2; - private readonly Module branch3x3dbl_3; - private readonly Module branch_pool; + private readonly Module branch1x1; + private readonly Module branch5x5_1; + private readonly Module branch5x5_2; + private readonly Module branch3x3dbl_1; + private readonly Module branch3x3dbl_2; + private readonly Module branch3x3dbl_3; + private readonly Module branch_pool; } - class InceptionB : Module + class InceptionB : Module { public InceptionB(int in_channels) : base("InceptionB") { @@ -324,24 +324,24 @@ public override Tensor forward(Tensor x) return torch.cat(outputs, 1); } - private readonly Module branch3x3; - private readonly Module branch3x3dbl_1; - private readonly Module branch3x3dbl_2; - private readonly Module branch3x3dbl_3; + private readonly Module branch3x3; + private readonly Module branch3x3dbl_1; + private readonly Module branch3x3dbl_2; + private readonly Module branch3x3dbl_3; } - class InceptionC : Module + class InceptionC : Module { - private readonly Module branch1x1; - private readonly Module branch7x7_1; - private readonly Module branch7x7_2; - private readonly Module branch7x7_3; - private readonly Module branch7x7dbl_1; - private readonly Module branch7x7dbl_2; - private readonly Module branch7x7dbl_3; - private readonly Module branch7x7dbl_4; - private readonly Module branch7x7dbl_5; - private readonly Module branch_pool; + private readonly Module branch1x1; + private readonly Module branch7x7_1; + private readonly Module branch7x7_2; + private readonly Module branch7x7_3; + private readonly Module branch7x7dbl_1; + private readonly Module branch7x7dbl_2; + private readonly Module branch7x7dbl_3; + private readonly Module branch7x7dbl_4; + private readonly Module branch7x7dbl_5; + private readonly Module branch_pool; public InceptionC(int in_channels, int channels_7x7) : base("InceptionC") { @@ -386,14 +386,14 @@ public override Tensor forward(Tensor x) } } - class InceptionD : Module + class InceptionD : Module { - private readonly Module branch3x3_1; - private readonly Module branch3x3_2; - private readonly Module branch7x7x3_1; - private readonly Module branch7x7x3_2; - private readonly Module branch7x7x3_3; - private readonly Module branch7x7x3_4; + private readonly Module branch3x3_1; + private readonly Module branch3x3_2; + private readonly Module branch7x7x3_1; + private readonly Module branch7x7x3_2; + private readonly Module branch7x7x3_3; + private readonly Module branch7x7x3_4; public InceptionD(int in_channels) : base("InceptionD") { @@ -426,17 +426,17 @@ public override Tensor forward(Tensor x) } } - class InceptionE : Module + class InceptionE : Module { - private readonly Module branch1x1; - private readonly Module branch3x3_1; - private readonly Module branch3x3_2a; - private readonly Module branch3x3_2b; - private readonly Module branch3x3dbl_1; - private readonly Module branch3x3dbl_2; - private readonly Module branch3x3dbl_3a; - private readonly Module branch3x3dbl_3b; - private readonly Module branch_pool; + private readonly Module branch1x1; + private readonly Module branch3x3_1; + private readonly Module branch3x3_2a; + private readonly Module branch3x3_2b; + private readonly Module branch3x3dbl_1; + private readonly Module branch3x3dbl_2; + private readonly Module branch3x3dbl_3a; + private readonly Module branch3x3dbl_3b; + private readonly Module branch_pool; public InceptionE(int in_channels) : base("InceptionE") { @@ -478,11 +478,11 @@ public override Tensor forward(Tensor x) } } - class InceptionAux : Module + class InceptionAux : Module { - private readonly Module conv0; - private readonly Module conv1; - private readonly Module fc; + private readonly Module conv0; + private readonly Module conv1; + private readonly Module fc; public InceptionAux(int in_channels, int num_classes) : base("InceptionAux") { diff --git a/src/TorchSharp/TorchVision/models/MobileNetV2.cs b/src/TorchSharp/TorchVision/models/MobileNetV2.cs index b56c8d1f8..dd1034eca 100644 --- a/src/TorchSharp/TorchVision/models/MobileNetV2.cs +++ b/src/TorchSharp/TorchVision/models/MobileNetV2.cs @@ -25,12 +25,12 @@ namespace Modules /// /// MobileNet V2 main class /// - public class MobileNetV2 : nn.Module + public class MobileNetV2 : nn.Module { - private class InvertedResidual : nn.Module + private class InvertedResidual : nn.Module { private readonly bool _is_cn; - private readonly nn.Module conv; + private readonly nn.Module conv; private readonly long out_channels; private readonly long stride; private readonly bool use_res_connect; @@ -41,7 +41,7 @@ public InvertedResidual( long oup, long stride, double expand_ratio, - Func? norm_layer = null) : base(name) + Func>? norm_layer = null) : base(name) { this.stride = stride; if (stride != 1 && stride != 2) { @@ -55,7 +55,7 @@ public InvertedResidual( var hidden_dim = (long)Math.Round(inp * expand_ratio); this.use_res_connect = this.stride == 1 && inp == oup; - var layers = new List(); + var layers = new List>(); if (expand_ratio != 1) { // pw layers.Add( @@ -66,7 +66,7 @@ public InvertedResidual( norm_layer: norm_layer, activation_layer: (inplace) => nn.ReLU6(inplace))); } - layers.AddRange(new List { + layers.AddRange(new List> { // dw ops.Conv2dNormActivation( hidden_dim, @@ -95,8 +95,8 @@ public override Tensor forward(Tensor x) } } - private readonly nn.Module classifier; - private readonly nn.Module features; + private readonly nn.Module classifier; + private readonly nn.Module features; private readonly long last_channel; internal MobileNetV2( @@ -105,8 +105,8 @@ internal MobileNetV2( double width_mult = 1.0, long[][]? inverted_residual_setting = null, long round_nearest = 8, - Func, nn.Module>? block = null, - Func? norm_layer = null, + Func>, nn.Module>? block = null, + Func>? norm_layer = null, double dropout = 0.2) : base(name) { if (block == null) { @@ -141,7 +141,7 @@ internal MobileNetV2( // building first layer input_channel = _make_divisible(input_channel * width_mult, round_nearest); this.last_channel = _make_divisible(last_channel * Math.Max(1.0, width_mult), round_nearest); - var features = new List { + var features = new List> { ops.Conv2dNormActivation(3, input_channel, stride: 2, norm_layer: norm_layer, activation_layer: (inplace) => nn.ReLU6(inplace)) }; // building inverted residual blocks @@ -233,8 +233,8 @@ public static Modules.MobileNetV2 mobilenet_v2( double width_mult = 1.0, long[][]? inverted_residual_setting = null, long round_nearest = 8, - Func, nn.Module>? block = null, - Func? norm_layer = null, + Func>, nn.Module>? block = null, + Func>? norm_layer = null, double dropout = 0.2) { return new Modules.MobileNetV2( diff --git a/src/TorchSharp/TorchVision/models/MobileNetV3.cs b/src/TorchSharp/TorchVision/models/MobileNetV3.cs index d2706d063..7d5ce0bfc 100644 --- a/src/TorchSharp/TorchVision/models/MobileNetV3.cs +++ b/src/TorchSharp/TorchVision/models/MobileNetV3.cs @@ -25,7 +25,7 @@ namespace TorchSharp { namespace Modules { - public class MobileNetV3 : nn.Module + public class MobileNetV3 : nn.Module { /// /// Stores information listed at Tables 1 and 2 of the MobileNetV3 paper @@ -71,18 +71,18 @@ internal static long adjust_channels(long channels, double width_mult) /// /// Implemented as described at section 5 of MobileNetV3 paper /// - private class InvertedResidual : nn.Module + private class InvertedResidual : nn.Module { private readonly bool _is_cn; - private readonly nn.Module block; + private readonly nn.Module block; private readonly long out_channels; private readonly bool use_res_connect; public InvertedResidual( string name, InvertedResidualConfig cnf, - Func norm_layer, - Func? se_layer = null) : base(name) + Func> norm_layer, + Func>? se_layer = null) : base(name) { if (!(1 <= cnf.stride && cnf.stride <= 2)) { throw new ArgumentException("illegal stride value"); @@ -90,8 +90,8 @@ public InvertedResidual( this.use_res_connect = cnf.stride == 1 && cnf.input_channels == cnf.out_channels; - var layers = new List(); - Func activation_layer = ( + var layers = new List>(); + Func> activation_layer = ( cnf.use_hs ? (inplace) => nn.Hardswish(inplace) : (inplace) => nn.ReLU(inplace)); // expand @@ -152,9 +152,9 @@ public override Tensor forward(Tensor input) } } - private readonly nn.Module avgpool; - private readonly nn.Module classifier; - private readonly nn.Module features; + private readonly nn.Module avgpool; + private readonly nn.Module classifier; + private readonly nn.Module features; /// /// MobileNet V3 main class @@ -172,8 +172,8 @@ internal MobileNetV3( InvertedResidualConfig[] inverted_residual_setting, long last_channel, long num_classes = 1000, - Func, nn.Module>? block = null, - Func? norm_layer = null, + Func>, nn.Module>? block = null, + Func>? norm_layer = null, double dropout = 0.2) : base(name) { if (inverted_residual_setting == null || inverted_residual_setting.Length == 0) { @@ -188,7 +188,7 @@ internal MobileNetV3( norm_layer = (features) => nn.BatchNorm2d(features, eps: 0.001, momentum: 0.01); } - var layers = new List(); + var layers = new List>(); // building first layer var firstconv_output_channels = inverted_residual_setting[0].input_channels; diff --git a/src/TorchSharp/TorchVision/models/ResNet.cs b/src/TorchSharp/TorchVision/models/ResNet.cs index b7b1d8642..011db3c73 100644 --- a/src/TorchSharp/TorchVision/models/ResNet.cs +++ b/src/TorchSharp/TorchVision/models/ResNet.cs @@ -213,25 +213,25 @@ public static Modules.ResNet resnet152(int num_classes = 1000, namespace Modules { - public class ResNet : Module + public class ResNet : Module { // The code here is based on // https://github.com/pytorch/vision/blob/main/torchvision/models/resnet.py // Licence and copypright notice at: https://github.com/pytorch/vision/blob/main/LICENSE - private readonly Module conv1; - private readonly Module bn1; - private readonly Module relu; - private readonly Module maxpool; + private readonly Module conv1; + private readonly Module bn1; + private readonly Module relu; + private readonly Module maxpool; private readonly Sequential layer1 = Sequential(); private readonly Sequential layer2 = Sequential(); private readonly Sequential layer3 = Sequential(); private readonly Sequential layer4 = Sequential(); - private readonly Module avgpool; - private readonly Module flatten; - private readonly Module fc; + private readonly Module avgpool; + private readonly Module flatten; + private readonly Module fc; private int in_planes = 64; @@ -311,14 +311,14 @@ public static ResNet ResNet152(int numClasses, } public ResNet(string name, - Func block, + Func> block, int expansion, IList num_blocks, int numClasses, string weights_file = null, bool skipfc = true, Device device = null) : base(name) { - var modules = new List<(string, Module)>(); + var modules = new List<(string, Module)>(); conv1 = Conv2d(3, 64, kernelSize: 7, stride: 2, padding: 3, bias: false); bn1 = BatchNorm2d(64); @@ -358,7 +358,7 @@ public ResNet(string name, this.to(device); } - private void MakeLayer(Sequential modules, Func block, int expansion, int planes, int num_blocks, int stride) + private void MakeLayer(Sequential modules, Func> block, int expansion, int planes, int num_blocks, int stride) { var strides = new List(); strides.Add(stride); @@ -388,11 +388,11 @@ public override Tensor forward(Tensor input) } } - class BasicBlock : Module + class BasicBlock : Module { public BasicBlock(int in_planes, int planes, int stride) : base("BasicBlock") { - var modules = new List<(string, Module)>(); + var modules = new List<(string, Module)>(); conv1 = Conv2d(in_planes, planes, kernelSize: 3, stride: stride, padding: 1, bias: false); bn1 = BatchNorm2d(planes); @@ -416,25 +416,25 @@ public override Tensor forward(Tensor input) x = bn2.forward(conv2.forward(x)); var y = input; - foreach (var m in downsample) y = m.forward(y); + foreach (var m in downsample) y = ((nn.Module)m).forward(y); return x.add_(y).relu_(); } public static int expansion = 1; - private readonly Module conv1; - private readonly Module bn1; - private readonly Module conv2; + private readonly Module conv1; + private readonly Module bn1; + private readonly Module conv2; private readonly TorchSharp.Modules.BatchNorm2d bn2; - private readonly Module relu1; + private readonly Module relu1; private readonly TorchSharp.Modules.ModuleList downsample = new TorchSharp.Modules.ModuleList(); } - class Bottleneck : Module + class Bottleneck : Module { public Bottleneck(int in_planes, int planes, int stride) : base("Bottleneck") { - var modules = new List<(string, Module)>(); + var modules = new List<(string, Module)>(); conv1 = Conv2d(in_planes, planes, kernelSize: 1, bias: false); bn1 = BatchNorm2d(planes); @@ -462,21 +462,21 @@ public override Tensor forward(Tensor input) x = bn3.forward(conv3.forward(x)); var y = input; - foreach (var m in downsample) y = m.forward(y); + foreach (var m in downsample) y = ((nn.Module)m).forward(y); return x.add_(y).relu_(); } public static int expansion = 4; - private readonly Module conv1; - private readonly Module bn1; - private readonly Module conv2; - private readonly Module bn2; - private readonly Module conv3; + private readonly Module conv1; + private readonly Module bn1; + private readonly Module conv2; + private readonly Module bn2; + private readonly Module conv3; private readonly TorchSharp.Modules.BatchNorm2d bn3; - private readonly Module relu1; - private readonly Module relu2; + private readonly Module relu1; + private readonly Module relu2; private readonly TorchSharp.Modules.ModuleList downsample = new TorchSharp.Modules.ModuleList(); } diff --git a/src/TorchSharp/TorchVision/models/VGG.cs b/src/TorchSharp/TorchVision/models/VGG.cs index 5896be377..fc48c0441 100644 --- a/src/TorchSharp/TorchVision/models/VGG.cs +++ b/src/TorchSharp/TorchVision/models/VGG.cs @@ -322,7 +322,7 @@ public static Modules.VGG vgg19_bn(int num_classes = 1000, float dropout = 0.5f, namespace Modules { - public class VGG : Module + public class VGG : Module { // The code here is based on // https://github.com/pytorch/vision/blob/main/torchvision/models/vgg.py @@ -335,9 +335,9 @@ public class VGG : Module { "VGG19", new long[] { 64, 64, 0, 128, 128, 0, 256, 256, 256, 256, 0, 512, 512, 512, 512, 0, 512, 512, 512, 512, 0 } } }; - private readonly Module features; - private readonly Module avgpool; - private readonly Module classifier; + private readonly Module features; + private readonly Module avgpool; + private readonly Module classifier; public VGG(string name, int numClasses, @@ -347,7 +347,7 @@ public VGG(string name, bool skipfc = true, Device device = null) : base(name) { - var layers = new List(); + var layers = new List>(); var channels = _channels[name]; diff --git a/test/TorchSharpTest.WithCudaBinaries/TorchSharpTest.WithCudaBinaries.csproj b/test/TorchSharpTest.WithCudaBinaries/TorchSharpTest.WithCudaBinaries.csproj index cc3d298e3..923178a4a 100644 --- a/test/TorchSharpTest.WithCudaBinaries/TorchSharpTest.WithCudaBinaries.csproj +++ b/test/TorchSharpTest.WithCudaBinaries/TorchSharpTest.WithCudaBinaries.csproj @@ -22,6 +22,7 @@ + @@ -63,6 +64,12 @@ PreserveNewest + + PreserveNewest + + + PreserveNewest + diff --git a/test/TorchSharpTest/NN.cs b/test/TorchSharpTest/NN.cs index 45dfd8b15..e3cc72327 100644 --- a/test/TorchSharpTest/NN.cs +++ b/test/TorchSharpTest/NN.cs @@ -1534,12 +1534,12 @@ public void TestSetGrad() Assert.False(x.requires_grad); } - private class CondModel : Module + private class CondModel : Module { - private Module fb = Linear(1000, 100, false); - private Module fbT1 = Linear(100, 10, false); - private Module fbF1 = Linear(100, 50, false); - private Module fbF2 = Linear(50, 10, false); + private Module fb = Linear(1000, 100, false); + private Module fbT1 = Linear(100, 10, false); + private Module fbF1 = Linear(100, 50, false); + private Module fbF2 = Linear(50, 10, false); private bool _isTrue = false; public CondModel(string name, bool isTrue) : base(name) @@ -2073,7 +2073,7 @@ public void TestCustomModule4() Assert.True(seq.has_parameter("0.dict.second")); } - private class TestModule1 : Module + private class TestModule1 : Module { public TestModule1(Tensor tensor, bool withGrad) : base("TestModule1") @@ -2095,7 +2095,7 @@ public override Tensor forward(Tensor input) private ParameterDict dict = new ParameterDict(); } - private class TestModule2 : Module + private class TestModule2 : Module { public TestModule2(Tensor tensor, bool withGrad) : base("TestModule1") @@ -2110,11 +2110,11 @@ public TestModule2(Tensor tensor, bool withGrad) public override Tensor forward(Tensor input) { - for (int i = 0; i < list.Count; i++) { input = list[i].forward(input); } + for (int i = 0; i < list.Count; i++) { input = ((nn.Module)list[i]).forward(input); } throw new NotImplementedException(); } - public Module submodule; + public Module submodule; private ModuleList list = new ModuleList(); private ModuleDict dict = new ModuleDict(); } @@ -2172,7 +2172,7 @@ public void TestDeviceTo() } } - private class TestModule3 : Module + private class TestModule3 : Module { public TestModule3() : base(nameof(TestModule3)) { RegisterComponents(); } diff --git a/test/TorchSharpTest/TestJIT.cs b/test/TorchSharpTest/TestJIT.cs new file mode 100644 index 000000000..a00d011fa --- /dev/null +++ b/test/TorchSharpTest/TestJIT.cs @@ -0,0 +1,207 @@ +// Copyright (c) .NET Foundation and Contributors. All Rights Reserved. See LICENSE in the project root for license information. +using System; +using System.IO; +using System.Linq; +using TorchSharp.Modules; +using static TorchSharp.torch.nn; +using Xunit; +using Google.Protobuf; +using Tensorboard; +using static TorchSharp.torch.utils.tensorboard; +using ICSharpCode.SharpZipLib; +using System.Collections.Generic; + +#nullable enable + +namespace TorchSharp +{ +#if NET472_OR_GREATER + [Collection("Sequential")] +#endif // NET472_OR_GREATER + public class TestJIT + { + + + [Fact] + public void TestLoadJIT_Func() + { + // One linear layer followed by ReLU. + using var m = torch.jit.load(@"func.script.dat"); + + var sms = m.named_modules().ToArray(); + Assert.Empty(sms); + + var kids = m.named_children().ToArray(); + Assert.Empty(kids); + + var t = m.forward(torch.ones(10), torch.ones(10)); + + Assert.Equal(new long[] { 10 }, t.shape); + Assert.Equal(torch.float32, t.dtype); + Assert.True(torch.tensor(new float[] { 2, 2, 2, 2, 2, 2, 2, 2, 2, 2 }).allclose(t)); + } + + [Fact] + public void TestLoadJIT_1() + { + // One linear layer followed by ReLU. + using var m = torch.jit.load(@"linrelu.script.dat"); + var t = m.forward(torch.ones(10)); + + Assert.Equal(new long[] { 6 }, t.shape); + Assert.Equal(torch.float32, t.dtype); + Assert.True(torch.tensor(new float[] { 0.313458264f, 0, 0.9996568f, 0, 0, 0 }).allclose(t)); + } + + [Fact] + public void TestSaveJIT() + { + var location = "TestSaveJIT.ts"; + if (File.Exists(location)) File.Delete(location); + + try { + + // One linear layer followed by ReLU. + using var m1 = torch.jit.load(@"linrelu.script.dat"); + + torch.jit.save(m1, location); + using var m2 = torch.jit.load(location); + + var t = m2.forward(torch.ones(10)); + + Assert.Equal(new long[] { 6 }, t.shape); + Assert.Equal(torch.float32, t.dtype); + Assert.True(torch.tensor(new float[] { 0.313458264f, 0, 0.9996568f, 0, 0, 0 }).allclose(t)); + + } finally { + if (File.Exists(location)) File.Delete(location); + } + } + + [Fact] + public void TestLoadJIT_2() + { + // One linear layer followed by ReLU. + using var m = torch.jit.load(@"scripted.script.dat"); + var t = m.forward(torch.ones(6)); + + Assert.Equal(new long[] { 6 }, t.shape); + Assert.Equal(torch.float32, t.dtype); + Assert.True(torch.tensor(new float[] { 1.554085f, 1.01024628f, -1.35086036f, -1.84021854f, 0.0127189457f, 0.5994258f }).allclose(t)); + } + + [Fact] + public void TestLoadJIT_3() + { + // Two linear layers, nested Sequential, ReLU in between. + using var m = torch.jit.load(@"l1000_100_10.script.dat"); + + var sms = m.named_modules().ToArray(); + Assert.Equal(4, sms.Length); + + var kids = m.named_children().ToArray(); + Assert.Equal(2, kids.Length); + + var t = m.forward(torch.ones(1000)); + + Assert.Equal(new long[] { 10 }, t.shape); + Assert.Equal(torch.float32, t.dtype); + Assert.True(torch.tensor(new float[] { 0.564213157f, -0.04519982f, -0.005117342f, 0.395530462f, -0.3780813f, -0.004734449f, -0.3221216f, -0.289159119f, 0.268511474f, 0.180702567f }).allclose(t)); + + Assert.Throws(() => m.forward(torch.ones(100))); + } + + [Fact] + public void TestLoadJIT_4() + { + // Definitely not a TorchScript file. Let's see what the runtime does with it. + Assert.Throws(() => torch.jit.load(@"bug510.dat")); + } + + [Fact] + public void TestSaveLoadJITCUDA() + { + if (torch.cuda.is_available()) { + + using var m = torch.jit.load(@"linrelu.script.dat"); + + m.to(DeviceType.CUDA); + var params0 = m.parameters().ToArray(); + foreach (var p in params0) + Assert.Equal(DeviceType.CUDA, p.device_type); + + var t = m.forward(torch.ones(10).cuda()).cpu(); + + Assert.Equal(new long[] { 6 }, t.shape); + Assert.Equal(torch.float32, t.dtype); + Assert.True(torch.tensor(new float[] { 0.313458264f, 0, 0.9996568f, 0, 0, 0 }).allclose(t)); + } + } + + [Fact] + public void TestJIT_TupleOut() + { + // def a(x, y): + // return x + y, x - y + // + using var m = torch.jit.load<(torch.Tensor, torch.Tensor)>(@"tuple_out.dat"); + + var x = torch.rand(3, 4); + var y = torch.rand(3, 4); + var output = m.forward(x, y); + + Assert.Multiple( + () => Assert.Equal(x.shape, output.Item1.shape), + () => Assert.Equal(x.shape, output.Item2.shape), + () => Assert.Equal(x + y, output.Item1), + () => Assert.Equal(x - y, output.Item2) + ); + } + + [Fact] + public void TestJIT_TupleOutError() + { + // def a(x, y): + // return x + y, x - y + // + using var m = torch.jit.load< (torch.Tensor, torch.Tensor)>(@"func.script.dat"); + + var x = torch.rand(3, 4); + var y = torch.rand(3, 4); + Assert.Throws(() => m.forward(x, y)); + } + + [Fact] + public void TestJIT_ListOut() + { + // def a(x, y): + // return [x + y, x - y] + // + using var m = torch.jit.load(@"list_out.dat"); + + var x = torch.rand(3, 4); + var y = torch.rand(3, 4); + var output = m.forward(x, y); + + Assert.Multiple( + () => Assert.Equal(x.shape, output[0].shape), + () => Assert.Equal(x.shape, output[1].shape), + () => Assert.Equal(x + y, output[0]), + () => Assert.Equal(x - y, output[1]) + ); + } + + [Fact] + public void TestJIT_ListOutError() + { + // def a(x, y): + // return x + y, x - y + // + using var m = torch.jit.load(@"func.script.dat"); + + var x = torch.rand(3, 4); + var y = torch.rand(3, 4); + Assert.Throws(() => m.forward(x, y)); + } + } +} diff --git a/test/TorchSharpTest/TestLoadSave.cs b/test/TorchSharpTest/TestLoadSave.cs index 924419e38..ad9a2640a 100644 --- a/test/TorchSharpTest/TestLoadSave.cs +++ b/test/TorchSharpTest/TestLoadSave.cs @@ -98,124 +98,6 @@ public void TestSaveLoadLinear3() } } - - - [Fact] - public void TestLoadJIT_Func() - { - // One linear layer followed by ReLU. - using var m = torch.jit.load(@"func.script.dat"); - - var sms = m.named_modules().ToArray(); - Assert.Empty(sms); - - var kids = m.named_children().ToArray(); - Assert.Empty(kids); - - var t = m.forward(torch.ones(10), torch.ones(10)); - - Assert.Equal(new long[] { 10 }, t.shape); - Assert.Equal(torch.float32, t.dtype); - Assert.True(torch.tensor(new float[] { 2, 2, 2, 2, 2, 2, 2, 2, 2, 2 }).allclose(t)); - } - - [Fact] - public void TestLoadJIT_1() - { - // One linear layer followed by ReLU. - using var m = torch.jit.load(@"linrelu.script.dat"); - var t = m.forward(torch.ones(10)); - - Assert.Equal(new long[] { 6 }, t.shape); - Assert.Equal(torch.float32, t.dtype); - Assert.True(torch.tensor(new float[] { 0.313458264f, 0, 0.9996568f, 0, 0, 0 }).allclose(t)); - } - - [Fact] - public void TestSaveJIT() - { - var location = "TestSaveJIT.ts"; - if (File.Exists(location)) File.Delete(location); - - try { - - // One linear layer followed by ReLU. - using var m1 = torch.jit.load(@"linrelu.script.dat"); - - torch.jit.save(m1, location); - using var m2 = torch.jit.load(location); - - var t = m2.forward(torch.ones(10)); - - Assert.Equal(new long[] { 6 }, t.shape); - Assert.Equal(torch.float32, t.dtype); - Assert.True(torch.tensor(new float[] { 0.313458264f, 0, 0.9996568f, 0, 0, 0 }).allclose(t)); - - } finally { - if (File.Exists(location)) File.Delete(location); - } - } - - [Fact] - public void TestLoadJIT_2() - { - // One linear layer followed by ReLU. - using var m = torch.jit.load(@"scripted.script.dat"); - var t = m.forward(torch.ones(6)); - - Assert.Equal(new long[] { 6 }, t.shape); - Assert.Equal(torch.float32, t.dtype); - Assert.True(torch.tensor(new float[] { 1.554085f, 1.01024628f, -1.35086036f, -1.84021854f, 0.0127189457f, 0.5994258f }).allclose(t)); - } - - [Fact] - public void TestLoadJIT_3() - { - // Two linear layers, nested Sequential, ReLU in between. - using var m = torch.jit.load(@"l1000_100_10.script.dat"); - - var sms = m.named_modules().ToArray(); - Assert.Equal(4, sms.Length); - - var kids = m.named_children().ToArray(); - Assert.Equal(2, kids.Length); - - var t = m.forward(torch.ones(1000)); - - Assert.Equal(new long[] { 10 }, t.shape); - Assert.Equal(torch.float32, t.dtype); - Assert.True(torch.tensor(new float[] { 0.564213157f, -0.04519982f, -0.005117342f, 0.395530462f, -0.3780813f, -0.004734449f, -0.3221216f, -0.289159119f, 0.268511474f, 0.180702567f }).allclose(t)); - - Assert.Throws(() => m.forward(torch.ones(100))); - } - - [Fact] - public void TestLoadJIT_4() - { - // Definitely not a TorchScript file. Let's see what the runtime does with it. - Assert.Throws(() => torch.jit.load(@"bug510.dat")); - } - - [Fact] - public void TestSaveLoadJITCUDA() - { - if (torch.cuda.is_available()) { - - using var m = torch.jit.load(@"linrelu.script.dat"); - - m.to(DeviceType.CUDA); - var params0 = m.parameters().ToArray(); - foreach (var p in params0) - Assert.Equal(DeviceType.CUDA, p.device_type); - - var t = m.forward(torch.ones(10).cuda()).cpu(); - - Assert.Equal(new long[] { 6 }, t.shape); - Assert.Equal(torch.float32, t.dtype); - Assert.True(torch.tensor(new float[] { 0.313458264f, 0, 0.9996568f, 0, 0, 0 }).allclose(t)); - } - } - [Fact] public void TestSaveLoadConv2D() { @@ -368,7 +250,7 @@ public void TestSaveLoadCustomWithParameters() } } - private class TestModule1 : Module + private class TestModule1 : Module { public TestModule1() : base("TestModule1") { diff --git a/test/TorchSharpTest/TestSaveSD.cs b/test/TorchSharpTest/TestSaveSD.cs index ca818faf8..7ccd1b21b 100644 --- a/test/TorchSharpTest/TestSaveSD.cs +++ b/test/TorchSharpTest/TestSaveSD.cs @@ -12,17 +12,17 @@ namespace TorchSharp #endif // NET472_OR_GREATER public class TestSaveSD { - private class LSTMModel : Module + private class LSTMModel : nn.Module { public static int NUM_WORDS = 100; public static int EMBEDDING_VEC_LEN = 100; public static int HIDDEN_SIZE = 128; - private Module embedding; + private Module embedding; private LSTM lstm; - private Module dropout; - private Module dense; - private Module sigmoid; + private Module dropout; + private Module dense; + private Module sigmoid; private Device _device; public LSTMModel(string name, Device device = null) : base(name) @@ -59,11 +59,11 @@ public void TestSaveSDData_LSTM() lstm.save("./lstm.dat"); } - class LeNet1Model : Module + class LeNet1Model : Module { // The names of properties should be the same in C# and Python // in this case, we both name the Sequential as layers - private readonly Module layers; + private readonly Module layers; private Device _device; public LeNet1Model(string name, Device device = null) : base(name) @@ -71,7 +71,7 @@ public LeNet1Model(string name, Device device = null) : base(name) _device = device; // the names of each layer should also be the same in C# and Python - var modules = new List<(string, Module)>(); + var modules = new List<(string, Module)>(); modules.Add(("conv-1", Conv2d(1, 4, 5, padding: 2))); modules.Add(("bnrm2d-1", BatchNorm2d(4))); modules.Add(("relu-1", ReLU())); diff --git a/test/TorchSharpTest/TestTorchSharp.cs b/test/TorchSharpTest/TestTorchSharp.cs index 933b25947..288655ec0 100644 --- a/test/TorchSharpTest/TestTorchSharp.cs +++ b/test/TorchSharpTest/TestTorchSharp.cs @@ -179,7 +179,7 @@ public void TestUtilsPtoV() var lin1 = nn.Linear(1000, 100); var lin2 = nn.Linear(100, 10); - var submodules = new List<(string name, torch.nn.Module submodule)>(); + var submodules = new List<(string name, torch.nn.Module submodule)>(); submodules.Add(("lin1", lin1)); submodules.Add(("lin2", lin2)); @@ -195,7 +195,7 @@ public void TestUtilsVtoP() var lin1 = nn.Linear(1000, 100); var lin2 = nn.Linear(100, 10); - var submodules = new List<(string name, torch.nn.Module submodule)>(); + var submodules = new List<(string name, torch.nn.Module submodule)>(); submodules.Add(("lin1", lin1)); submodules.Add(("relu1", nn.ReLU())); submodules.Add(("lin2", lin2)); diff --git a/test/TorchSharpTest/TestTorchTensorBugs.cs b/test/TorchSharpTest/TestTorchTensorBugs.cs index 1477f902c..59d138e69 100644 --- a/test/TorchSharpTest/TestTorchTensorBugs.cs +++ b/test/TorchSharpTest/TestTorchTensorBugs.cs @@ -56,7 +56,7 @@ public void ValidateIssue145() } } - class DoubleIt : nn.Module + class DoubleIt : nn.Module { public DoubleIt() : base("double") { } @@ -178,7 +178,7 @@ public void ValidateIssue315_4() } } - class TestModule : Module + class TestModule : Module { public TestModule() : base(nameof(TestModule)) { } @@ -199,7 +199,7 @@ public static void Reproduce() } [MethodImpl(MethodImplOptions.NoInlining)] - static Module Make() => Sequential(("t", new TestModule()), ("d", Linear(10, 10))); + static Module Make() => Sequential(("t", new TestModule()), ("d", Linear(10, 10))); } [Fact] @@ -476,9 +476,9 @@ public void ValidateIssue500() } } - class Module500 : Module + class Module500 : Module { - private Module bn1 = BatchNorm1d(28); + private Module bn1 = BatchNorm1d(28); public Module500() : base(nameof(TestModule)) { RegisterComponents(); } @@ -518,9 +518,9 @@ public void ValidateIssue510() Assert.Equal(0, nm_.item()); } - internal class Module510 : Module + internal class Module510 : Module { - private readonly Module stack; + private readonly Module stack; public Module510(int in_channels, int out_channels, int kernel_size = 3, int stride = 1, int padding = 0) : base(String.Empty) { @@ -574,7 +574,7 @@ public void ValidateIssue516() } } - internal abstract class BaseModule : torch.nn.Module + internal abstract class BaseModule : torch.nn.Module { public int? InstanceId = null; @@ -583,7 +583,7 @@ protected BaseModule(string name) : base(name) } } - public class TestGradWarningModel : torch.nn.Module + public class TestGradWarningModel : torch.nn.Module { public readonly Modules.Parameter Weight; @@ -612,11 +612,11 @@ public void Validate532() Assert.Equal(pB.Length + pC.Length, p.Length); } - internal class Module532 : Module + internal class Module532 : Module { - public Module conv; - public Module batch; - private Module seq; + public Module conv; + public Module batch; + private Module seq; public Module532(int in_channels, int out_channels) : base(String.Empty) { @@ -656,9 +656,9 @@ public void Validate538() File.Delete("bug538.dat"); } - internal class Module538 : Module + internal class Module538 : Module { - private Module seq; + private Module seq; public Module538(int in_channels, int out_channels) : base(String.Empty) { @@ -779,7 +779,7 @@ public void ValidateBug715() { var resnet = resnet18(); var resnetlist = resnet.named_children(); - var list = resnetlist.Take(6); + var list = resnetlist.Take(6).Select(x => (x.name, (nn.Module)x.module)); var bone = nn.Sequential(list); var x = torch.zeros(1, 3, 64, 160); diff --git a/test/TorchSharpTest/TestTraining.cs b/test/TorchSharpTest/TestTraining.cs index 98e40126b..ab7d033ec 100644 --- a/test/TorchSharpTest/TestTraining.cs +++ b/test/TorchSharpTest/TestTraining.cs @@ -29,7 +29,7 @@ public void TestTraining1() var lin1 = Linear(1000, 100); var lin2 = Linear(100, 10); - var submodules = new List<(string name, torch.nn.Module submodule)>(); + var submodules = new List<(string name, torch.nn.Module submodule)>(); submodules.Add(("lin1", lin1)); submodules.Add(("relu1", ReLU())); submodules.Add(("lin2", lin2)); @@ -186,7 +186,7 @@ private static void ReInitializeLinear(Generator gen, Linear linear) } } - private static float TrainLoop(Module seq, Tensor x, Tensor y, optim.Optimizer optimizer) + private static float TrainLoop(IModule seq, Tensor x, Tensor y, optim.Optimizer optimizer) { var loss = MSELoss(Reduction.Sum); @@ -213,7 +213,7 @@ private static float TrainLoop(Module seq, Tensor x, Tensor y, optim.Optimizer o return finalLoss; } - private static float TrainLoop(Module seq, Tensor x, Tensor y, optim.Optimizer optimizer, optim.lr_scheduler.LRScheduler scheduler, bool check_lr = true, int iters = 10) + private static float TrainLoop(IModule seq, Tensor x, Tensor y, optim.Optimizer optimizer, optim.lr_scheduler.LRScheduler scheduler, bool check_lr = true, int iters = 10) { var loss = MSELoss(Reduction.Sum); @@ -1727,7 +1727,7 @@ public void TestTrainingLoadedTorchScript() var gen = new Generator(4711); CreateDataAndLabels(gen, out var x, out var y); - var seq = torch.jit.load(@"l1000_100_10.script.dat"); + var seq = torch.jit.load(@"l1000_100_10.script.dat"); double learning_rate = 0.00004f; var optimizer = torch.optim.SGD(seq.parameters(), learning_rate); @@ -1770,7 +1770,7 @@ public void TestTrainingConv2dCUDA() if (torch.cuda.is_available()) { var device = torch.CUDA; - using (Module conv1 = Conv2d(3, 4, 3, stride: 2), + using (Module conv1 = Conv2d(3, 4, 3, stride: 2), lin1 = Linear(4 * 13 * 13, 32), lin2 = Linear(32, 10)) diff --git a/test/TorchSharpTest/TorchSharpTest.csproj b/test/TorchSharpTest/TorchSharpTest.csproj index 4525d8c55..516d0eb95 100644 --- a/test/TorchSharpTest/TorchSharpTest.csproj +++ b/test/TorchSharpTest/TorchSharpTest.csproj @@ -38,9 +38,15 @@ PreserveNewest + + PreserveNewest + PreserveNewest + + PreserveNewest + PreserveNewest diff --git a/test/TorchSharpTest/list_out.dat b/test/TorchSharpTest/list_out.dat new file mode 100644 index 000000000..d6bc9ce27 Binary files /dev/null and b/test/TorchSharpTest/list_out.dat differ diff --git a/test/TorchSharpTest/tuple_out.dat b/test/TorchSharpTest/tuple_out.dat new file mode 100644 index 000000000..a3051d5cb Binary files /dev/null and b/test/TorchSharpTest/tuple_out.dat differ