Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/Examples/AdversarialExampleGeneration.cs
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ private static Tensor Attack(Tensor image, double ε, Tensor data_grad)

private static double Test(
MNIST.Model model,
Loss criterion,
Loss<Tensor, Tensor, Tensor> criterion,
double ε,
Device device,
Dataset dataset,
Expand Down
8 changes: 4 additions & 4 deletions src/Examples/AlexNet.cs
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,11 @@ namespace TorchSharp.Examples
/// <summary>
/// Modified version of original AlexNet to fix CIFAR10 32x32 images.
/// </summary>
class AlexNet : Module
class AlexNet : Module<Tensor, Tensor>
{
private readonly Module features;
private readonly Module avgPool;
private readonly Module classifier;
private readonly Module<Tensor, Tensor> features;
private readonly Module<Tensor, Tensor> avgPool;
private readonly Module<Tensor, Tensor> classifier;

public AlexNet(string name, int numClasses, torch.Device device = null) : base(name)
{
Expand Down
10 changes: 5 additions & 5 deletions src/Examples/CIFAR10.cs
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ internal static void Main(string[] args)

Console.WriteLine($"\tCreating the model...");

Module model = null;
Module<torch.Tensor, torch.Tensor> model = null;

switch (modelName.ToLower()) {
case "alexnet":
Expand Down Expand Up @@ -134,9 +134,9 @@ internal static void Main(string[] args)
}

private static void Train(
Module model,
Module<torch.Tensor, torch.Tensor> model,
torch.optim.Optimizer optimizer,
Loss loss,
Loss<torch.Tensor, torch.Tensor, torch.Tensor> loss,
DataLoader dataLoader,
int epoch,
long batchSize,
Expand Down Expand Up @@ -182,8 +182,8 @@ private static void Train(
}

private static void Test(
Module model,
Loss loss,
Module<torch.Tensor, torch.Tensor> model,
Loss<torch.Tensor, torch.Tensor, torch.Tensor> loss,
DataLoader dataLoader,
long size)
{
Expand Down
30 changes: 15 additions & 15 deletions src/Examples/MNIST.cs
Original file line number Diff line number Diff line change
Expand Up @@ -96,26 +96,26 @@ internal static void TrainingLoop(string dataset, Device device, Model model, Da
model.save(dataset + ".model.bin");
}

internal class Model : Module
internal class Model : Module<Tensor, Tensor>
{
private Module conv1 = Conv2d(1, 32, 3);
private Module conv2 = Conv2d(32, 64, 3);
private Module fc1 = Linear(9216, 128);
private Module fc2 = Linear(128, 10);
private Module<Tensor, Tensor> conv1 = Conv2d(1, 32, 3);
private Module<Tensor, Tensor> conv2 = Conv2d(32, 64, 3);
private Module<Tensor, Tensor> fc1 = Linear(9216, 128);
private Module<Tensor, Tensor> fc2 = Linear(128, 10);

// These don't have any parameters, so the only reason to instantiate
// them is performance, since they will be used over and over.
private Module pool1 = MaxPool2d(kernelSize: new long[] { 2, 2 });
private Module<Tensor, Tensor> pool1 = MaxPool2d(kernelSize: new long[] { 2, 2 });

private Module relu1 = ReLU();
private Module relu2 = ReLU();
private Module relu3 = ReLU();
private Module<Tensor, Tensor> relu1 = ReLU();
private Module<Tensor, Tensor> relu2 = ReLU();
private Module<Tensor, Tensor> relu3 = ReLU();

private Module dropout1 = Dropout(0.25);
private Module dropout2 = Dropout(0.5);
private Module<Tensor, Tensor> dropout1 = Dropout(0.25);
private Module<Tensor, Tensor> dropout2 = Dropout(0.5);

private Module flatten = Flatten();
private Module logsm = LogSoftmax(1);
private Module<Tensor, Tensor> flatten = Flatten();
private Module<Tensor, Tensor> logsm = LogSoftmax(1);

public Model(string name, torch.Device device = null) : base(name)
{
Expand Down Expand Up @@ -151,7 +151,7 @@ public override Tensor forward(Tensor input)
private static void Train(
Model model,
torch.optim.Optimizer optimizer,
Loss loss,
Loss<torch.Tensor, torch.Tensor, torch.Tensor> loss,
DataLoader dataLoader,
int epoch,
long size)
Expand Down Expand Up @@ -191,7 +191,7 @@ private static void Train(

private static void Test(
Model model,
Loss loss,
Loss<torch.Tensor, torch.Tensor, torch.Tensor> loss,
DataLoader dataLoader,
long size)
{
Expand Down
8 changes: 4 additions & 4 deletions src/Examples/MobileNet.cs
Original file line number Diff line number Diff line change
Expand Up @@ -14,21 +14,21 @@ namespace TorchSharp.Examples
/// With an unaugmented CIFAR-10 data set, the author of this saw training converge
/// at roughly 75% accuracy on the test set, over the course of 1500 epochs.
/// </remarks>
class MobileNet : Module
class MobileNet : Module<Tensor, Tensor>
{
// The code here is is loosely based on https://github.com/kuangliu/pytorch-cifar/blob/master/models/mobilenet.py
// Licence and copypright notice at: https://github.com/kuangliu/pytorch-cifar/blob/master/LICENSE

private readonly long[] planes = new long[] { 64, 128, 128, 256, 256, 512, 512, 512, 512, 512, 512, 1024, 1024 };
private readonly long[] strides = new long[] { 1, 2, 1, 2, 1, 2, 1, 1, 1, 1, 1, 2, 1 };

private readonly Module layers;
private readonly Module<Tensor, Tensor> layers;

public MobileNet(string name, int numClasses, Device device = null) : base(name)
{
if (planes.Length != strides.Length) throw new ArgumentException("'planes' and 'strides' must have the same length.");

var modules = new List<(string, Module)>();
var modules = new List<(string, Module<Tensor, Tensor>)>();

modules.Add(($"conv2d-first", Conv2d(3, 32, kernelSize: 3, stride: 1, padding: 1, bias: false)));
modules.Add(($"bnrm2d-first", BatchNorm2d(32)));
Expand All @@ -46,7 +46,7 @@ public MobileNet(string name, int numClasses, Device device = null) : base(name)
this.to(device);
}

private void MakeLayers(List<(string, Module)> modules, long in_planes)
private void MakeLayers(List<(string, Module<Tensor, Tensor>)> modules, long in_planes)
{

for (var i = 0; i < strides.Length; i++) {
Expand Down
26 changes: 13 additions & 13 deletions src/Examples/ResNet.cs
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,12 @@ namespace TorchSharp.Examples
/// <summary>
/// Modified version of ResNet to classify CIFAR10 32x32 images.
/// </summary>
class ResNet : Module
class ResNet : Module<Tensor, Tensor>
{
// The code here is is loosely based on https://github.com/kuangliu/pytorch-cifar/blob/master/models/resnet.py
// Licence and copypright notice at: https://github.com/kuangliu/pytorch-cifar/blob/master/LICENSE

private readonly Module layers;
private readonly Module<Tensor, Tensor> layers;
private int in_planes = 64;

public static ResNet ResNet18(int numClasses, Device device = null)
Expand Down Expand Up @@ -68,9 +68,9 @@ public static ResNet ResNet152(int numClasses, Device device = null)
device);
}

public ResNet(string name, Func<string, int,int,int,Module> block, int expansion, IList<int> num_blocks, int numClasses, Device device = null) : base(name)
public ResNet(string name, Func<string, int,int,int,Module<Tensor, Tensor>> block, int expansion, IList<int> num_blocks, int numClasses, Device device = null) : base(name)
{
var modules = new List<(string, Module)>();
var modules = new List<(string, Module<Tensor, Tensor>)>();

modules.Add(($"conv2d-first", Conv2d(3, 64, kernelSize: 3, stride: 1, padding: 1, bias: false)));
modules.Add(($"bnrm2d-first", BatchNorm2d(64)));
Expand All @@ -91,7 +91,7 @@ public ResNet(string name, Func<string, int,int,int,Module> block, int expansion
this.to(device);
}

private void MakeLayer(List<(string, Module)> modules, Func<string, int, int, int, Module> block, int expansion, int planes, int num_blocks, int stride)
private void MakeLayer(List<(string, Module<Tensor, Tensor>)> modules, Func<string, int, int, int, Module<Tensor, Tensor>> block, int expansion, int planes, int num_blocks, int stride)
{
var strides = new List<int>();
strides.Add(stride);
Expand All @@ -109,11 +109,11 @@ public override Tensor forward(Tensor input)
return layers.forward(input);
}

class BasicBlock : Module
class BasicBlock : Module<Tensor, Tensor>
{
public BasicBlock (string name, int in_planes, int planes, int stride) : base(name)
{
var modules = new List<(string, Module)>();
var modules = new List<(string, Module<Tensor, Tensor>)>();

modules.Add(($"{name}-conv2d-1", Conv2d(in_planes, planes, kernelSize: 3, stride: stride, padding: 1, bias: false)));
modules.Add(($"{name}-bnrm2d-1", BatchNorm2d(planes)));
Expand Down Expand Up @@ -146,15 +146,15 @@ public override Tensor forward(Tensor t)

public static int expansion = 1;

private readonly Module layers;
private readonly Module shortcut;
private readonly Module<Tensor, Tensor> layers;
private readonly Module<Tensor, Tensor> shortcut;
}

class Bottleneck : Module
class Bottleneck : Module<Tensor, Tensor>
{
public Bottleneck(string name, int in_planes, int planes, int stride) : base(name)
{
var modules = new List<(string, Module)>();
var modules = new List<(string, Module<Tensor, Tensor>)>();

modules.Add(($"{name}-conv2d-1", Conv2d(in_planes, planes, kernelSize: 1, bias: false)));
modules.Add(($"{name}-bnrm2d-1", BatchNorm2d(planes)));
Expand Down Expand Up @@ -187,8 +187,8 @@ public override Tensor forward(Tensor t)

public static int expansion = 4;

private readonly Module layers;
private readonly Module shortcut;
private readonly Module<Tensor, Tensor> layers;
private readonly Module<Tensor, Tensor> shortcut;
}
}
}
17 changes: 6 additions & 11 deletions src/Examples/SequenceToSequence.cs
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ internal static void Main(string[] args)
Console.WriteLine($"\nEnd of training | time: {totalTime.Elapsed.TotalSeconds:0.0}s | loss: {tst_loss:0.00}\n");
}

private static void train(int epoch, Tensor train_data, TransformerModel model, Loss criterion, int bptt, int ntokens, torch.optim.Optimizer optimizer)
private static void train(int epoch, Tensor train_data, TransformerModel model, Loss<Tensor, Tensor, Tensor> criterion, int bptt, int ntokens, torch.optim.Optimizer optimizer)
{
model.train();

Expand Down Expand Up @@ -149,7 +149,7 @@ private static void train(int epoch, Tensor train_data, TransformerModel model,
}
}

private static double evaluate(Tensor eval_data, TransformerModel model, Loss criterion, int bptt, int ntokens, torch.optim.Optimizer optimizer)
private static double evaluate(Tensor eval_data, TransformerModel model, Loss<Tensor, Tensor, Tensor> criterion, int bptt, int ntokens, torch.optim.Optimizer optimizer)
{
model.eval();

Expand Down Expand Up @@ -211,9 +211,9 @@ static Tensor Batchify(Tensor data, int batch_size)
return (data, target);
}

class TransformerModel : Module
class TransformerModel : Module<Tensor, Tensor, Tensor>
{
private Module transformer_encoder;
private Modules.TransformerEncoder transformer_encoder;
private PositionalEncoding pos_encoder;
private Modules.Embedding encoder;
private Modules.Linear decoder;
Expand Down Expand Up @@ -252,11 +252,6 @@ private void InitWeights()
init.uniform_(decoder.weight, -initrange, initrange);
}

public override Tensor forward(Tensor t)
{
throw new NotImplementedException("single-argument forward()");
}

public override Tensor forward(Tensor t, Tensor mask)
{
var src = pos_encoder.forward(encoder.forward(t) * MathF.Sqrt(ninputs));
Expand All @@ -271,9 +266,9 @@ protected override Module _to(DeviceType deviceType, int deviceIndex = -1)
}
}

class PositionalEncoding : Module
class PositionalEncoding : Module<Tensor, Tensor>
{
private Module dropout;
private Module<Tensor, Tensor> dropout;
private Tensor pe;

public PositionalEncoding(long dmodel, double dropout, int maxLen = 5000) : base("PositionalEncoding")
Expand Down
32 changes: 16 additions & 16 deletions src/Examples/SpeechCommands.cs
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ private static void Train(
M5 model,
ITransform transform,
torch.optim.Optimizer optimizer,
Loss criteria,
Loss<Tensor,Tensor,Tensor> criteria,
DataLoader<SpeechCommandsDatasetItem, BatchItem> dataLoader,
int epoch,
long size)
Expand Down Expand Up @@ -160,7 +160,7 @@ private static void Train(
private static void Test(
M5 model,
ITransform transform,
Loss criteria,
Loss<Tensor, Tensor, Tensor> criteria,
DataLoader<SpeechCommandsDatasetItem, BatchItem> dataLoader,
long size)
{
Expand Down Expand Up @@ -217,21 +217,21 @@ private class BatchItem
public torch.Tensor label;
}

internal class M5 : Module
internal class M5 : Module<Tensor, Tensor>
{
private readonly Module conv1;
private readonly Module bn1;
private readonly Module pool1;
private readonly Module conv2;
private readonly Module bn2;
private readonly Module pool2;
private readonly Module conv3;
private readonly Module bn3;
private readonly Module pool3;
private readonly Module conv4;
private readonly Module bn4;
private readonly Module pool4;
private readonly Module fc1;
private readonly Module<Tensor, Tensor> conv1;
private readonly Module<Tensor, Tensor> bn1;
private readonly Module<Tensor, Tensor> pool1;
private readonly Module<Tensor, Tensor> conv2;
private readonly Module<Tensor, Tensor> bn2;
private readonly Module<Tensor, Tensor> pool2;
private readonly Module<Tensor, Tensor> conv3;
private readonly Module<Tensor, Tensor> bn3;
private readonly Module<Tensor, Tensor> pool3;
private readonly Module<Tensor, Tensor> conv4;
private readonly Module<Tensor, Tensor> bn4;
private readonly Module<Tensor, Tensor> pool4;
private readonly Module<Tensor, Tensor> fc1;

public M5(string name, int n_input = 1, int n_output = 35, int stride = 16, int n_channel = 32) : base(name)
{
Expand Down
8 changes: 4 additions & 4 deletions src/Examples/TextClassification.cs
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ internal static void Main(string[] args)
}
}

static void train(int epoch, IEnumerable<(Tensor, Tensor, Tensor)> train_data, TextClassificationModel model, Loss criterion, torch.optim.Optimizer optimizer)
static void train(int epoch, IEnumerable<(Tensor, Tensor, Tensor)> train_data, TextClassificationModel model, Loss<torch.Tensor, torch.Tensor, torch.Tensor> criterion, torch.optim.Optimizer optimizer)
{
model.train();

Expand Down Expand Up @@ -145,7 +145,7 @@ static void train(int epoch, IEnumerable<(Tensor, Tensor, Tensor)> train_data, T
}
}

static double evaluate(IEnumerable<(Tensor, Tensor, Tensor)> test_data, TextClassificationModel model, Loss criterion)
static double evaluate(IEnumerable<(Tensor, Tensor, Tensor)> test_data, TextClassificationModel model, Loss<Tensor, Tensor, Tensor> criterion)
{
model.eval();

Expand All @@ -166,7 +166,7 @@ static double evaluate(IEnumerable<(Tensor, Tensor, Tensor)> test_data, TextClas
}
}

class TextClassificationModel : Module
class TextClassificationModel : Module<Tensor, Tensor>
{
private Modules.EmbeddingBag embedding;
private Modules.Linear fc;
Expand Down Expand Up @@ -194,7 +194,7 @@ public override Tensor forward(Tensor t)
throw new NotImplementedException();
}

public override Tensor forward(Tensor input, Tensor offsets)
public Tensor forward(Tensor input, Tensor offsets)
{
return fc.forward(embedding.forward(input, offsets));
}
Expand Down
6 changes: 3 additions & 3 deletions src/Examples/VGG.cs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ namespace TorchSharp.Examples
/// With an unaugmented CIFAR-10 data set, the author of this saw training converge
/// at roughly 85% accuracy on the test set, after 50 epochs using VGG-16.
/// </remarks>
class VGG : Module
class VGG : Module<Tensor, Tensor>
{
// The code here is is loosely based on https://github.com/kuangliu/pytorch-cifar/blob/master/models/vgg.py
// Licence and copypright notice at: https://github.com/kuangliu/pytorch-cifar/blob/master/LICENSE
Expand All @@ -25,11 +25,11 @@ class VGG : Module
{ "VGG19", new long[] { 64, 64, 0, 128, 128, 0, 256, 256, 256, 256, 0, 512, 512, 512, 512, 0, 512, 512, 512, 512, 0 } }
};

private readonly Module layers;
private readonly Module<Tensor, Tensor> layers;

public VGG(string name, int numClasses, Device device = null) : base(name)
{
var modules = new List<(string, Module)>();
var modules = new List<(string, Module<Tensor, Tensor>)>();

var channels = _channels[name];

Expand Down
Loading