Skip to content

Commit 1658ce2

Browse files
committed
Merge branch 'frank-dong-ms-retry-fail-tests'
2 parents b045e4d + 0d9a3c2 commit 1658ce2

File tree

13 files changed

+303
-16
lines changed

13 files changed

+303
-16
lines changed

test/Microsoft.ML.Core.Tests/UnitTests/TestEntryPoints.cs

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
using Microsoft.ML.Runtime;
2020
using Microsoft.ML.TestFramework.Attributes;
2121
using Microsoft.ML.TestFrameworkCommon;
22+
using Microsoft.ML.TestFrameworkCommon.Attributes;
2223
using Microsoft.ML.Trainers;
2324
using Microsoft.ML.Trainers.Ensemble;
2425
using Microsoft.ML.Trainers.FastTree;
@@ -2758,7 +2759,7 @@ public void EntryPointSDCAMulticlass()
27582759
TestEntryPointRoutine("iris.txt", "Trainers.StochasticDualCoordinateAscentClassifier");
27592760
}
27602761

2761-
[Fact()]
2762+
[RetryFact]
27622763
public void EntryPointSDCARegression()
27632764
{
27642765
TestEntryPointRoutine(TestDatasets.generatedRegressionDatasetmacro.trainFilename, "Trainers.StochasticDualCoordinateAscentRegressor", loader: TestDatasets.generatedRegressionDatasetmacro.loaderSettings);
@@ -3845,7 +3846,7 @@ public void EntryPointChainedTrainTestMacros()
38453846
validateAuc(metrics);
38463847
}
38473848

3848-
[Fact]
3849+
[RetryFact]
38493850
public void EntryPointChainedCrossValMacros()
38503851
{
38513852
string inputGraph = @"
@@ -6025,7 +6026,7 @@ public void TestCrossValidationMacroWithNonDefaultNames()
60256026
}
60266027
}
60276028

6028-
[Fact]
6029+
[RetryFact]
60296030
public void TestOvaMacro()
60306031
{
60316032
var dataPath = GetDataPath(@"iris.txt");
@@ -6189,7 +6190,7 @@ public void TestOvaMacro()
61896190
}
61906191
}
61916192

6192-
[Fact]
6193+
[RetryFact]
61936194
public void TestOvaMacroWithUncalibratedLearner()
61946195
{
61956196
var dataPath = GetDataPath(@"iris.txt");

test/Microsoft.ML.Predictor.Tests/TestPredictors.cs

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ namespace Microsoft.ML.RunTests
2424
using Xunit.Abstractions;
2525
using TestLearners = TestLearnersBase;
2626
using Microsoft.ML.TestFrameworkCommon;
27+
using Microsoft.ML.TestFrameworkCommon.Attributes;
2728

2829
/// <summary>
2930
/// Tests using maml commands (IDV) functionality.
@@ -181,7 +182,7 @@ public void MulticlassSdcaTest()
181182
/// <summary>
182183
/// Multiclass Logistic Regression test with a tree featurizer.
183184
/// </summary>
184-
[X64Fact("x86 output differs from Baseline")]
185+
[RetryX64Fact("x86 output differs from Baseline")]
185186
[TestCategory("Multiclass")]
186187
[TestCategory("Logistic Regression")]
187188
[TestCategory("FastTree")]
@@ -270,7 +271,7 @@ public void BinaryClassifierLogisticRegressionTest()
270271
Done();
271272
}
272273

273-
[X64Fact("x86 output differs from Baseline")]
274+
[RetryX64Fact("x86 output differs from Baseline")]
274275
[TestCategory("Binary")]
275276
public void BinaryClassifierSymSgdTest()
276277
{
@@ -321,7 +322,7 @@ public void BinaryClassifierLogisticRegressionNonNegativeTest()
321322
/// <summary>
322323
///A test for binary classifiers
323324
///</summary>
324-
[LessThanNetCore30OrNotNetCoreFact("netcoreapp3.0 output differs from Baseline")]
325+
[RetryLessThanNetCore30OrNotNetCoreFact("netcoreapp3.0 output differs from Baseline")]
325326
[TestCategory("Binary")]
326327
public void BinaryClassifierLogisticRegressionBinNormTest()
327328
{
Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
// Licensed to the .NET Foundation under one or more agreements.
2+
// The .NET Foundation licenses this file to you under the MIT license.
3+
// See the LICENSE file in the project root for more information.
4+
5+
using System;
6+
using Xunit;
7+
using Xunit.Sdk;
8+
9+
namespace Microsoft.ML.TestFrameworkCommon.Attributes
10+
{
11+
/// <summary>
12+
/// ML.NET facts that will retry several flaky test cases, use default timeout settings
13+
/// </summary>
14+
[XunitTestCaseDiscoverer("Microsoft.ML.TestFrameworkCommon.RetryFactDiscoverer", "Microsoft.ML.TestFrameworkCommon")]
15+
public class RetryFactAttribute : FactAttribute
16+
{
17+
/// <summary>
18+
/// Number of retries allowed for a failed test. If unset (or set less than 1), will
19+
/// default to 2 attempts.
20+
/// </summary>
21+
public int MaxRetries { get; set; }
22+
}
23+
24+
25+
/// <summary>
26+
/// ML.NET facts that will retry several flaky test cases, use default timeout settings
27+
/// </summary>
28+
[XunitTestCaseDiscoverer("Microsoft.ML.TestFrameworkCommon.RetryFactDiscoverer", "Microsoft.ML.TestFrameworkCommon")]
29+
public class RetryLessThanNetCore30OrNotNetCoreFactAttribute : EnvironmentSpecificFactAttribute
30+
{
31+
public RetryLessThanNetCore30OrNotNetCoreFactAttribute(string skipMessage) : base(skipMessage)
32+
{
33+
}
34+
35+
/// <inheritdoc />
36+
protected override bool IsEnvironmentSupported()
37+
{
38+
return AppDomain.CurrentDomain.GetData("FX_PRODUCT_VERSION") == null;
39+
}
40+
/// <summary>
41+
/// Number of retries allowed for a failed test. If unset (or set less than 1), will
42+
/// default to 2 attempts.
43+
/// </summary>
44+
public int MaxRetries { get; set; }
45+
}
46+
47+
/// <summary>
48+
/// A fact for tests requiring X64 environment.
49+
/// </summary>
50+
[XunitTestCaseDiscoverer("Microsoft.ML.TestFrameworkCommon.RetryFactDiscoverer", "Microsoft.ML.TestFrameworkCommon")]
51+
public sealed class RetryX64FactAttribute : EnvironmentSpecificFactAttribute
52+
{
53+
public RetryX64FactAttribute(string skipMessage) : base(skipMessage)
54+
{
55+
}
56+
57+
/// <inheritdoc />
58+
protected override bool IsEnvironmentSupported()
59+
{
60+
return Environment.Is64BitProcess;
61+
}
62+
63+
public int MaxRetries { get; set; }
64+
}
65+
66+
/// <summary>
67+
/// A fact for tests requiring TensorFlow.
68+
/// </summary>
69+
[XunitTestCaseDiscoverer("Microsoft.ML.TestFrameworkCommon.RetryFactDiscoverer", "Microsoft.ML.TestFrameworkCommon")]
70+
public sealed class RetryTensorFlowFactAttribute : EnvironmentSpecificFactAttribute
71+
{
72+
public RetryTensorFlowFactAttribute() : base("TensorFlow is 64-bit only")
73+
{
74+
}
75+
76+
/// <inheritdoc />
77+
protected override bool IsEnvironmentSupported()
78+
{
79+
return Environment.Is64BitProcess;
80+
}
81+
82+
public int MaxRetries { get; set; }
83+
}
84+
}
85+
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
// Licensed to the .NET Foundation under one or more agreements.
2+
// The .NET Foundation licenses this file to you under the MIT license.
3+
// See the LICENSE file in the project root for more information.
4+
5+
using System.Collections.Generic;
6+
using System.Reflection;
7+
using Xunit.Abstractions;
8+
using Xunit.Sdk;
9+
10+
namespace Microsoft.ML.TestFrameworkCommon
11+
{
12+
/// <summary>
13+
/// Used to capture messages to potentially be forwarded later. Messages are forwarded by
14+
/// disposing of the message bus.
15+
/// </summary>
16+
public class DelayedMessageBus : IMessageBus
17+
{
18+
private readonly IMessageBus innerBus;
19+
public readonly List<IMessageSinkMessage> messages = new List<IMessageSinkMessage>();
20+
21+
public DelayedMessageBus(IMessageBus innerBus)
22+
{
23+
this.innerBus = innerBus;
24+
}
25+
26+
public bool QueueMessage(IMessageSinkMessage message)
27+
{
28+
lock (messages)
29+
messages.Add(message);
30+
31+
// No way to ask the inner bus if they want to cancel without sending them the message, so
32+
// we just go ahead and continue always.
33+
return true;
34+
}
35+
36+
public void Dispose()
37+
{
38+
foreach (var message in messages)
39+
innerBus.QueueMessage(message);
40+
}
41+
}
42+
}
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
// Licensed to the .NET Foundation under one or more agreements.
2+
// The .NET Foundation licenses this file to you under the MIT license.
3+
// See the LICENSE file in the project root for more information.
4+
5+
using System.Collections.Generic;
6+
using Xunit.Abstractions;
7+
using Xunit.Sdk;
8+
9+
namespace Microsoft.ML.TestFrameworkCommon
10+
{
11+
public class RetryFactDiscoverer : IXunitTestCaseDiscoverer
12+
{
13+
readonly IMessageSink diagnosticMessageSink;
14+
15+
public RetryFactDiscoverer(IMessageSink diagnosticMessageSink)
16+
{
17+
this.diagnosticMessageSink = diagnosticMessageSink;
18+
}
19+
20+
public IEnumerable<IXunitTestCase> Discover(ITestFrameworkDiscoveryOptions discoveryOptions,
21+
ITestMethod testMethod, IAttributeInfo factAttribute)
22+
{
23+
//by default, retry failed tests at max 2 times
24+
var maxRetries = factAttribute.GetNamedArgument<int>("MaxRetries");
25+
if (maxRetries < 1)
26+
maxRetries = 2;
27+
28+
yield return new RetryTestCase(diagnosticMessageSink, discoveryOptions.MethodDisplayOrDefault(), testMethod, maxRetries);
29+
}
30+
}
31+
}
32+
Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,120 @@
1+
// Licensed to the .NET Foundation under one or more agreements.
2+
// The .NET Foundation licenses this file to you under the MIT license.
3+
// See the LICENSE file in the project root for more information.
4+
5+
using System;
6+
using System.ComponentModel;
7+
using System.Threading;
8+
using System.Threading.Tasks;
9+
using Xunit.Abstractions;
10+
using Xunit.Sdk;
11+
12+
namespace Microsoft.ML.TestFrameworkCommon
13+
{
14+
[Serializable]
15+
public class RetryTestCase : XunitTestCase
16+
{
17+
private int maxRetries;
18+
19+
[EditorBrowsable(EditorBrowsableState.Never)]
20+
[Obsolete("Called by the de-serializer", true)]
21+
public RetryTestCase() { }
22+
23+
public RetryTestCase(IMessageSink diagnosticMessageSink, TestMethodDisplay testMethodDisplay,
24+
ITestMethod testMethod, int maxRetries)
25+
: base(diagnosticMessageSink, testMethodDisplay, TestMethodDisplayOptions.None, testMethod, testMethodArguments: null)
26+
{
27+
this.maxRetries = maxRetries;
28+
}
29+
30+
// This method is called by the xUnit test framework classes to run the test case. We will do the
31+
// loop here, forwarding on to the implementation in XunitTestCase to do the heavy lifting. We will
32+
// continue to re-run the test until the aggregator has an error (meaning that some internal error
33+
// condition happened), or the test runs without failure, or we've hit the maximum number of tries.
34+
public override async Task<RunSummary> RunAsync(IMessageSink diagnosticMessageSink,
35+
IMessageBus messageBus,
36+
object[] constructorArguments,
37+
ExceptionAggregator aggregator,
38+
CancellationTokenSource cancellationTokenSource)
39+
{
40+
var runCount = 0;
41+
42+
while (true)
43+
{
44+
// This is really the only tricky bit: we need to capture and delay messages (since those will
45+
// contain run status) until we know we've decided to accept the final result;
46+
var delayedMessageBus = new DelayedMessageBus(messageBus);
47+
48+
RunSummary summary = await base.RunAsync(diagnosticMessageSink, delayedMessageBus, constructorArguments, aggregator, cancellationTokenSource);
49+
if (aggregator.HasExceptions || summary.Failed > 0)
50+
{
51+
var details = ExtractTestFailDetailsFromMessageBus(delayedMessageBus);
52+
var errorMessage = $"Execution of '{DisplayName}' failed (attempt #{runCount + 1}) with details {details}.";
53+
54+
diagnosticMessageSink.OnMessage(new DiagnosticMessage(errorMessage));
55+
Console.WriteLine(errorMessage);
56+
}
57+
58+
if (summary.Failed == 0 || ++runCount >= maxRetries)
59+
{
60+
delayedMessageBus.Dispose(); // Sends all the delayed messages
61+
return summary;
62+
}
63+
}
64+
}
65+
66+
private static string ExtractTestFailDetailsFromMessageBus(DelayedMessageBus delayedMessageBus)
67+
{
68+
string details = "";
69+
70+
foreach (var message in delayedMessageBus.messages)
71+
{
72+
if (message.ToString() == "Xunit.Sdk.TestFailed")
73+
{
74+
try
75+
{
76+
var messages = (string[])message.GetType().GetProperty("Messages").GetValue(message);
77+
var exceptionTypes = (string[])message.GetType().GetProperty("ExceptionTypes").GetValue(message);
78+
var stackTraces = (string[])message.GetType().GetProperty("StackTraces").GetValue(message);
79+
80+
if (messages != null && messages.Length > 0)
81+
{
82+
details += "Messages: " + string.Join(";", messages) + ". ";
83+
}
84+
85+
if (exceptionTypes != null && exceptionTypes.Length > 0)
86+
{
87+
details += "ExceptionTypes: " + string.Join(";", exceptionTypes) + ". ";
88+
}
89+
90+
if (stackTraces != null && stackTraces.Length > 0)
91+
{
92+
details += "StackTraces: " + string.Join(";", stackTraces) + ".";
93+
}
94+
}
95+
catch
96+
{
97+
Console.WriteLine($"Fail to read test fail message from message bus.");
98+
}
99+
}
100+
}
101+
102+
return details;
103+
}
104+
105+
public override void Serialize(IXunitSerializationInfo data)
106+
{
107+
base.Serialize(data);
108+
109+
data.AddValue("MaxRetries", maxRetries);
110+
}
111+
112+
public override void Deserialize(IXunitSerializationInfo data)
113+
{
114+
base.Deserialize(data);
115+
116+
maxRetries = data.GetValue<int>("MaxRetries");
117+
}
118+
}
119+
}
120+

test/Microsoft.ML.Tests/FeatureContributionTests.cs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
using Microsoft.ML.Internal.Utilities;
1111
using Microsoft.ML.RunTests;
1212
using Microsoft.ML.TestFramework.Attributes;
13+
using Microsoft.ML.TestFrameworkCommon.Attributes;
1314
using Microsoft.ML.Trainers;
1415
using Xunit;
1516
using Xunit.Abstractions;
@@ -89,7 +90,7 @@ public void TestPoissonRegression()
8990
new LbfgsPoissonRegressionTrainer.Options { NumberOfThreads = 1 }), GetSparseDataset(numberOfInstances: 100), "PoissonRegression");
9091
}
9192

92-
[Fact]
93+
[RetryFact]
9394
public void TestGAMRegression()
9495
{
9596
TestFeatureContribution(ML.Regression.Trainers.Gam(), GetSparseDataset(numberOfInstances: 100), "GAMRegression");

test/Microsoft.ML.Tests/OnnxConversionTest.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -200,7 +200,7 @@ public void KmeansOnnxConversionTest()
200200
Done();
201201
}
202202

203-
[Fact]
203+
[RetryFact]
204204
public void RegressionTrainersOnnxConversionTest()
205205
{
206206
var mlContext = new MLContext(seed: 1);

0 commit comments

Comments
 (0)