diff --git a/test/Microsoft.ML.Tests/Torch/TorchTests.cs b/test/Microsoft.ML.Tests/Torch/TorchTests.cs index 2112a2bcf8..8cc117b8c3 100644 --- a/test/Microsoft.ML.Tests/Torch/TorchTests.cs +++ b/test/Microsoft.ML.Tests/Torch/TorchTests.cs @@ -1,5 +1,7 @@ -using System.Collections.Generic; +using System; +using System.Collections.Generic; using System.Linq; +using System.Runtime.InteropServices; using Microsoft.ML.Data; using Microsoft.ML.RunTests; using Microsoft.ML.TestFramework.Attributes; @@ -24,53 +26,60 @@ private class TestReLUModelData [TorchFact] public void TorchScoringReLUTest() { - var mlContext = new MLContext(); - var tensor = new float[] { -1, -1, 0, 1, 1 }.ToTorchTensor(dimensions: new long[] { 5 }); - var data = new TestReLUModelData + if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows)) { - Features = tensor.Data().ToArray() - }; - var dataPoint = new List() { data }; + var mlContext = new MLContext(); + var tensor = new float[] { -1, -1, 0, 1, 1 }.ToTorchTensor(dimensions: new long[] { 5 }); + var data = new TestReLUModelData + { + Features = tensor.Data().ToArray() + }; + var dataPoint = new List() { data }; - var dataView = mlContext.Data.LoadFromEnumerable(dataPoint); + var dataView = mlContext.Data.LoadFromEnumerable(dataPoint); - var output = mlContext.Model - .LoadTorchModel(GetDataPath("Torch/relu.pt")) - .ScoreTorchModel("Features", new long[] { 5 }) - .Fit(dataView) - .Transform(dataView); + var output = mlContext.Model + .LoadTorchModel(GetDataPath("Torch/relu.pt")) + .ScoreTorchModel("Features", new long[] { 5 }) + .Fit(dataView) + .Transform(dataView); - var transformedData = mlContext.Data.CreateEnumerable(output, false).ToArray()[0].Features; - Assert.True(transformedData.Length == 5); - Assert.Equal(transformedData, new float[] { 0, 0, 0, 1, 1 }); + var transformedData = mlContext.Data.CreateEnumerable(output, false).ToArray()[0].Features; + Assert.True(transformedData.Length == 5); + Assert.Equal(transformedData, new float[] { 0, 0, 0, 1, 1 }); + } } [TorchFact] public void TorchTransformerWorkoutTest() { - var mlContext = new MLContext(); - var tensorData = FloatTensor.Random(new long[] { 5 }); - var datapoint = new TestReLUModelData + if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows)) { - Features = tensorData.Data().ToArray() - }; - var data = new List() { datapoint, datapoint, datapoint, datapoint, datapoint }; + var mlContext = new MLContext(); + var tensorData = FloatTensor.Random(new long[] { 5 }); + var datapoint = new TestReLUModelData + { + Features = tensorData.Data().ToArray() + }; + var data = new List() { datapoint, datapoint, datapoint, datapoint, datapoint }; - var dataView = mlContext.Data.LoadFromEnumerable(data); + var dataView = mlContext.Data.LoadFromEnumerable(data); - var estimator = mlContext.Model.LoadTorchModel(GetDataPath("Torch/relu.pt")) - .ScoreTorchModel("TorchOutput", new long[] { 5 }, "Features"); + var estimator = mlContext.Model.LoadTorchModel(GetDataPath("Torch/relu.pt")) + .ScoreTorchModel("TorchOutput", new long[] { 5 }, "Features"); - TestEstimatorCore(estimator, dataView); + TestEstimatorCore(estimator, dataView); - var output = estimator.Fit(dataView) - .Transform(dataView); + var output = estimator.Fit(dataView) + .Transform(dataView); - var transformedData = mlContext.Data.CreateEnumerable(output, false).ToArray()[0].Features; - Assert.True(transformedData.Length == 5); - foreach (var elt in transformedData) - Assert.True(elt >= 0); + var transformedData = mlContext.Data.CreateEnumerable(output, false).ToArray()[0].Features; + + Assert.True(transformedData.Length == 5); + foreach (var elt in transformedData) + Assert.True(elt >= 0); + } } } }