Skip to content
Merged
78 changes: 75 additions & 3 deletions test/Microsoft.ML.Functional.Tests/Common.cs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
using System.Linq;
using Microsoft.Data.DataView;
using Microsoft.ML.Data;
using Microsoft.ML.Data.Evaluators.Metrics;
using Microsoft.ML.Functional.Tests.Datasets;
using Xunit;

Expand Down Expand Up @@ -160,13 +161,86 @@ public static void AssertEqual(TypeTestData testType1, TypeTestData testType2)
Assert.True(testType1.Ug.Equals(testType2.Ug));
}

/// <summary>
/// Check that a <see cref="AnomalyDetectionMetrics"/> object is valid.
/// </summary>
/// <param name="metrics">The metrics object.</param>
public static void AssertMetrics(AnomalyDetectionMetrics metrics)
{
Assert.InRange(metrics.Auc, 0, 1);
Assert.InRange(metrics.DrAtK, 0, 1);
}

/// <summary>
/// Check that a <see cref="BinaryClassificationMetrics"/> object is valid.
/// </summary>
/// <param name="metrics">The metrics object.</param>
public static void AssertMetrics(BinaryClassificationMetrics metrics)
{
Assert.InRange(metrics.Accuracy, 0, 1);
Assert.InRange(metrics.Auc, 0, 1);
Assert.InRange(metrics.Auprc, 0, 1);
Assert.InRange(metrics.F1Score, 0, 1);
Assert.InRange(metrics.NegativePrecision, 0, 1);
Assert.InRange(metrics.NegativeRecall, 0, 1);
Assert.InRange(metrics.PositivePrecision, 0, 1);
Assert.InRange(metrics.PositiveRecall, 0, 1);
}

/// <summary>
/// Check that a <see cref="CalibratedBinaryClassificationMetrics"/> object is valid.
/// </summary>
/// <param name="metrics">The metrics object.</param>
public static void AssertMetrics(CalibratedBinaryClassificationMetrics metrics)
{
Assert.InRange(metrics.Entropy, double.NegativeInfinity, 1);
Assert.InRange(metrics.LogLoss, double.NegativeInfinity, 1);
Assert.InRange(metrics.LogLossReduction, double.NegativeInfinity, 100);
AssertMetrics(metrics as BinaryClassificationMetrics);
}

/// <summary>
/// Check that a <see cref="ClusteringMetrics"/> object is valid.
/// </summary>
/// <param name="metrics">The metrics object.</param>
public static void AssertMetrics(ClusteringMetrics metrics)
{
Assert.True(metrics.AvgMinScore >= 0);
Assert.True(metrics.Dbi >= 0);
if (!double.IsNaN(metrics.Nmi))
Assert.True(metrics.Nmi >= 0 && metrics.Nmi <= 1);
}

/// <summary>
/// Check that a <see cref="MultiClassClassifierMetrics"/> object is valid.
/// </summary>
/// <param name="metrics">The metrics object.</param>
public static void AssertMetrics(MultiClassClassifierMetrics metrics)
{
Assert.InRange(metrics.AccuracyMacro, 0, 1);
Assert.InRange(metrics.AccuracyMicro, 0, 1);
Assert.True(metrics.LogLoss >= 0);
Assert.InRange(metrics.TopKAccuracy, 0, 1);
}

/// <summary>
/// Check that a <see cref="RankerMetrics"/> object is valid.
/// </summary>
/// <param name="metrics">The metrics object.</param>
public static void AssertMetrics(RankerMetrics metrics)
{
foreach (var dcg in metrics.Dcg)
Assert.True(dcg >= 0);
foreach (var ndcg in metrics.Ndcg)
Assert.InRange(ndcg, 0, 100);
}

/// <summary>
/// Check that a <see cref="RegressionMetrics"/> object is valid.
/// </summary>
/// <param name="metrics">The metrics object.</param>
public static void AssertMetrics(RegressionMetrics metrics)
{
// Perform sanity checks on the metrics.
Assert.True(metrics.Rms >= 0);
Assert.True(metrics.L1 >= 0);
Assert.True(metrics.L2 >= 0);
Expand All @@ -179,7 +253,6 @@ public static void AssertMetrics(RegressionMetrics metrics)
/// <param name="metric">The <see cref="MetricStatistics"/> object.</param>
public static void AssertMetricStatistics(MetricStatistics metric)
{
// Perform sanity checks on the metrics.
Assert.True(metric.StandardDeviation >= 0);
Assert.True(metric.StandardError >= 0);
}
Expand All @@ -190,7 +263,6 @@ public static void AssertMetricStatistics(MetricStatistics metric)
/// <param name="metrics">The metrics object.</param>
public static void AssertMetricsStatistics(RegressionMetricsStatistics metrics)
{
// The mean can be any float; the standard deviation and error must be >=0.
AssertMetricStatistics(metrics.Rms);
AssertMetricStatistics(metrics.L1);
AssertMetricStatistics(metrics.L2);
Expand Down
78 changes: 78 additions & 0 deletions test/Microsoft.ML.Functional.Tests/Datasets/Iris.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.


using System;
using Microsoft.Data.DataView;
using Microsoft.ML.Data;

namespace Microsoft.ML.Functional.Tests.Datasets
{
/// <summary>
/// A class for the Iris test dataset.
/// </summary>
internal sealed class Iris
{
[LoadColumn(0)]
public float Label { get; set; }

[LoadColumn(1)]
public float SepalLength { get; set; }

[LoadColumn(2)]
public float SepalWidth { get; set; }

[LoadColumn(4)]
public float PetalLength { get; set; }

[LoadColumn(5)]
public float PetalWidth { get; set; }

/// <summary>
/// The list of columns commonly used as features.
/// </summary>
public static readonly string[] Features = new string[] { "SepalLength", "SepalWidth", "PetalLength", "PetalWidth" };

public static IDataView LoadAsRankingProblem(MLContext mlContext, string filePath, bool hasHeader, char separatorChar, int seed = 1)
{
// Load the Iris data.
var data = mlContext.Data.ReadFromTextFile<Iris>(filePath, hasHeader: hasHeader, separatorChar: separatorChar);

// Create a function that generates a random groupId.
var rng = new Random(seed);
Action<Iris, IrisWithGroup> generateGroupId = (input, output) =>
{
output.Label = input.Label;
// The standard set used in tests has 150 rows
output.GroupId = rng.Next(0, 30);
output.PetalLength = input.PetalLength;
output.PetalWidth = input.PetalWidth;
output.SepalLength = input.SepalLength;
output.SepalWidth = input.SepalWidth;
};

// Describe a pipeline that generates a groupId and converts it to a key.
var pipeline = mlContext.Transforms.CustomMapping(generateGroupId, null)
.Append(mlContext.Transforms.Conversion.MapValueToKey("GroupId"));

// Transform the data
var transformedData = pipeline.Fit(data).Transform(data);

return transformedData;
}
}

/// <summary>
/// A class for the Iris dataset with a GroupId column.
/// </summary>
internal sealed class IrisWithGroup
{
public float Label { get; set; }
public int GroupId { get; set; }
public float SepalLength { get; set; }
public float SepalWidth { get; set; }
public float PetalLength { get; set; }
public float PetalWidth { get; set; }
}
}
29 changes: 29 additions & 0 deletions test/Microsoft.ML.Functional.Tests/Datasets/MnistOneClass.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

using Microsoft.ML.Data;

namespace Microsoft.ML.Functional.Tests.Datasets
{
internal sealed class MnistOneClass
{
private const int _featureLength = 783;

public float Label { get; set; }

public float[] Features { get; set; }

public static TextLoader GetTextLoader(MLContext mlContext, bool hasHeader, char separatorChar)
{
return mlContext.Data.CreateTextLoader(
new[] {
new TextLoader.Column("Label", DataKind.R4, 0),
new TextLoader.Column("Features", DataKind.R4, 1, 1 + _featureLength)
},
separatorChar: separatorChar,
hasHeader: hasHeader,
allowSparse: true);
}
}
}
20 changes: 20 additions & 0 deletions test/Microsoft.ML.Functional.Tests/Datasets/Sentiment.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

using Microsoft.ML.Data;

namespace Microsoft.ML.Functional.Tests.Datasets
{
/// <summary>
/// A class for reading in the Sentiment test dataset.
/// </summary>
internal sealed class TweetSentiment
{
[LoadColumn(0), ColumnName("Label")]
public bool Sentiment { get; set; }

[LoadColumn(1)]
public string SentimentText { get; set; }
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.


using System;
using Microsoft.Data.DataView;
using Microsoft.ML.Data;

namespace Microsoft.ML.Functional.Tests.Datasets
{
/// <summary>
/// A class describing the TrivialMatrixFactorization test dataset.
/// </summary>
internal sealed class TrivialMatrixFactorization
{
[LoadColumn(0)]
public float Label { get; set; }

[LoadColumn(1)]
public uint MatrixColumnIndex { get; set; }

[LoadColumn(2)]
public uint MatrixRowIndex { get; set; }

public static IDataView LoadAndFeaturizeFromTextFile(MLContext mlContext, string filePath, bool hasHeader, char separatorChar)
{
// Load the data from a textfile.
var data = mlContext.Data.ReadFromTextFile<TrivialMatrixFactorization>(filePath, hasHeader: hasHeader, separatorChar: separatorChar);

// Describe a pipeline to translate the uints to keys.
var pipeline = mlContext.Transforms.Conversion.MapValueToKey("MatrixColumnIndex")
.Append(mlContext.Transforms.Conversion.MapValueToKey("MatrixRowIndex"));

// Transform the data.
var transformedData = pipeline.Fit(data).Transform(data);

return transformedData;
}
}
}
Loading