Skip to content

Commit 8001ccc

Browse files
authored
Adding functional tests for all training and evaluation tasks (#2646)
* Adding functional tests for all training and evaluation tasks
1 parent 22844f6 commit 8001ccc

File tree

10 files changed

+570
-45
lines changed

10 files changed

+570
-45
lines changed

test/Microsoft.ML.Functional.Tests/Common.cs

Lines changed: 75 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
using System.Linq;
88
using Microsoft.Data.DataView;
99
using Microsoft.ML.Data;
10+
using Microsoft.ML.Data.Evaluators.Metrics;
1011
using Microsoft.ML.Functional.Tests.Datasets;
1112
using Xunit;
1213

@@ -160,13 +161,86 @@ public static void AssertEqual(TypeTestData testType1, TypeTestData testType2)
160161
Assert.True(testType1.Ug.Equals(testType2.Ug));
161162
}
162163

164+
/// <summary>
165+
/// Check that a <see cref="AnomalyDetectionMetrics"/> object is valid.
166+
/// </summary>
167+
/// <param name="metrics">The metrics object.</param>
168+
public static void AssertMetrics(AnomalyDetectionMetrics metrics)
169+
{
170+
Assert.InRange(metrics.Auc, 0, 1);
171+
Assert.InRange(metrics.DrAtK, 0, 1);
172+
}
173+
174+
/// <summary>
175+
/// Check that a <see cref="BinaryClassificationMetrics"/> object is valid.
176+
/// </summary>
177+
/// <param name="metrics">The metrics object.</param>
178+
public static void AssertMetrics(BinaryClassificationMetrics metrics)
179+
{
180+
Assert.InRange(metrics.Accuracy, 0, 1);
181+
Assert.InRange(metrics.Auc, 0, 1);
182+
Assert.InRange(metrics.Auprc, 0, 1);
183+
Assert.InRange(metrics.F1Score, 0, 1);
184+
Assert.InRange(metrics.NegativePrecision, 0, 1);
185+
Assert.InRange(metrics.NegativeRecall, 0, 1);
186+
Assert.InRange(metrics.PositivePrecision, 0, 1);
187+
Assert.InRange(metrics.PositiveRecall, 0, 1);
188+
}
189+
190+
/// <summary>
191+
/// Check that a <see cref="CalibratedBinaryClassificationMetrics"/> object is valid.
192+
/// </summary>
193+
/// <param name="metrics">The metrics object.</param>
194+
public static void AssertMetrics(CalibratedBinaryClassificationMetrics metrics)
195+
{
196+
Assert.InRange(metrics.Entropy, double.NegativeInfinity, 1);
197+
Assert.InRange(metrics.LogLoss, double.NegativeInfinity, 1);
198+
Assert.InRange(metrics.LogLossReduction, double.NegativeInfinity, 100);
199+
AssertMetrics(metrics as BinaryClassificationMetrics);
200+
}
201+
202+
/// <summary>
203+
/// Check that a <see cref="ClusteringMetrics"/> object is valid.
204+
/// </summary>
205+
/// <param name="metrics">The metrics object.</param>
206+
public static void AssertMetrics(ClusteringMetrics metrics)
207+
{
208+
Assert.True(metrics.AvgMinScore >= 0);
209+
Assert.True(metrics.Dbi >= 0);
210+
if (!double.IsNaN(metrics.Nmi))
211+
Assert.True(metrics.Nmi >= 0 && metrics.Nmi <= 1);
212+
}
213+
214+
/// <summary>
215+
/// Check that a <see cref="MultiClassClassifierMetrics"/> object is valid.
216+
/// </summary>
217+
/// <param name="metrics">The metrics object.</param>
218+
public static void AssertMetrics(MultiClassClassifierMetrics metrics)
219+
{
220+
Assert.InRange(metrics.AccuracyMacro, 0, 1);
221+
Assert.InRange(metrics.AccuracyMicro, 0, 1);
222+
Assert.True(metrics.LogLoss >= 0);
223+
Assert.InRange(metrics.TopKAccuracy, 0, 1);
224+
}
225+
226+
/// <summary>
227+
/// Check that a <see cref="RankerMetrics"/> object is valid.
228+
/// </summary>
229+
/// <param name="metrics">The metrics object.</param>
230+
public static void AssertMetrics(RankerMetrics metrics)
231+
{
232+
foreach (var dcg in metrics.Dcg)
233+
Assert.True(dcg >= 0);
234+
foreach (var ndcg in metrics.Ndcg)
235+
Assert.InRange(ndcg, 0, 100);
236+
}
237+
163238
/// <summary>
164239
/// Check that a <see cref="RegressionMetrics"/> object is valid.
165240
/// </summary>
166241
/// <param name="metrics">The metrics object.</param>
167242
public static void AssertMetrics(RegressionMetrics metrics)
168243
{
169-
// Perform sanity checks on the metrics.
170244
Assert.True(metrics.Rms >= 0);
171245
Assert.True(metrics.L1 >= 0);
172246
Assert.True(metrics.L2 >= 0);
@@ -179,7 +253,6 @@ public static void AssertMetrics(RegressionMetrics metrics)
179253
/// <param name="metric">The <see cref="MetricStatistics"/> object.</param>
180254
public static void AssertMetricStatistics(MetricStatistics metric)
181255
{
182-
// Perform sanity checks on the metrics.
183256
Assert.True(metric.StandardDeviation >= 0);
184257
Assert.True(metric.StandardError >= 0);
185258
}
@@ -190,7 +263,6 @@ public static void AssertMetricStatistics(MetricStatistics metric)
190263
/// <param name="metrics">The metrics object.</param>
191264
public static void AssertMetricsStatistics(RegressionMetricsStatistics metrics)
192265
{
193-
// The mean can be any float; the standard deviation and error must be >=0.
194266
AssertMetricStatistics(metrics.Rms);
195267
AssertMetricStatistics(metrics.L1);
196268
AssertMetricStatistics(metrics.L2);
Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
// Licensed to the .NET Foundation under one or more agreements.
2+
// The .NET Foundation licenses this file to you under the MIT license.
3+
// See the LICENSE file in the project root for more information.
4+
5+
6+
using System;
7+
using Microsoft.Data.DataView;
8+
using Microsoft.ML.Data;
9+
10+
namespace Microsoft.ML.Functional.Tests.Datasets
11+
{
12+
/// <summary>
13+
/// A class for the Iris test dataset.
14+
/// </summary>
15+
internal sealed class Iris
16+
{
17+
[LoadColumn(0)]
18+
public float Label { get; set; }
19+
20+
[LoadColumn(1)]
21+
public float SepalLength { get; set; }
22+
23+
[LoadColumn(2)]
24+
public float SepalWidth { get; set; }
25+
26+
[LoadColumn(4)]
27+
public float PetalLength { get; set; }
28+
29+
[LoadColumn(5)]
30+
public float PetalWidth { get; set; }
31+
32+
/// <summary>
33+
/// The list of columns commonly used as features.
34+
/// </summary>
35+
public static readonly string[] Features = new string[] { "SepalLength", "SepalWidth", "PetalLength", "PetalWidth" };
36+
37+
public static IDataView LoadAsRankingProblem(MLContext mlContext, string filePath, bool hasHeader, char separatorChar, int seed = 1)
38+
{
39+
// Load the Iris data.
40+
var data = mlContext.Data.ReadFromTextFile<Iris>(filePath, hasHeader: hasHeader, separatorChar: separatorChar);
41+
42+
// Create a function that generates a random groupId.
43+
var rng = new Random(seed);
44+
Action<Iris, IrisWithGroup> generateGroupId = (input, output) =>
45+
{
46+
output.Label = input.Label;
47+
// The standard set used in tests has 150 rows
48+
output.GroupId = rng.Next(0, 30);
49+
output.PetalLength = input.PetalLength;
50+
output.PetalWidth = input.PetalWidth;
51+
output.SepalLength = input.SepalLength;
52+
output.SepalWidth = input.SepalWidth;
53+
};
54+
55+
// Describe a pipeline that generates a groupId and converts it to a key.
56+
var pipeline = mlContext.Transforms.CustomMapping(generateGroupId, null)
57+
.Append(mlContext.Transforms.Conversion.MapValueToKey("GroupId"));
58+
59+
// Transform the data
60+
var transformedData = pipeline.Fit(data).Transform(data);
61+
62+
return transformedData;
63+
}
64+
}
65+
66+
/// <summary>
67+
/// A class for the Iris dataset with a GroupId column.
68+
/// </summary>
69+
internal sealed class IrisWithGroup
70+
{
71+
public float Label { get; set; }
72+
public int GroupId { get; set; }
73+
public float SepalLength { get; set; }
74+
public float SepalWidth { get; set; }
75+
public float PetalLength { get; set; }
76+
public float PetalWidth { get; set; }
77+
}
78+
}
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
// Licensed to the .NET Foundation under one or more agreements.
2+
// The .NET Foundation licenses this file to you under the MIT license.
3+
// See the LICENSE file in the project root for more information.
4+
5+
using Microsoft.ML.Data;
6+
7+
namespace Microsoft.ML.Functional.Tests.Datasets
8+
{
9+
internal sealed class MnistOneClass
10+
{
11+
private const int _featureLength = 783;
12+
13+
public float Label { get; set; }
14+
15+
public float[] Features { get; set; }
16+
17+
public static TextLoader GetTextLoader(MLContext mlContext, bool hasHeader, char separatorChar)
18+
{
19+
return mlContext.Data.CreateTextLoader(
20+
new[] {
21+
new TextLoader.Column("Label", DataKind.R4, 0),
22+
new TextLoader.Column("Features", DataKind.R4, 1, 1 + _featureLength)
23+
},
24+
separatorChar: separatorChar,
25+
hasHeader: hasHeader,
26+
allowSparse: true);
27+
}
28+
}
29+
}
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
// Licensed to the .NET Foundation under one or more agreements.
2+
// The .NET Foundation licenses this file to you under the MIT license.
3+
// See the LICENSE file in the project root for more information.
4+
5+
using Microsoft.ML.Data;
6+
7+
namespace Microsoft.ML.Functional.Tests.Datasets
8+
{
9+
/// <summary>
10+
/// A class for reading in the Sentiment test dataset.
11+
/// </summary>
12+
internal sealed class TweetSentiment
13+
{
14+
[LoadColumn(0), ColumnName("Label")]
15+
public bool Sentiment { get; set; }
16+
17+
[LoadColumn(1)]
18+
public string SentimentText { get; set; }
19+
}
20+
}
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
// Licensed to the .NET Foundation under one or more agreements.
2+
// The .NET Foundation licenses this file to you under the MIT license.
3+
// See the LICENSE file in the project root for more information.
4+
5+
6+
using System;
7+
using Microsoft.Data.DataView;
8+
using Microsoft.ML.Data;
9+
10+
namespace Microsoft.ML.Functional.Tests.Datasets
11+
{
12+
/// <summary>
13+
/// A class describing the TrivialMatrixFactorization test dataset.
14+
/// </summary>
15+
internal sealed class TrivialMatrixFactorization
16+
{
17+
[LoadColumn(0)]
18+
public float Label { get; set; }
19+
20+
[LoadColumn(1)]
21+
public uint MatrixColumnIndex { get; set; }
22+
23+
[LoadColumn(2)]
24+
public uint MatrixRowIndex { get; set; }
25+
26+
public static IDataView LoadAndFeaturizeFromTextFile(MLContext mlContext, string filePath, bool hasHeader, char separatorChar)
27+
{
28+
// Load the data from a textfile.
29+
var data = mlContext.Data.ReadFromTextFile<TrivialMatrixFactorization>(filePath, hasHeader: hasHeader, separatorChar: separatorChar);
30+
31+
// Describe a pipeline to translate the uints to keys.
32+
var pipeline = mlContext.Transforms.Conversion.MapValueToKey("MatrixColumnIndex")
33+
.Append(mlContext.Transforms.Conversion.MapValueToKey("MatrixRowIndex"));
34+
35+
// Transform the data.
36+
var transformedData = pipeline.Fit(data).Transform(data);
37+
38+
return transformedData;
39+
}
40+
}
41+
}

0 commit comments

Comments
 (0)