diff --git a/eng/Versions.props b/eng/Versions.props
index 052539cb30..0c6252b6f5 100644
--- a/eng/Versions.props
+++ b/eng/Versions.props
@@ -87,7 +87,7 @@
0.0.6-test
0.0.7-test
4.6.1
- 1.0.112.2
+ 1.0.113
1.2.7
2.4.2
diff --git a/eng/helix.proj b/eng/helix.proj
index b7d3b60086..ef55768a73 100644
--- a/eng/helix.proj
+++ b/eng/helix.proj
@@ -96,7 +96,7 @@
- $(HelixPreCommands);export ML_TEST_DATADIR=$HELIX_CORRELATION_PAYLOAD;export MICROSOFTML_RESOURCE_PATH=$HELIX_WORKITEM_ROOT;sudo chmod -R 777 $HELIX_WORKITEM_ROOT;sudo chown -R $(whoami) $HELIX_WORKITEM_ROOT
+ $(HelixPreCommands);export ML_TEST_DATADIR=$HELIX_CORRELATION_PAYLOAD;export MICROSOFTML_RESOURCE_PATH=$HELIX_WORKITEM_ROOT;sudo chmod -R 777 $HELIX_WORKITEM_ROOT;sudo chown -R $USER $HELIX_WORKITEM_ROOT
$(HelixPreCommands);set ML_TEST_DATADIR=%HELIX_CORRELATION_PAYLOAD%;set MICROSOFTML_RESOURCE_PATH=%HELIX_WORKITEM_ROOT%
$(HelixPreCommands);install_name_tool -change "/usr/local/opt/libomp/lib/libomp.dylib" "@loader_path/libomp.dylib" libSymSgdNative.dylib
diff --git a/src/Microsoft.Data.Analysis/DataFrame.IO.cs b/src/Microsoft.Data.Analysis/DataFrame.IO.cs
index 04cacc99a8..e525624998 100644
--- a/src/Microsoft.Data.Analysis/DataFrame.IO.cs
+++ b/src/Microsoft.Data.Analysis/DataFrame.IO.cs
@@ -4,9 +4,12 @@
using System;
using System.Collections.Generic;
+using System.Data;
+using System.Data.Common;
using System.Globalization;
using System.IO;
using System.Text;
+using System.Threading.Tasks;
namespace Microsoft.Data.Analysis
{
@@ -109,12 +112,158 @@ public static DataFrame LoadCsv(string filename,
}
}
+ public static DataFrame LoadFrom(IEnumerable> vals, IList<(string, Type)> columnInfos)
+ {
+ var columnsCount = columnInfos.Count;
+ var columns = new List(columnsCount);
+
+ foreach (var (name, type) in columnInfos)
+ {
+ var column = CreateColumn(type, name);
+ columns.Add(column);
+ }
+
+ var res = new DataFrame(columns);
+
+ foreach (var items in vals)
+ {
+ for (var c = 0; c < items.Count; c++)
+ {
+ items[c] = items[c];
+ }
+ res.Append(items, inPlace: true);
+ }
+
+ return res;
+ }
+
+ public void SaveTo(DataTable table)
+ {
+ var columnsCount = Columns.Count;
+
+ if (table.Columns.Count == 0)
+ {
+ foreach (var column in Columns)
+ {
+ table.Columns.Add(column.Name, column.DataType);
+ }
+ }
+ else
+ {
+ if (table.Columns.Count != columnsCount)
+ throw new ArgumentException();
+ for (var c = 0; c < columnsCount; c++)
+ {
+ if (table.Columns[c].DataType != Columns[c].DataType)
+ throw new ArgumentException();
+ }
+ }
+
+ var items = new object[columnsCount];
+ foreach (var row in Rows)
+ {
+ for (var c = 0; c < columnsCount; c++)
+ {
+ items[c] = row[c] ?? DBNull.Value;
+ }
+ table.Rows.Add(items);
+ }
+ }
+
+ public DataTable ToTable()
+ {
+ var res = new DataTable();
+ SaveTo(res);
+ return res;
+ }
+
+ public static DataFrame FromSchema(DbDataReader reader)
+ {
+ var columnsCount = reader.FieldCount;
+ var columns = new DataFrameColumn[columnsCount];
+
+ for (var c = 0; c < columnsCount; c++)
+ {
+ var type = reader.GetFieldType(c);
+ var name = reader.GetName(c);
+ var column = CreateColumn(type, name);
+ columns[c] = column;
+ }
+
+ var res = new DataFrame(columns);
+ return res;
+ }
+
+ public static async Task LoadFrom(DbDataReader reader)
+ {
+ var res = FromSchema(reader);
+ var columnsCount = reader.FieldCount;
+
+ var items = new object[columnsCount];
+ while (await reader.ReadAsync())
+ {
+ for (var c = 0; c < columnsCount; c++)
+ {
+ items[c] = reader.IsDBNull(c)
+ ? null
+ : reader[c];
+ }
+ res.Append(items, inPlace: true);
+ }
+
+ reader.Close();
+
+ return res;
+ }
+
+ public static async Task LoadFrom(DbDataAdapter adapter)
+ {
+ using var reader = await adapter.SelectCommand.ExecuteReaderAsync();
+ return await LoadFrom(reader);
+ }
+
+ public void SaveTo(DbDataAdapter dataAdapter, DbProviderFactory factory)
+ {
+ using var commandBuilder = factory.CreateCommandBuilder();
+ commandBuilder.DataAdapter = dataAdapter;
+ dataAdapter.InsertCommand = commandBuilder.GetInsertCommand();
+ dataAdapter.UpdateCommand = commandBuilder.GetUpdateCommand();
+ dataAdapter.DeleteCommand = commandBuilder.GetDeleteCommand();
+
+ using var table = ToTable();
+
+ var connection = dataAdapter.SelectCommand.Connection;
+ var needClose = connection.TryOpen();
+
+ try
+ {
+ using var transaction = connection.BeginTransaction();
+ try
+ {
+ dataAdapter.Update(table);
+ }
+ catch
+ {
+ transaction.Rollback();
+ transaction.Dispose();
+ throw;
+ }
+ transaction.Commit();
+ }
+ finally
+ {
+ if (needClose)
+ connection.Close();
+ }
+ }
+
///
/// return of if not null or empty, otherwise return "Column{i}" where i is .
///
/// column names.
/// column index.
///
+
private static string GetColumnName(string[] columnNames, int columnIndex)
{
var defaultColumnName = "Column" + columnIndex.ToString();
@@ -126,68 +275,68 @@ private static string GetColumnName(string[] columnNames, int columnIndex)
return defaultColumnName;
}
- private static DataFrameColumn CreateColumn(Type kind, string[] columnNames, int columnIndex)
+ private static DataFrameColumn CreateColumn(Type kind, string columnName)
{
DataFrameColumn ret;
if (kind == typeof(bool))
{
- ret = new BooleanDataFrameColumn(GetColumnName(columnNames, columnIndex));
+ ret = new BooleanDataFrameColumn(columnName);
}
else if (kind == typeof(int))
{
- ret = new Int32DataFrameColumn(GetColumnName(columnNames, columnIndex));
+ ret = new Int32DataFrameColumn(columnName);
}
else if (kind == typeof(float))
{
- ret = new SingleDataFrameColumn(GetColumnName(columnNames, columnIndex));
+ ret = new SingleDataFrameColumn(columnName);
}
else if (kind == typeof(string))
{
- ret = new StringDataFrameColumn(GetColumnName(columnNames, columnIndex), 0);
+ ret = new StringDataFrameColumn(columnName, 0);
}
else if (kind == typeof(long))
{
- ret = new Int64DataFrameColumn(GetColumnName(columnNames, columnIndex));
+ ret = new Int64DataFrameColumn(columnName);
}
else if (kind == typeof(decimal))
{
- ret = new DecimalDataFrameColumn(GetColumnName(columnNames, columnIndex));
+ ret = new DecimalDataFrameColumn(columnName);
}
else if (kind == typeof(byte))
{
- ret = new ByteDataFrameColumn(GetColumnName(columnNames, columnIndex));
+ ret = new ByteDataFrameColumn(columnName);
}
else if (kind == typeof(char))
{
- ret = new CharDataFrameColumn(GetColumnName(columnNames, columnIndex));
+ ret = new CharDataFrameColumn(columnName);
}
else if (kind == typeof(double))
{
- ret = new DoubleDataFrameColumn(GetColumnName(columnNames, columnIndex));
+ ret = new DoubleDataFrameColumn(columnName);
}
else if (kind == typeof(sbyte))
{
- ret = new SByteDataFrameColumn(GetColumnName(columnNames, columnIndex));
+ ret = new SByteDataFrameColumn(columnName);
}
else if (kind == typeof(short))
{
- ret = new Int16DataFrameColumn(GetColumnName(columnNames, columnIndex));
+ ret = new Int16DataFrameColumn(columnName);
}
else if (kind == typeof(uint))
{
- ret = new UInt32DataFrameColumn(GetColumnName(columnNames, columnIndex));
+ ret = new UInt32DataFrameColumn(columnName);
}
else if (kind == typeof(ulong))
{
- ret = new UInt64DataFrameColumn(GetColumnName(columnNames, columnIndex));
+ ret = new UInt64DataFrameColumn(columnName);
}
else if (kind == typeof(ushort))
{
- ret = new UInt16DataFrameColumn(GetColumnName(columnNames, columnIndex));
+ ret = new UInt16DataFrameColumn(columnName);
}
else if (kind == typeof(DateTime))
{
- ret = new PrimitiveDataFrameColumn(GetColumnName(columnNames, columnIndex));
+ ret = new PrimitiveDataFrameColumn(columnName);
}
else
{
@@ -196,6 +345,11 @@ private static DataFrameColumn CreateColumn(Type kind, string[] columnNames, int
return ret;
}
+ private static DataFrameColumn CreateColumn(Type kind, string[] columnNames, int columnIndex)
+ {
+ return CreateColumn(kind, GetColumnName(columnNames, columnIndex));
+ }
+
private static DataFrame ReadCsvLinesIntoDataFrame(WrappedStreamReaderOrStringReader wrappedReader,
char separator = ',', bool header = true,
string[] columnNames = null, Type[] dataTypes = null,
diff --git a/src/Microsoft.Data.Analysis/Extensions.cs b/src/Microsoft.Data.Analysis/Extensions.cs
new file mode 100644
index 0000000000..3e3d20b4a4
--- /dev/null
+++ b/src/Microsoft.Data.Analysis/Extensions.cs
@@ -0,0 +1,37 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using System;
+using System.Collections.Generic;
+using System.Data;
+using System.Data.Common;
+using System.Text;
+
+namespace Microsoft.Data.Analysis
+{
+ public static class Extensions
+ {
+ public static DbDataAdapter CreateDataAdapter(this DbProviderFactory factory, DbConnection connection, string tableName)
+ {
+ var query = connection.CreateCommand();
+ query.CommandText = $"SELECT * FROM {tableName}";
+ var res = factory.CreateDataAdapter();
+ res.SelectCommand = query;
+ return res;
+ }
+
+ public static bool TryOpen(this DbConnection connection)
+ {
+ if (connection.State == ConnectionState.Closed)
+ {
+ connection.Open();
+ return true;
+ }
+ else
+ {
+ return false;
+ }
+ }
+ }
+}
diff --git a/test/Microsoft.Data.Analysis.Tests/DataFrame.IOTests.cs b/test/Microsoft.Data.Analysis.Tests/DataFrame.IOTests.cs
index 697e38f9e4..398e849907 100644
--- a/test/Microsoft.Data.Analysis.Tests/DataFrame.IOTests.cs
+++ b/test/Microsoft.Data.Analysis.Tests/DataFrame.IOTests.cs
@@ -4,11 +4,16 @@
using System;
using System.Collections.Generic;
+using System.Data;
+using System.Data.Common;
using System.Globalization;
using System.IO;
using System.Linq;
using System.Text;
+using System.Data.SQLite;
+using System.Data.SQLite.EF6;
using Xunit;
+using Microsoft.ML.TestFramework.Attributes;
namespace Microsoft.Data.Analysis.Tests
{
@@ -1021,6 +1026,119 @@ public void TestMixedDataTypesInCsv()
}
}
+ [Fact]
+ public void TestLoadFromEnumerable()
+ {
+ var (columns, vals) = GetTestData();
+ var dataFrame = DataFrame.LoadFrom(vals, columns);
+ AssertEqual(dataFrame, columns, vals);
+ }
+
+ [Fact]
+ public void TestSaveToDataTable()
+ {
+ var (columns, vals) = GetTestData();
+ var dataFrame = DataFrame.LoadFrom(vals, columns);
+
+ using var table = dataFrame.ToTable();
+
+ var resColumns = table.Columns.Cast().Select(column => (column.ColumnName, column.DataType)).ToArray();
+ Assert.Equal(columns, resColumns);
+
+ var resVals = table.Rows.Cast().Select(row => row.ItemArray).ToArray();
+ Assert.Equal(vals, resVals);
+ }
+
+ [X86X64FactAttribute("The SQLite un-managed code, SQLite.interop, only supports x86/x64 architectures.")]
+ public async void TestSQLite()
+ {
+ var (columns, vals) = GetTestData();
+ var dataFrame = DataFrame.LoadFrom(vals, columns);
+
+ try
+ {
+ var (factory, connection) = InitSQLiteDb();
+ using (factory)
+ {
+ using (connection)
+ {
+ using var dataAdapter = factory.CreateDataAdapter(connection, TableName);
+ dataFrame.SaveTo(dataAdapter, factory);
+
+ var resDataFrame = await DataFrame.LoadFrom(dataAdapter);
+
+ AssertEqual(resDataFrame, columns, vals);
+ }
+ }
+ }
+ finally
+ {
+ CleanupSQLiteDb();
+ }
+ }
+
+ static void AssertEqual(DataFrame dataFrame, (string name, Type type)[] columns, object[][] vals)
+ {
+ var resColumns = dataFrame.Columns.Select(column => (column.Name, column.DataType)).ToArray();
+ Assert.Equal(columns, resColumns);
+ var resVals = dataFrame.Rows.Select(row => row.ToArray()).ToArray();
+ Assert.Equal(vals, resVals);
+ }
+
+ static ((string name, Type type)[] columns, object[][] vals) GetTestData()
+ {
+ const int RowsCount = 10_000;
+
+ var columns = new[]
+ {
+ ("ID", typeof(long)),
+ ("Text", typeof(string))
+ };
+
+ var vals = new object[RowsCount][];
+ for (var i = 0L; i < RowsCount; i++)
+ {
+ var row = new object[columns.Length];
+ row[0] = i;
+ row[1] = $"test {i}";
+ vals[i] = row;
+ }
+
+ return (columns, vals);
+ }
+
+ static (SQLiteProviderFactory factory, DbConnection connection) InitSQLiteDb()
+ {
+ var connectionString = $"DataSource={SQLitePath};Version=3;New=True;Compress=True;";
+
+ SQLiteConnection.CreateFile(SQLitePath);
+ var factory = new SQLiteProviderFactory();
+
+ var connection = factory.CreateConnection();
+ connection.ConnectionString = connectionString;
+ connection.Open();
+
+ using var command = connection.CreateCommand();
+ command.CommandText = $"CREATE TABLE {TableName} (ID INTEGER NOT NULL PRIMARY KEY ASC, Text VARCHAR(25))";
+ command.ExecuteNonQuery();
+
+ return (factory, connection);
+ }
+
+ static void CleanupSQLiteDb()
+ {
+ if (File.Exists(SQLitePath))
+ File.Delete(SQLitePath);
+ }
+
+ static readonly string BasePath =
+ Path.GetDirectoryName(System.Reflection.Assembly.GetExecutingAssembly().Location) + "/";
+
+ const string DbName = "TestDb";
+ const string TableName = "TestTable";
+
+ static readonly string SQLitePath = $@"{BasePath}/{DbName}.sqlite";
+
public readonly struct LoadCsvVerifyingHelper
{
private readonly int _columnCount;
diff --git a/test/Microsoft.Data.Analysis.Tests/Microsoft.Data.Analysis.Tests.csproj b/test/Microsoft.Data.Analysis.Tests/Microsoft.Data.Analysis.Tests.csproj
index 1a570416e4..dafae0d942 100644
--- a/test/Microsoft.Data.Analysis.Tests/Microsoft.Data.Analysis.Tests.csproj
+++ b/test/Microsoft.Data.Analysis.Tests/Microsoft.Data.Analysis.Tests.csproj
@@ -8,6 +8,8 @@
+
+
@@ -44,4 +46,9 @@
+
+
+
+
+