Skip to content

Commit 9c183fc

Browse files
authored
Fix Saving csv with VBufferDataFrameColumn (#6860)
1 parent e3ec250 commit 9c183fc

File tree

4 files changed

+106
-68
lines changed

4 files changed

+106
-68
lines changed

src/Microsoft.Data.Analysis/DataFrame.IO.cs

Lines changed: 51 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
// See the LICENSE file in the project root for more information.
44

55
using System;
6+
using System.Collections;
67
using System.Collections.Generic;
78
using System.Data;
89
using System.Data.Common;
@@ -11,6 +12,7 @@
1112
using System.Linq;
1213
using System.Text;
1314
using System.Threading.Tasks;
15+
using Microsoft.ML.Data;
1416

1517
namespace Microsoft.Data.Analysis
1618
{
@@ -675,58 +677,7 @@ public static void SaveCsv(DataFrame dataFrame, Stream csvStream,
675677

676678
foreach (var row in dataFrame.Rows)
677679
{
678-
bool firstCell = true;
679-
foreach (var cell in row)
680-
{
681-
if (!firstCell)
682-
{
683-
record.Append(separator);
684-
}
685-
else
686-
{
687-
firstCell = false;
688-
}
689-
690-
Type t = cell?.GetType();
691-
692-
if (t == typeof(bool))
693-
{
694-
record.AppendFormat(cultureInfo, "{0}", cell);
695-
continue;
696-
}
697-
698-
if (t == typeof(float))
699-
{
700-
record.AppendFormat(cultureInfo, "{0:G9}", cell);
701-
continue;
702-
}
703-
704-
if (t == typeof(double))
705-
{
706-
record.AppendFormat(cultureInfo, "{0:G17}", cell);
707-
continue;
708-
}
709-
710-
if (t == typeof(decimal))
711-
{
712-
record.AppendFormat(cultureInfo, "{0:G31}", cell);
713-
continue;
714-
}
715-
716-
if (t == typeof(string))
717-
{
718-
string stringCell = (string)cell;
719-
if (NeedsQuotes(stringCell, separator))
720-
{
721-
record.Append('\"');
722-
record.Append(stringCell.Replace("\"", "\"\"")); // Quotations in CSV data must be escaped with another quotation
723-
record.Append('\"');
724-
continue;
725-
}
726-
}
727-
728-
record.Append(cell);
729-
}
680+
AppendValuesToRecord(record, row, separator, cultureInfo);
730681

731682
csvFile.WriteLine(record);
732683

@@ -736,6 +687,54 @@ public static void SaveCsv(DataFrame dataFrame, Stream csvStream,
736687
}
737688
}
738689

690+
private static void AppendValuesToRecord(StringBuilder record, IEnumerable values, char separator, CultureInfo cultureInfo)
691+
{
692+
bool firstCell = true;
693+
foreach (var value in values)
694+
{
695+
if (!firstCell)
696+
{
697+
record.Append(separator);
698+
}
699+
else
700+
{
701+
firstCell = false;
702+
}
703+
704+
switch (value)
705+
{
706+
case bool:
707+
record.AppendFormat(cultureInfo, "{0}", value);
708+
continue;
709+
case float:
710+
record.AppendFormat(cultureInfo, "{0:G9}", value);
711+
continue;
712+
case double:
713+
record.AppendFormat(cultureInfo, "{0:G17}", value);
714+
continue;
715+
case decimal:
716+
record.AppendFormat(cultureInfo, "{0:G31}", value);
717+
continue;
718+
case string stringCell:
719+
if (NeedsQuotes(stringCell, separator))
720+
{
721+
record.Append('\"');
722+
record.Append(stringCell.Replace("\"", "\"\"")); // Quotations in CSV data must be escaped with another quotation
723+
record.Append('\"');
724+
continue;
725+
}
726+
break;
727+
case IEnumerable nestedValues:
728+
record.Append("(");
729+
AppendValuesToRecord(record, nestedValues, ' ', cultureInfo);
730+
record.Append(")");
731+
continue;
732+
}
733+
734+
record.Append(value);
735+
}
736+
}
737+
739738
private static void SaveHeader(StreamWriter csvFile, IReadOnlyList<string> columnNames, char separator)
740739
{
741740
bool firstColumn = true;

src/Microsoft.Data.Analysis/PrimitiveDataFrameColumn.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -231,7 +231,7 @@ protected virtual PrimitiveDataFrameColumn<T> CreateNewColumn(string name, long
231231
return new PrimitiveDataFrameColumn<T>(name, length);
232232
}
233233

234-
internal T? GetTypedValue(long rowIndex) => _columnContainer[rowIndex];
234+
protected T? GetTypedValue(long rowIndex) => _columnContainer[rowIndex];
235235

236236
protected override object GetValue(long rowIndex) => GetTypedValue(rowIndex);
237237

src/Microsoft.ML.DataView/VBuffer.cs

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
// See the LICENSE file in the project root for more information.
44

55
using System;
6+
using System.Collections;
67
using System.Collections.Generic;
78
using System.Diagnostics;
89
using Microsoft.ML.Internal.DataView;
@@ -27,7 +28,7 @@ namespace Microsoft.ML.Data
2728
/// a value is sufficient to make a completely independent copy of it. So, for example, this means that a buffer of
2829
/// buffers is not possible. But, things like <see cref="int"/>, <see cref="float"/>, and <see
2930
/// cref="ReadOnlyMemory{Char}"/>, are totally fine.</typeparam>
30-
public readonly struct VBuffer<T>
31+
public readonly struct VBuffer<T> : IEnumerable
3132
{
3233
/// <summary>
3334
/// The internal re-usable array of values.
@@ -403,6 +404,14 @@ public T GetItemOrDefault(int index)
403404
public override string ToString()
404405
=> IsDense ? $"Dense vector of size {Length}" : $"Sparse vector of size {Length}, {_count} explicit values";
405406

407+
/// <summary>
408+
/// Returns an enumerator that iterates through the values in VBuffer.
409+
/// </summary>
410+
public IEnumerator GetEnumerator()
411+
{
412+
return _values.GetEnumerator();
413+
}
414+
406415
internal VBufferEditor<T> GetEditor()
407416
{
408417
return GetEditor(Length, _count);

test/Microsoft.Data.Analysis.Tests/DataFrame.IOTests.cs

Lines changed: 44 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
using Xunit;
1616
using Microsoft.ML.TestFramework.Attributes;
1717
using System.Threading;
18+
using Microsoft.ML.Data;
1819

1920
namespace Microsoft.Data.Analysis.Tests
2021
{
@@ -273,7 +274,7 @@ void ReducedRowsTest(DataFrame reducedRows)
273274
[Theory]
274275
[InlineData(false)]
275276
[InlineData(true)]
276-
public void TestReadCsvNoHeader(bool useQuotes)
277+
public void TestLoadCsvNoHeader(bool useQuotes)
277278
{
278279
string CMT = useQuotes ? @"""C,MT""" : "CMT";
279280
string verifyCMT = useQuotes ? "C,MT" : "CMT";
@@ -349,7 +350,7 @@ void VerifyDataFrameWithNamedColumnsAndDataTypes(DataFrame df, bool verifyColumn
349350
[InlineData(false, 0)]
350351
[InlineData(true, 10)]
351352
[InlineData(false, 10)]
352-
public void TestReadCsvWithTypesAndGuessRows(bool header, int guessRows)
353+
public void TestLoadCsvWithTypesAndGuessRows(bool header, int guessRows)
353354
{
354355
/* Tests this matrix
355356
*
@@ -472,7 +473,7 @@ void Verify(DataFrame df)
472473
}
473474

474475
[Fact]
475-
public void TestReadCsvWithTypesDateTime()
476+
public void TestLoadCsvWithTypesDateTime()
476477
{
477478
string data = @"vendor_id,rate_code,passenger_count,trip_time_in_secs,trip_distance,payment_type,fare_amount,date
478479
CMT,1,1,1271,3.8,CRD,17.5,1-june-2020
@@ -549,7 +550,7 @@ void Verify(DataFrame df, bool verifyDataTypes)
549550
}
550551

551552
[Fact]
552-
public void TestReadCsvWithPipeSeparator()
553+
public void TestLoadCsvWithPipeSeparator()
553554
{
554555
string data = @"vendor_id|rate_code|passenger_count|trip_time_in_secs|trip_distance|payment_type|fare_amount
555556
CMT|1|1|1271|3.8|CRD|17.5
@@ -588,7 +589,7 @@ void Verify(DataFrame df)
588589
}
589590

590591
[Fact]
591-
public void TestReadCsvWithSemicolonSeparator()
592+
public void TestLoadCsvWithSemicolonSeparator()
592593
{
593594
string data = @"vendor_id;rate_code;passenger_count;trip_time_in_secs;trip_distance;payment_type;fare_amount
594595
CMT;1;1;1271;3.8;CRD;17.5
@@ -627,7 +628,7 @@ void Verify(DataFrame df)
627628
}
628629

629630
[Fact]
630-
public void TestReadCsvWithExtraColumnInHeader()
631+
public void TestLoadCsvWithExtraColumnInHeader()
631632
{
632633
string data = @"vendor_id,rate_code,passenger_count,trip_time_in_secs,trip_distance,payment_type,fare_amount,extra
633634
CMT,1,1,1271,3.8,CRD,17.5
@@ -656,7 +657,7 @@ void Verify(DataFrame df)
656657
}
657658

658659
[Fact]
659-
public void TestReadCsvWithMultipleEmptyColumnNameInHeaderWithoutGivenColumn()
660+
public void TestLoadCsvWithMultipleEmptyColumnNameInHeaderWithoutGivenColumn()
660661
{
661662
string data = @"vendor_id,rate_code,passenger_count,trip_time_in_secs,,,,
662663
CMT,1,1,1271,3.8,CRD,17.5,0
@@ -671,7 +672,7 @@ public void TestReadCsvWithMultipleEmptyColumnNameInHeaderWithoutGivenColumn()
671672
}
672673

673674
[Fact]
674-
public void TestReadCsvWithMultipleEmptyColumnNameInHeaderWithGivenColumn()
675+
public void TestLoadCsvWithMultipleEmptyColumnNameInHeaderWithGivenColumn()
675676
{
676677
string data = @"vendor_id,rate_code,passenger_count,trip_time_in_secs,,,,
677678
CMT,1,1,1271,3.8,CRD,17.5,0
@@ -713,7 +714,7 @@ public void TestLoadCsvWithAddIndexColumn()
713714
}
714715

715716
[Fact]
716-
public void TestReadCsvWithExtraColumnInRow()
717+
public void TestLoadCsvWithExtraColumnInRow()
717718
{
718719
string data = @"vendor_id,rate_code,passenger_count,trip_time_in_secs,trip_distance,payment_type,fare_amount
719720
CMT,1,1,1271,3.8,CRD,17.5,0
@@ -726,7 +727,7 @@ public void TestReadCsvWithExtraColumnInRow()
726727
}
727728

728729
[Fact]
729-
public void TestReadCsvWithLessColumnsInRow()
730+
public void TestLoadCsvWithLessColumnsInRow()
730731
{
731732
string data = @"vendor_id,rate_code,passenger_count,trip_time_in_secs,trip_distance,payment_type,fare_amount
732733
CMT,1,1,1271,3.8,CRD
@@ -755,7 +756,7 @@ void Verify(DataFrame df)
755756
}
756757

757758
[Fact]
758-
public void TestReadCsvWithAllNulls()
759+
public void TestLoadCsvWithAllNulls()
759760
{
760761
string data = @"vendor_id,rate_code,passenger_count,trip_time_in_secs
761762
null,null,null,null
@@ -798,7 +799,7 @@ void Verify(DataFrame df)
798799
}
799800

800801
[Fact]
801-
public void TestReadCsvWithNullsAndDataTypes()
802+
public void TestLoadCsvWithNullsAndDataTypes()
802803
{
803804
string data = @"vendor_id,rate_code,passenger_count,trip_time_in_secs
804805
null,1,1,1271
@@ -860,7 +861,7 @@ void Verify(DataFrame df)
860861
}
861862

862863
[Fact]
863-
public void TestReadCsvWithNulls()
864+
public void TestLoadCsvWithNulls()
864865
{
865866
string data = @"vendor_id,rate_code,passenger_count,trip_time_in_secs
866867
null,1,1,1271
@@ -922,7 +923,36 @@ void Verify(DataFrame df)
922923
}
923924

924925
[Fact]
925-
public void TestWriteCsvWithHeader()
926+
public void TestSaveCsvVBufferColumn()
927+
{
928+
var vBuffers = new[]
929+
{
930+
new VBuffer<int> (3, new int[] { 1, 2, 3 }),
931+
new VBuffer<int> (3, new int[] { 2, 3, 4 }),
932+
new VBuffer<int> (3, new int[] { 3, 4, 5 }),
933+
};
934+
935+
var vBufferColumn = new VBufferDataFrameColumn<int>("VBuffer", vBuffers);
936+
DataFrame dataFrame = new DataFrame(vBufferColumn);
937+
938+
using MemoryStream csvStream = new MemoryStream();
939+
940+
DataFrame.SaveCsv(dataFrame, csvStream);
941+
942+
csvStream.Seek(0, SeekOrigin.Begin);
943+
DataFrame readIn = DataFrame.LoadCsv(csvStream);
944+
945+
Assert.Equal(dataFrame.Rows.Count, readIn.Rows.Count);
946+
Assert.Equal(dataFrame.Columns.Count, readIn.Columns.Count);
947+
948+
Assert.Equal(typeof(string), readIn.Columns[0].DataType);
949+
Assert.Equal("(1 2 3)", readIn[0, 0]);
950+
Assert.Equal("(2 3 4)", readIn[1, 0]);
951+
Assert.Equal("(3 4 5)", readIn[2, 0]);
952+
}
953+
954+
[Fact]
955+
public void TestSaveCsvWithHeader()
926956
{
927957
using MemoryStream csvStream = new MemoryStream();
928958
DataFrame dataFrame = DataFrameTests.MakeDataFrameWithAllColumnTypes(10, true);

0 commit comments

Comments
 (0)