diff --git a/test/Microsoft.ML.Tests/TrainerEstimators/MatrixFactorizationTests.cs b/test/Microsoft.ML.Tests/TrainerEstimators/MatrixFactorizationTests.cs index 2552c5be98..594b95bd03 100644 --- a/test/Microsoft.ML.Tests/TrainerEstimators/MatrixFactorizationTests.cs +++ b/test/Microsoft.ML.Tests/TrainerEstimators/MatrixFactorizationTests.cs @@ -54,8 +54,6 @@ public void MatrixFactorization_Estimator() } [MatrixFactorizationFact] - //Skipping test temporarily. This test will be re-enabled once the cause of failures has been determined - [Trait("Category", "SkipInCI")] public void MatrixFactorizationSimpleTrainAndPredict() { var mlContext = new MLContext(seed: 1); @@ -94,7 +92,7 @@ public void MatrixFactorizationSimpleTrainAndPredict() var rightMatrix = model.Model.RightFactorMatrix; Assert.Equal(leftMatrix.Count, model.Model.NumberOfRows * model.Model.ApproximationRank); Assert.Equal(rightMatrix.Count, model.Model.NumberOfColumns * model.Model.ApproximationRank); - // MF produce different matrixes on different platforms, so at least test thier content on windows. + // MF produce different matrixes on different platforms, so at least test their content on windows. if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows)) { Assert.Equal(0.33491, leftMatrix[0], 5); @@ -124,12 +122,14 @@ public void MatrixFactorizationSimpleTrainAndPredict() var metrices = mlContext.Recommendation().Evaluate(prediction, labelColumnName: labelColumnName, scoreColumnName: scoreColumnName); // Determine if the selected metric is reasonable for different platforms - double tolerance = Math.Pow(10, -7); + double windowsTolerance = Math.Pow(10, -7); + // 1e-7 is too small for Linux, so we try 1e-4 + double linuxTolerance = Math.Pow(10, -4); if (RuntimeInformation.IsOSPlatform(OSPlatform.Linux)) { // Linux case var expectedUnixL2Error = 0.614457914950479; // Linux baseline - Assert.InRange(metrices.MeanSquaredError, expectedUnixL2Error - tolerance, expectedUnixL2Error + tolerance); + Assert.InRange(metrices.MeanSquaredError, expectedUnixL2Error - linuxTolerance, expectedUnixL2Error + linuxTolerance); } else if (RuntimeInformation.IsOSPlatform(OSPlatform.OSX)) { @@ -142,7 +142,7 @@ public void MatrixFactorizationSimpleTrainAndPredict() { // Windows case var expectedWindowsL2Error = 0.6098110249191965; // Windows baseline - Assert.InRange(metrices.MeanSquaredError, expectedWindowsL2Error - tolerance, expectedWindowsL2Error + tolerance); + Assert.InRange(metrices.MeanSquaredError, expectedWindowsL2Error - windowsTolerance, expectedWindowsL2Error + windowsTolerance); } var modelWithValidation = pipeline.Fit(data, testData);