Skip to content

Commit a00a222

Browse files
authored
[Issue4484] Support specifying command timeout while using the database loader (#5288)
* initial checkin * update tests * fix comments * remove empty lines * resolve comments and attempt to fix test failure * fix comments * update * update * minor change
1 parent b628918 commit a00a222

File tree

3 files changed

+41
-5
lines changed

3 files changed

+41
-5
lines changed

src/Microsoft.ML.Data/DataLoadSave/Database/DatabaseLoaderCursor.cs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@ public DbCommand Command
7575
{
7676
_command = Connection.CreateCommand();
7777
_command.CommandText = _source.CommandText;
78+
_command.CommandTimeout = _source.CommandTimeoutInSeconds;
7879
}
7980
return _command;
8081
}

src/Microsoft.ML.Data/DataLoadSave/Database/DatabaseSource.cs

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,21 +10,38 @@ namespace Microsoft.ML.Data
1010
/// <summary>Exposes the data required for opening a database for reading.</summary>
1111
public sealed class DatabaseSource
1212
{
13+
private const int DefaultCommandTimeoutInSeconds = 30;
14+
1315
/// <summary>Creates a new instance of the <see cref="DatabaseSource" /> class.</summary>
1416
/// <param name="providerFactory">The factory used to create the <see cref="DbConnection"/>..</param>
1517
/// <param name="connectionString">The string used to open the connection.</param>
1618
/// <param name="commandText">The text command to run against the data source.</param>
17-
public DatabaseSource(DbProviderFactory providerFactory, string connectionString, string commandText)
19+
public DatabaseSource(DbProviderFactory providerFactory, string connectionString, string commandText) :
20+
this(providerFactory, connectionString, commandText, DefaultCommandTimeoutInSeconds)
21+
{
22+
}
23+
24+
/// <summary>Creates a new instance of the <see cref="DatabaseSource" /> class.</summary>
25+
/// <param name="providerFactory">The factory used to create the <see cref="DbConnection"/>..</param>
26+
/// <param name="connectionString">The string used to open the connection.</param>
27+
/// <param name="commandText">The text command to run against the data source.</param>
28+
/// <param name="commandTimeoutInSeconds">The timeout(in seconds) for database command.</param>
29+
public DatabaseSource(DbProviderFactory providerFactory, string connectionString, string commandText, int commandTimeoutInSeconds)
1830
{
1931
Contracts.CheckValue(providerFactory, nameof(providerFactory));
2032
Contracts.CheckValue(connectionString, nameof(connectionString));
2133
Contracts.CheckValue(commandText, nameof(commandText));
34+
Contracts.CheckUserArg(commandTimeoutInSeconds >= 0, nameof(commandTimeoutInSeconds));
2235

2336
ProviderFactory = providerFactory;
2437
ConnectionString = connectionString;
2538
CommandText = commandText;
39+
CommandTimeoutInSeconds = commandTimeoutInSeconds;
2640
}
2741

42+
/// <summary>Gets the timeout for database command.</summary>
43+
public int CommandTimeoutInSeconds { get; }
44+
2845
/// <summary>Gets the text command to run against the data source.</summary>
2946
public string CommandText { get; }
3047

test/Microsoft.ML.Tests/DatabaseLoaderTests.cs

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,22 @@ public DatabaseLoaderTests(ITestOutputHelper output)
2727

2828
[LightGBMFact]
2929
public void IrisLightGbm()
30+
{
31+
DatabaseSource dbs = GetIrisDatabaseSource("SELECT * FROM {0}");
32+
IrisLightGbmImpl(dbs);
33+
}
34+
35+
[LightGBMFact]
36+
public void IrisLightGbmWithTimeout()
37+
{
38+
if (!RuntimeInformation.IsOSPlatform(OSPlatform.Windows)) //sqlite does not have built-in command for sleep
39+
return;
40+
DatabaseSource dbs = GetIrisDatabaseSource("WAITFOR DELAY '00:00:01'; SELECT * FROM {0}", 1);
41+
var ex = Assert.Throws<System.Reflection.TargetInvocationException>(() => IrisLightGbmImpl(dbs));
42+
Assert.Contains("Timeout", ex.InnerException.Message);
43+
}
44+
45+
private void IrisLightGbmImpl(DatabaseSource dbs)
3046
{
3147
var mlContext = new MLContext(seed: 1);
3248

@@ -41,7 +57,7 @@ public void IrisLightGbm()
4157

4258
var loader = mlContext.Data.CreateDatabaseLoader(loaderColumns);
4359

44-
var trainingData = loader.Load(GetIrisDatabaseSource("SELECT * FROM {0}"));
60+
var trainingData = loader.Load(dbs);
4561

4662
IEstimator<ITransformer> pipeline = mlContext.Transforms.Conversion.MapValueToKey("Label")
4763
.Append(mlContext.Transforms.Concatenate("Features", "SepalLength", "SepalWidth", "PetalLength", "PetalWidth"))
@@ -211,18 +227,20 @@ public void IrisSdcaMaximumEntropy()
211227
/// SQLite database is used on Linux and MacOS builds.
212228
/// </summary>
213229
/// <returns>Return the appropiate Iris DatabaseSource according to build OS.</returns>
214-
private DatabaseSource GetIrisDatabaseSource(string command)
230+
private DatabaseSource GetIrisDatabaseSource(string command, int commandTimeoutInSeconds = 30)
215231
{
216232
if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows))
217233
return new DatabaseSource(
218234
SqlClientFactory.Instance,
219235
GetMSSQLConnectionString(TestDatasets.irisDb.name),
220-
String.Format(command, $@"""{TestDatasets.irisDb.trainFilename}"""));
236+
String.Format(command, $@"""{TestDatasets.irisDb.trainFilename}"""),
237+
commandTimeoutInSeconds);
221238
else
222239
return new DatabaseSource(
223240
SQLiteFactory.Instance,
224241
GetSQLiteConnectionString(TestDatasets.irisDbSQLite.name),
225-
String.Format(command, TestDatasets.irisDbSQLite.trainFilename));
242+
String.Format(command, TestDatasets.irisDbSQLite.trainFilename),
243+
commandTimeoutInSeconds);
226244
}
227245

228246
private string GetMSSQLConnectionString(string databaseName)

0 commit comments

Comments
 (0)