From 375939dd0be1ed9a47c78e2586c8da565cf19868 Mon Sep 17 00:00:00 2001 From: Aleksei Smirnov Date: Mon, 24 Jul 2023 16:45:27 +0300 Subject: [PATCH] Fix DataFrame.LoadCsv can not load CSV with duplicate column names --- src/Microsoft.Data.Analysis/DataFrame.IO.cs | 29 ++++++++++++-- .../DataFrame.IOTests.cs | 40 +++++++++++++++++++ 2 files changed, 65 insertions(+), 4 deletions(-) diff --git a/src/Microsoft.Data.Analysis/DataFrame.IO.cs b/src/Microsoft.Data.Analysis/DataFrame.IO.cs index 44448da008..6f78a94ea2 100644 --- a/src/Microsoft.Data.Analysis/DataFrame.IO.cs +++ b/src/Microsoft.Data.Analysis/DataFrame.IO.cs @@ -8,6 +8,7 @@ using System.Data.Common; using System.Globalization; using System.IO; +using System.Linq; using System.Text; using System.Threading.Tasks; @@ -349,8 +350,8 @@ private static DataFrameColumn CreateColumn(Type kind, string[] columnNames, int private static DataFrame ReadCsvLinesIntoDataFrame(WrappedStreamReaderOrStringReader wrappedReader, char separator = ',', bool header = true, string[] columnNames = null, Type[] dataTypes = null, - long numberOfRowsToRead = -1, int guessRows = 10, bool addIndexColumn = false - ) + long numberOfRowsToRead = -1, int guessRows = 10, bool addIndexColumn = false, + bool renameDuplicatedColumns = false) { if (dataTypes == null && guessRows <= 0) { @@ -376,6 +377,25 @@ private static DataFrame ReadCsvLinesIntoDataFrame(WrappedStreamReaderOrStringRe // First pass: schema and number of rows. while ((fields = parser.ReadFields()) != null) { + if (renameDuplicatedColumns) + { + var names = new Dictionary(); + + for (int i = 0; i < fields.Length; i++) + { + if (names.TryGetValue(fields[i], out int index)) + { + var newName = String.Format("{0}.{1}", fields[i], index); + names[fields[i]] = ++index; + fields[i] = newName; + } + else + { + names.Add(fields[i], 1); + } + } + } + if ((numberOfRowsToRead == -1) || rowline < numberOfRowsToRead) { if (linesForGuessType.Count < guessRows || (header && rowline == 0)) @@ -525,12 +545,13 @@ public static DataFrame LoadCsvFromString(string csvString, /// number of rows used to guess types /// add one column with the row index /// The character encoding. Defaults to UTF8 if not specified + /// If set to true, columns with repeated names are auto-renamed. /// public static DataFrame LoadCsv(Stream csvStream, char separator = ',', bool header = true, string[] columnNames = null, Type[] dataTypes = null, long numberOfRowsToRead = -1, int guessRows = 10, bool addIndexColumn = false, - Encoding encoding = null) + Encoding encoding = null, bool renameDuplicatedColumns = false) { if (!csvStream.CanSeek) { @@ -543,7 +564,7 @@ public static DataFrame LoadCsv(Stream csvStream, } WrappedStreamReaderOrStringReader wrappedStreamReaderOrStringReader = new WrappedStreamReaderOrStringReader(csvStream, encoding ?? Encoding.UTF8); - return ReadCsvLinesIntoDataFrame(wrappedStreamReaderOrStringReader, separator, header, columnNames, dataTypes, numberOfRowsToRead, guessRows, addIndexColumn); + return ReadCsvLinesIntoDataFrame(wrappedStreamReaderOrStringReader, separator, header, columnNames, dataTypes, numberOfRowsToRead, guessRows, addIndexColumn, renameDuplicatedColumns); } /// diff --git a/test/Microsoft.Data.Analysis.Tests/DataFrame.IOTests.cs b/test/Microsoft.Data.Analysis.Tests/DataFrame.IOTests.cs index 05565673b0..646a7bceef 100644 --- a/test/Microsoft.Data.Analysis.Tests/DataFrame.IOTests.cs +++ b/test/Microsoft.Data.Analysis.Tests/DataFrame.IOTests.cs @@ -154,6 +154,46 @@ void ReducedRowsTest(DataFrame reducedRows) ReducedRowsTest(csvDf); } + [Fact] + public void TestReadCsvWithHeaderAndDuplicatedColumns_WithoutRenaming() + { + + string data = @$"vendor_id,rate_code,passenger_count,trip_time_in_secs,trip_distance,payment_type,payment_type,fare_amount +CMT,1,1,1271,3.8,CRD,CRD,17.5 +CMT,1,1,474,1.5,CRD,CRD,8 +CMT,1,1,637,1.4,CRD,CRD,8.5 +CMT,1,1,181,0.6,CSH,CSH,4.5"; + + Assert.Throws(() => DataFrame.LoadCsv(GetStream(data))); + } + + [Fact] + public void TestReadCsvWithHeaderAndDuplicatedColumns_WithDuplicateColumnRenaming() + { + + string data = @$"vendor_id,rate_code,passenger_count,trip_time_in_secs,trip_distance,payment_type,payment_type,payment_type,fare_amount +CMT,1,1,1271,3.8,CRD,CRD_1,Test,17.5 +CMT,1,1,474,1.5,CRD,CRD,Test,8 +CMT,1,1,637,1.4,CRD,CRD,Test,8.5 +CMT,1,1,181,0.6,CSH,CSH,Test,4.5"; + + DataFrame df = DataFrame.LoadCsv(GetStream(data), renameDuplicatedColumns: true); + + Assert.Equal(4, df.Rows.Count); + Assert.Equal(9, df.Columns.Count); + Assert.Equal("CMT", df.Columns["vendor_id"][3]); + + Assert.Equal("payment_type", df.Columns[5].Name); + Assert.Equal("payment_type.1", df.Columns[6].Name); + Assert.Equal("payment_type.2", df.Columns[7].Name); + + Assert.Equal("CRD", df.Columns["payment_type"][0]); + Assert.Equal("CRD_1", df.Columns["payment_type.1"][0]); + Assert.Equal("Test", df.Columns["payment_type.2"][0]); + + VerifyColumnTypes(df); + } + [Fact] public void TestReadCsvSplitAcrossMultipleLines() {