From 2d909c19de3bf6c8fe24daca31c94607b1334412 Mon Sep 17 00:00:00 2001 From: Niklas Gustafsson Date: Mon, 12 Sep 2022 09:17:36 -0700 Subject: [PATCH 1/5] Fixed bug in ResNet implementation. --- RELEASENOTES.md | 3 ++- src/TorchSharp/TorchVision/models/ResNet.cs | 19 ++++++++++--------- test/TorchSharpTest/TestTorchTensorBugs.cs | 17 ++++++++++++++++- 3 files changed, 28 insertions(+), 11 deletions(-) diff --git a/RELEASENOTES.md b/RELEASENOTES.md index 9faa384a2..c1b8256a5 100644 --- a/RELEASENOTES.md +++ b/RELEASENOTES.md @@ -6,11 +6,12 @@ Releases, starting with 9/2/2021, are listed with the most recent release at the __Fixed Bugs:__ - +#715 How to implement the following code
__API Changes__: Add functional normalizations
+Added torch.utils.tensorboard.SummaryWriter. Support for scalars only.
## NuGet Version 0.97.3 diff --git a/src/TorchSharp/TorchVision/models/ResNet.cs b/src/TorchSharp/TorchVision/models/ResNet.cs index a167ddb63..98d923850 100644 --- a/src/TorchSharp/TorchVision/models/ResNet.cs +++ b/src/TorchSharp/TorchVision/models/ResNet.cs @@ -3,6 +3,7 @@ using System.Collections.Generic; using System.IO; using System.Net; +using TorchSharp.Modules; using TorchSharp.torchvision.Modules; using static TorchSharp.torch; using static TorchSharp.torch.nn; @@ -220,10 +221,10 @@ public class ResNet : Module private readonly Module bn1; private readonly Module relu1; - private readonly TorchSharp.Modules.ModuleList layer1 = new TorchSharp.Modules.ModuleList(); - private readonly TorchSharp.Modules.ModuleList layer2 = new TorchSharp.Modules.ModuleList(); - private readonly TorchSharp.Modules.ModuleList layer3 = new TorchSharp.Modules.ModuleList(); - private readonly TorchSharp.Modules.ModuleList layer4 = new TorchSharp.Modules.ModuleList(); + private readonly Sequential layer1 = Sequential(); + private readonly Sequential layer2 = Sequential(); + private readonly Sequential layer3 = Sequential(); + private readonly Sequential layer4 = Sequential(); private readonly Module avgpool; private readonly Module maxpool; @@ -356,7 +357,7 @@ public ResNet(string name, this.to(device); } - private void MakeLayer(TorchSharp.Modules.ModuleList modules, Func block, int expansion, int planes, int num_blocks, int stride) + private void MakeLayer(Sequential modules, Func block, int expansion, int planes, int num_blocks, int stride) { var strides = new List(); strides.Add(stride); @@ -375,10 +376,10 @@ public override Tensor forward(Tensor input) var x = maxpool.forward(relu1.forward(bn1.forward(conv1.forward(input)))); - foreach (var m in layer1) x = m.forward(x); - foreach (var m in layer2) x = m.forward(x); - foreach (var m in layer3) x = m.forward(x); - foreach (var m in layer4) x = m.forward(x); + x = layer1.forward(x); + x = layer2.forward(x); + x = layer3.forward(x); + x = layer4.forward(x); var res = fc.forward(flatten.forward(avgpool.forward(x))); scope.MoveToOuter(res); diff --git a/test/TorchSharpTest/TestTorchTensorBugs.cs b/test/TorchSharpTest/TestTorchTensorBugs.cs index 3b82dd360..c35a171d3 100644 --- a/test/TorchSharpTest/TestTorchTensorBugs.cs +++ b/test/TorchSharpTest/TestTorchTensorBugs.cs @@ -5,12 +5,13 @@ using System.IO; using System.Threading; +using System.Runtime.CompilerServices; using static TorchSharp.torch.nn; using Xunit; using static TorchSharp.torch; -using System.Runtime.CompilerServices; +using static TorchSharp.torchvision.models; #nullable enable @@ -772,5 +773,19 @@ public void ValidateBug679() spec = torch.rand(1, 257, 500, dtype: dtype); x = torch.istft(spec, 512, 160, 400, null); } + + [Fact] + public void ValidateBug715() + { + var resnet = resnet18(); + var resnetlist = resnet.named_children(); + var list = resnetlist.Take(6); + var bone = nn.Sequential(list); + + var x = torch.zeros(1, 3, 64, 160); + + // This should not blow up. + var tmp = bone.forward(x); + } } } \ No newline at end of file From 935a700af54449d0aca6a50bbf8545386a72de75 Mon Sep 17 00:00:00 2001 From: Niklas Gustafsson Date: Mon, 12 Sep 2022 09:34:46 -0700 Subject: [PATCH 2/5] Adjusted order of declared submodule of ResNet --- src/TorchSharp/TorchVision/models/ResNet.cs | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/src/TorchSharp/TorchVision/models/ResNet.cs b/src/TorchSharp/TorchVision/models/ResNet.cs index 98d923850..2924784a9 100644 --- a/src/TorchSharp/TorchVision/models/ResNet.cs +++ b/src/TorchSharp/TorchVision/models/ResNet.cs @@ -219,7 +219,8 @@ public class ResNet : Module private readonly Module conv1; private readonly Module bn1; - private readonly Module relu1; + private readonly Module relu; + private readonly Module maxpool; private readonly Sequential layer1 = Sequential(); private readonly Sequential layer2 = Sequential(); @@ -227,11 +228,9 @@ public class ResNet : Module private readonly Sequential layer4 = Sequential(); private readonly Module avgpool; - private readonly Module maxpool; private readonly Module flatten; private readonly Module fc; - private int in_planes = 64; public static ResNet ResNet18(int numClasses, @@ -321,7 +320,7 @@ public ResNet(string name, conv1 = Conv2d(3, 64, kernelSize: 7, stride: 2, padding: 3, bias: false); bn1 = BatchNorm2d(64); - relu1 = ReLU(inPlace: true); + relu = ReLU(inPlace: true); maxpool = MaxPool2d(kernelSize: 2, stride: 2, padding: 1); MakeLayer(layer1, block, expansion, 64, num_blocks[0], 1); MakeLayer(layer2, block, expansion, 128, num_blocks[1], 2); @@ -374,7 +373,7 @@ public override Tensor forward(Tensor input) { using (var scope = NewDisposeScope()) { - var x = maxpool.forward(relu1.forward(bn1.forward(conv1.forward(input)))); + var x = maxpool.forward(relu.forward(bn1.forward(conv1.forward(input)))); x = layer1.forward(x); x = layer2.forward(x); From ce844be3296225c47aa2147ebfad009d91d98180 Mon Sep 17 00:00:00 2001 From: Niklas Gustafsson Date: Mon, 12 Sep 2022 09:43:39 -0700 Subject: [PATCH 3/5] Reordered fields in additional TV models. --- src/TorchSharp/TorchVision/models/AlexNet.cs | 6 +++--- src/TorchSharp/TorchVision/models/GoogleNet.cs | 7 +++---- src/TorchSharp/TorchVision/models/InceptionV3.cs | 4 ++-- src/TorchSharp/TorchVision/models/VGG.cs | 2 +- test/TorchSharpTest/TestTorchVision.cs | 2 +- 5 files changed, 10 insertions(+), 11 deletions(-) diff --git a/src/TorchSharp/TorchVision/models/AlexNet.cs b/src/TorchSharp/TorchVision/models/AlexNet.cs index 8a7673134..2b011ae68 100644 --- a/src/TorchSharp/TorchVision/models/AlexNet.cs +++ b/src/TorchSharp/TorchVision/models/AlexNet.cs @@ -60,7 +60,7 @@ namespace Modules public class AlexNet : Module { private readonly Module features; - private readonly Module avgPool; + private readonly Module avgpool; private readonly Module classifier; public AlexNet(int numClasses, float dropout = 0.5f, string weights_file = null, bool skipfc = true, Device device = null) : base(nameof(AlexNet)) @@ -81,7 +81,7 @@ public AlexNet(int numClasses, float dropout = 0.5f, string weights_file = null, MaxPool2d(kernelSize: 3, stride: 2) ); - avgPool = AdaptiveAvgPool2d(new long[] { 6, 6 }); + avgpool = AdaptiveAvgPool2d(new long[] { 6, 6 }); classifier = Sequential( Dropout(probability: dropout), @@ -108,7 +108,7 @@ public override Tensor forward(Tensor input) { using (var _ = NewDisposeScope()) { var f = features.forward(input); - var avg = avgPool.forward(f); + var avg = avgpool.forward(f); var x = avg.flatten(1); return classifier.forward(x).MoveToOuterDisposeScope(); } diff --git a/src/TorchSharp/TorchVision/models/GoogleNet.cs b/src/TorchSharp/TorchVision/models/GoogleNet.cs index b4eaa76e7..3b2783621 100644 --- a/src/TorchSharp/TorchVision/models/GoogleNet.cs +++ b/src/TorchSharp/TorchVision/models/GoogleNet.cs @@ -71,19 +71,18 @@ public class GoogleNet : Module private readonly Module conv1; private readonly Module maxpool1; - private readonly Module maxpool2; - private readonly Module maxpool3; - private readonly Module maxpool4; private readonly Module conv2; private readonly Module conv3; + private readonly Module maxpool2; private readonly Module inception3a; private readonly Module inception3b; - + private readonly Module maxpool3; private readonly Module inception4a; private readonly Module inception4b; private readonly Module inception4c; private readonly Module inception4d; private readonly Module inception4e; + private readonly Module maxpool4; private readonly Module inception5a; private readonly Module inception5b; //private readonly Module aux1; diff --git a/src/TorchSharp/TorchVision/models/InceptionV3.cs b/src/TorchSharp/TorchVision/models/InceptionV3.cs index c62e710a5..398943fc9 100644 --- a/src/TorchSharp/TorchVision/models/InceptionV3.cs +++ b/src/TorchSharp/TorchVision/models/InceptionV3.cs @@ -70,11 +70,11 @@ public class InceptionV3 : Module private readonly Module Conv2d_1a_3x3; private readonly Module Conv2d_2a_3x3; private readonly Module Conv2d_2b_3x3; + private readonly Module maxpool1; private readonly Module Conv2d_3b_1x1; private readonly Module Conv2d_4a_3x3; - - private readonly Module maxpool1; private readonly Module maxpool2; + private readonly Module Mixed_5b; private readonly Module Mixed_5c; private readonly Module Mixed_5d; diff --git a/src/TorchSharp/TorchVision/models/VGG.cs b/src/TorchSharp/TorchVision/models/VGG.cs index 5c65217f7..84554d148 100644 --- a/src/TorchSharp/TorchVision/models/VGG.cs +++ b/src/TorchSharp/TorchVision/models/VGG.cs @@ -334,8 +334,8 @@ public class VGG : Module }; private readonly Module features; - private readonly Module classifier; private readonly Module avgpool; + private readonly Module classifier; public VGG(string name, int numClasses, diff --git a/test/TorchSharpTest/TestTorchVision.cs b/test/TorchSharpTest/TestTorchVision.cs index b1eaff969..3ca618a5a 100644 --- a/test/TorchSharpTest/TestTorchVision.cs +++ b/test/TorchSharpTest/TestTorchVision.cs @@ -1,4 +1,4 @@ -using System.Collections.Generic; +using System.Linq; using TorchSharp.Modules; using static TorchSharp.torchvision.models; using Xunit; From 149d0432b85abdef69c765c228931588af0ed974 Mon Sep 17 00:00:00 2001 From: Niklas Gustafsson Date: Mon, 12 Sep 2022 09:52:46 -0700 Subject: [PATCH 4/5] Added a documentation note on the important of the ordering of custom module layers. --- docfx/articles/modules.md | 60 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 60 insertions(+) diff --git a/docfx/articles/modules.md b/docfx/articles/modules.md index 4f472754a..ff35b61c4 100644 --- a/docfx/articles/modules.md +++ b/docfx/articles/modules.md @@ -230,3 +230,63 @@ Sometimes, a module needs to allocate tensor that are not trainable, i.e. their Each buffer should be declared as a field of type 'Tensor' (not 'Parameter'). This will ensure that the buffer is registered properly when `RegisterComponents()` is called. + +## Modules, 'children()' and 'named_children()' + +It is sometimes necessary to create a new model from an existing one and discard some of the final layers. The submodules will appear in the 'named_children' list in the same order that they are declared within the module itself, and when constructing a model based on the children, the layers may be reordered unless the submodules are declared in the same order that they are meant to be invoked. + +So, for example: + +```C# + private class TestModule1 : Module + { + public TestModule1() + : base("TestModule1") + { + lin1 = Linear(100, 10); + lin2 = Linear(10, 5); + RegisterComponents(); + } + + public override Tensor forward(Tensor input) + { + using (var x = lin1.forward(input)) + return lin2.forward(x); + } + + // Correct -- the layers are declared in the same order they are invoked. + private Module lin1; + private Module lin2; + } + + private class TestModule2 : Module + { + public TestModule2() + : base("TestModule2") + { + lin1 = Linear(100, 10); + lin2 = Linear(10, 5); + RegisterComponents(); + } + + public override Tensor forward(Tensor input) + { + using (var x = lin1.forward(input)) + return lin2.forward(x); + } + + // Incorrect -- the layers are not declared in the same order they are invoked. + private Module lin2; + private Module lin1; + } + + ... + TestModule1 mod1 = ... + TestModule2 mod2 = ... + var seq1 = nn.Sequential(mod1.named_children()); + seq1.forward(t); // Does the same as mod1.forward(t) + var seq2 = nn.Sequential(mod2.named_children()); + seq2.forward(t); // This probably blows up. +``` + + From 4328197302ec825cfdd02c1f5105cdd831df7a8c Mon Sep 17 00:00:00 2001 From: Niklas Gustafsson Date: Mon, 12 Sep 2022 11:26:09 -0700 Subject: [PATCH 5/5] Added additional details to TV unit tests. --- test/TorchSharpTest/TestTorchVision.cs | 178 ++++++++++++++++++++++++- 1 file changed, 177 insertions(+), 1 deletion(-) diff --git a/test/TorchSharpTest/TestTorchVision.cs b/test/TorchSharpTest/TestTorchVision.cs index 3ca618a5a..81c046673 100644 --- a/test/TorchSharpTest/TestTorchVision.cs +++ b/test/TorchSharpTest/TestTorchVision.cs @@ -10,13 +10,28 @@ namespace TorchSharp [Collection("Sequential")] #endif // NET472_OR_GREATER public class TestTorchVision - { + { [Fact] public void TestResNet18() { using var model = resnet18(); var sd = model.state_dict(); Assert.Equal(122, sd.Count); + + var names = model.named_children().Select(nm => nm.name).ToArray(); + Assert.Multiple( + () => Assert.Equal("conv1", names[0]), + () => Assert.Equal("bn1", names[1]), + () => Assert.Equal("relu", names[2]), + () => Assert.Equal("maxpool", names[3]), + () => Assert.Equal("layer1", names[4]), + () => Assert.Equal("layer2", names[5]), + () => Assert.Equal("layer3", names[6]), + () => Assert.Equal("layer4", names[7]), + () => Assert.Equal("avgpool", names[8]), + () => Assert.Equal("flatten", names[9]), + () => Assert.Equal("fc", names[10]) + ); } [Fact] @@ -25,6 +40,21 @@ public void TestResNet34() using var model = resnet34(); var sd = model.state_dict(); Assert.Equal(218, sd.Count); + + var names = model.named_children().Select(nm => nm.name).ToArray(); + Assert.Multiple( + () => Assert.Equal("conv1", names[0]), + () => Assert.Equal("bn1", names[1]), + () => Assert.Equal("relu", names[2]), + () => Assert.Equal("maxpool", names[3]), + () => Assert.Equal("layer1", names[4]), + () => Assert.Equal("layer2", names[5]), + () => Assert.Equal("layer3", names[6]), + () => Assert.Equal("layer4", names[7]), + () => Assert.Equal("avgpool", names[8]), + () => Assert.Equal("flatten", names[9]), + () => Assert.Equal("fc", names[10]) + ); } [Fact] @@ -33,6 +63,21 @@ public void TestResNet50() using var model = resnet50(); var sd = model.state_dict(); Assert.Equal(320, sd.Count); + + var names = model.named_children().Select(nm => nm.name).ToArray(); + Assert.Multiple( + () => Assert.Equal("conv1", names[0]), + () => Assert.Equal("bn1", names[1]), + () => Assert.Equal("relu", names[2]), + () => Assert.Equal("maxpool", names[3]), + () => Assert.Equal("layer1", names[4]), + () => Assert.Equal("layer2", names[5]), + () => Assert.Equal("layer3", names[6]), + () => Assert.Equal("layer4", names[7]), + () => Assert.Equal("avgpool", names[8]), + () => Assert.Equal("flatten", names[9]), + () => Assert.Equal("fc", names[10]) + ); } [Fact] @@ -41,6 +86,21 @@ public void TestResNet101() using var model = resnet101(); var sd = model.state_dict(); Assert.Equal(626, sd.Count); + + var names = model.named_children().Select(nm => nm.name).ToArray(); + Assert.Multiple( + () => Assert.Equal("conv1", names[0]), + () => Assert.Equal("bn1", names[1]), + () => Assert.Equal("relu", names[2]), + () => Assert.Equal("maxpool", names[3]), + () => Assert.Equal("layer1", names[4]), + () => Assert.Equal("layer2", names[5]), + () => Assert.Equal("layer3", names[6]), + () => Assert.Equal("layer4", names[7]), + () => Assert.Equal("avgpool", names[8]), + () => Assert.Equal("flatten", names[9]), + () => Assert.Equal("fc", names[10]) + ); } [Fact] @@ -49,6 +109,21 @@ public void TestResNet152() using var model = resnet152(); var sd = model.state_dict(); Assert.Equal(932, sd.Count); + + var names = model.named_children().Select(nm => nm.name).ToArray(); + Assert.Multiple( + () => Assert.Equal("conv1", names[0]), + () => Assert.Equal("bn1", names[1]), + () => Assert.Equal("relu", names[2]), + () => Assert.Equal("maxpool", names[3]), + () => Assert.Equal("layer1", names[4]), + () => Assert.Equal("layer2", names[5]), + () => Assert.Equal("layer3", names[6]), + () => Assert.Equal("layer4", names[7]), + () => Assert.Equal("avgpool", names[8]), + () => Assert.Equal("flatten", names[9]), + () => Assert.Equal("fc", names[10]) + ); } [Fact] @@ -57,6 +132,12 @@ public void TestAlexNet() using var model = alexnet(); var sd = model.state_dict(); Assert.Equal(16, sd.Count); + var names = model.named_children().Select(nm => nm.name).ToArray(); + Assert.Multiple( + () => Assert.Equal("features", names[0]), + () => Assert.Equal("avgpool", names[1]), + () => Assert.Equal("classifier", names[2]) + ); } [Fact] @@ -66,11 +147,23 @@ public void TestVGG11() using var model = vgg11(); var sd = model.state_dict(); Assert.Equal(22, sd.Count); + var names = model.named_children().Select(nm => nm.name).ToArray(); + Assert.Multiple( + () => Assert.Equal("features", names[0]), + () => Assert.Equal("avgpool", names[1]), + () => Assert.Equal("classifier", names[2]) + ); } { using var model = vgg11_bn(); var sd = model.state_dict(); Assert.Equal(62, sd.Count); + var names = model.named_children().Select(nm => nm.name).ToArray(); + Assert.Multiple( + () => Assert.Equal("features", names[0]), + () => Assert.Equal("avgpool", names[1]), + () => Assert.Equal("classifier", names[2]) + ); } } @@ -81,11 +174,23 @@ public void TestVGG13() using var model = vgg13(); var sd = model.state_dict(); Assert.Equal(26, sd.Count); + var names = model.named_children().Select(nm => nm.name).ToArray(); + Assert.Multiple( + () => Assert.Equal("features", names[0]), + () => Assert.Equal("avgpool", names[1]), + () => Assert.Equal("classifier", names[2]) + ); } { using var model = vgg13_bn(); var sd = model.state_dict(); Assert.Equal(76, sd.Count); + var names = model.named_children().Select(nm => nm.name).ToArray(); + Assert.Multiple( + () => Assert.Equal("features", names[0]), + () => Assert.Equal("avgpool", names[1]), + () => Assert.Equal("classifier", names[2]) + ); } } @@ -96,11 +201,23 @@ public void TestVGG16() using var model = vgg16(); var sd = model.state_dict(); Assert.Equal(32, sd.Count); + var names = model.named_children().Select(nm => nm.name).ToArray(); + Assert.Multiple( + () => Assert.Equal("features", names[0]), + () => Assert.Equal("avgpool", names[1]), + () => Assert.Equal("classifier", names[2]) + ); } { using var model = vgg16_bn(); var sd = model.state_dict(); Assert.Equal(97, sd.Count); + var names = model.named_children().Select(nm => nm.name).ToArray(); + Assert.Multiple( + () => Assert.Equal("features", names[0]), + () => Assert.Equal("avgpool", names[1]), + () => Assert.Equal("classifier", names[2]) + ); } } @@ -111,11 +228,23 @@ public void TestVGG19() using var model = vgg19(); var sd = model.state_dict(); Assert.Equal(38, sd.Count); + var names = model.named_children().Select(nm => nm.name).ToArray(); + Assert.Multiple( + () => Assert.Equal("features", names[0]), + () => Assert.Equal("avgpool", names[1]), + () => Assert.Equal("classifier", names[2]) + ); } { using var model = vgg19_bn(); var sd = model.state_dict(); Assert.Equal(118, sd.Count); + var names = model.named_children().Select(nm => nm.name).ToArray(); + Assert.Multiple( + () => Assert.Equal("features", names[0]), + () => Assert.Equal("avgpool", names[1]), + () => Assert.Equal("classifier", names[2]) + ); } } @@ -125,6 +254,31 @@ public void TestInception() using var model = inception_v3(); var sd = model.state_dict(); Assert.Equal(580, sd.Count); + var names = model.named_children().Select(nm => nm.name).ToArray(); + Assert.Multiple( + () => Assert.Equal("Conv2d_1a_3x3", names[0]), + () => Assert.Equal("Conv2d_2a_3x3", names[1]), + () => Assert.Equal("Conv2d_2b_3x3", names[2]), + () => Assert.Equal("maxpool1", names[3]), + () => Assert.Equal("Conv2d_3b_1x1", names[4]), + () => Assert.Equal("Conv2d_4a_3x3", names[5]), + () => Assert.Equal("maxpool2", names[6]), + () => Assert.Equal("Mixed_5b", names[7]), + () => Assert.Equal("Mixed_5c", names[8]), + () => Assert.Equal("Mixed_5d", names[9]), + () => Assert.Equal("Mixed_6a", names[10]), + () => Assert.Equal("Mixed_6b", names[11]), + () => Assert.Equal("Mixed_6c", names[12]), + () => Assert.Equal("Mixed_6d", names[13]), + () => Assert.Equal("Mixed_6e", names[14]), + () => Assert.Equal("AuxLogits", names[15]), + () => Assert.Equal("Mixed_7a", names[16]), + () => Assert.Equal("Mixed_7b", names[17]), + () => Assert.Equal("Mixed_7c", names[18]), + () => Assert.Equal("avgpool", names[19]), + () => Assert.Equal("dropout", names[20]), + () => Assert.Equal("fc", names[21]) + ); } [Fact] @@ -133,6 +287,28 @@ public void TestGoogLeNet() using var model = googlenet(); var sd = model.state_dict(); Assert.Equal(344, sd.Count); + var names = model.named_children().Select(nm => nm.name).ToArray(); + Assert.Multiple( + () => Assert.Equal("conv1", names[0]), + () => Assert.Equal("maxpool1", names[1]), + () => Assert.Equal("conv2", names[2]), + () => Assert.Equal("conv3", names[3]), + () => Assert.Equal("maxpool2", names[4]), + () => Assert.Equal("inception3a", names[5]), + () => Assert.Equal("inception3b", names[6]), + () => Assert.Equal("maxpool3", names[7]), + () => Assert.Equal("inception4a", names[8]), + () => Assert.Equal("inception4b", names[9]), + () => Assert.Equal("inception4c", names[10]), + () => Assert.Equal("inception4d", names[11]), + () => Assert.Equal("inception4e", names[12]), + () => Assert.Equal("maxpool4", names[13]), + () => Assert.Equal("inception5a", names[14]), + () => Assert.Equal("inception5b", names[15]), + () => Assert.Equal("avgpool", names[16]), + () => Assert.Equal("dropout", names[17]), + () => Assert.Equal("fc", names[18]) + ); } } }