diff --git a/src/Microsoft.Data.Analysis/DataFrameColumnCollection.cs b/src/Microsoft.Data.Analysis/DataFrameColumnCollection.cs index 21d29b4c19..85282a7930 100644 --- a/src/Microsoft.Data.Analysis/DataFrameColumnCollection.cs +++ b/src/Microsoft.Data.Analysis/DataFrameColumnCollection.cs @@ -58,13 +58,15 @@ public void Insert(int columnIndex, IEnumerable column, string columnName) protected override void InsertItem(int columnIndex, DataFrameColumn column) { column = column ?? throw new ArgumentNullException(nameof(column)); - if (RowCount > 0 && column.Length != RowCount) + + if (Count == 0) { - throw new ArgumentException(Strings.MismatchedColumnLengths, nameof(column)); + //change RowCount on inserting first row to dataframe + RowCount = column.Length; } - - if (Count >= 1 && RowCount == 0 && column.Length != RowCount) + else if (column.Length != RowCount) { + //check all columns in the dataframe have the same length (amount of rows) throw new ArgumentException(Strings.MismatchedColumnLengths, nameof(column)); } @@ -72,7 +74,9 @@ protected override void InsertItem(int columnIndex, DataFrameColumn column) { throw new ArgumentException(string.Format(Strings.DuplicateColumnName, column.Name), nameof(column)); } + RowCount = column.Length; + _columnNameToIndexDictionary[column.Name] = columnIndex; for (int i = columnIndex + 1; i < Count; i++) { @@ -108,6 +112,11 @@ protected override void RemoveItem(int columnIndex) _columnNameToIndexDictionary[this[i].Name]--; } base.RemoveItem(columnIndex); + + //Reset RowCount if the last column was removed and dataframe is empty + if (Count == 0) + RowCount = 0; + ColumnsChanged?.Invoke(); } @@ -138,6 +147,9 @@ protected override void ClearItems() base.ClearItems(); ColumnsChanged?.Invoke(); _columnNameToIndexDictionary.Clear(); + + //reset RowCount as DataFrame is now empty + RowCount = 0; } /// diff --git a/test/Microsoft.Data.Analysis.Tests/DataFrameTests.cs b/test/Microsoft.Data.Analysis.Tests/DataFrameTests.cs index 1695ded747..4510530299 100644 --- a/test/Microsoft.Data.Analysis.Tests/DataFrameTests.cs +++ b/test/Microsoft.Data.Analysis.Tests/DataFrameTests.cs @@ -271,29 +271,44 @@ public void TestIndexer() [Fact] public void ColumnAndTableCreationTest() { - DataFrameColumn intColumn = new Int32DataFrameColumn("IntColumn", Enumerable.Range(0, 10).Select(x => x)); - DataFrameColumn floatColumn = new SingleDataFrameColumn("FloatColumn", Enumerable.Range(0, 10).Select(x => (float)x)); + const int rowCount = 10; + DataFrameColumn intColumn = new Int32DataFrameColumn("IntColumn", Enumerable.Range(0, rowCount).Select(x => x)); + DataFrameColumn floatColumn = new SingleDataFrameColumn("FloatColumn", Enumerable.Range(0, rowCount).Select(x => (float)x)); DataFrame dataFrame = new DataFrame(); dataFrame.Columns.Insert(0, intColumn); dataFrame.Columns.Insert(1, floatColumn); - Assert.Equal(10, dataFrame.Rows.Count); + Assert.Equal(rowCount, dataFrame.Rows.Count); Assert.Equal(2, dataFrame.Columns.Count); - Assert.Equal(10, dataFrame.Columns[0].Length); + Assert.Equal(2, dataFrame.Columns.LongCount()); + Assert.Equal(rowCount, dataFrame.Columns[0].Length); Assert.Equal("IntColumn", dataFrame.Columns[0].Name); - Assert.Equal(10, dataFrame.Columns[1].Length); + Assert.Equal(rowCount, dataFrame.Columns[1].Length); Assert.Equal("FloatColumn", dataFrame.Columns[1].Name); - DataFrameColumn bigColumn = new SingleDataFrameColumn("BigColumn", Enumerable.Range(0, 11).Select(x => (float)x)); - DataFrameColumn repeatedName = new SingleDataFrameColumn("FloatColumn", Enumerable.Range(0, 10).Select(x => (float)x)); + //add column with bigger length than other columns in the dataframe + DataFrameColumn bigColumn = new SingleDataFrameColumn("BigColumn", Enumerable.Range(0, rowCount + 1).Select(x => (float)x)); Assert.Throws(() => dataFrame.Columns.Insert(2, bigColumn)); + Assert.Throws(() => dataFrame.Columns.Add(bigColumn)); + + //add column smaller than other columns in the dataframe + DataFrameColumn smallColumn = new SingleDataFrameColumn("SmallColumn", Enumerable.Range(0, rowCount - 1).Select(x => (float)x)); + Assert.Throws(() => dataFrame.Columns.Insert(2, smallColumn)); + Assert.Throws(() => dataFrame.Columns.Add(smallColumn)); + + //add column with duplicate name + DataFrameColumn repeatedName = new SingleDataFrameColumn("FloatColumn", Enumerable.Range(0, rowCount).Select(x => (float)x)); Assert.Throws(() => dataFrame.Columns.Insert(2, repeatedName)); - Assert.Throws(() => dataFrame.Columns.Insert(10, repeatedName)); + + //Insert column at index out of range + DataFrameColumn extraColumn = new SingleDataFrameColumn("OtherFloatColumn", Enumerable.Range(0, rowCount).Select(x => (float)x)); + var columnCount = dataFrame.Columns.Count; + Assert.Throws(() => dataFrame.Columns.Insert(columnCount + 1, repeatedName)); Assert.Equal(2, dataFrame.Columns.Count); - DataFrameColumn intColumnCopy = new Int32DataFrameColumn("IntColumn", Enumerable.Range(0, 10).Select(x => x)); + DataFrameColumn intColumnCopy = new Int32DataFrameColumn("IntColumn", Enumerable.Range(0, rowCount).Select(x => x)); Assert.Throws(() => dataFrame.Columns[1] = intColumnCopy); - DataFrameColumn differentIntColumn = new Int32DataFrameColumn("IntColumn1", Enumerable.Range(0, 10).Select(x => x)); + DataFrameColumn differentIntColumn = new Int32DataFrameColumn("IntColumn1", Enumerable.Range(0, rowCount).Select(x => x)); dataFrame.Columns[1] = differentIntColumn; Assert.True(object.ReferenceEquals(differentIntColumn, dataFrame.Columns[1])); @@ -309,18 +324,68 @@ public void ColumnAndTableCreationTest() } [Fact] - public void InsertAndRemoveColumnTests() + public void InsertAndRemoveColumnToTheEndOfNotEmptyDataFrameTests() { DataFrame dataFrame = MakeDataFrameWithAllMutableColumnTypes(10); - DataFrameColumn intColumn = new Int32DataFrameColumn("IntColumn", Enumerable.Range(0, 10).Select(x => x)); - DataFrameColumn charColumn = dataFrame.Columns["Char"]; - int insertedIndex = dataFrame.Columns.Count; - dataFrame.Columns.Insert(dataFrame.Columns.Count, intColumn); + DataFrameColumn intColumn = new Int32DataFrameColumn("NewIntColumn", Enumerable.Range(0, 10).Select(x => x)); + + int columnCount = dataFrame.Columns.Count; + DataFrameColumn originalLastColumn = dataFrame.Columns[columnCount - 1]; + + //Insert new column at the end + dataFrame.Columns.Insert(columnCount, intColumn); + Assert.Equal(columnCount + 1, dataFrame.Columns.Count); + + //Remove first dataFrame.Columns.RemoveAt(0); - DataFrameColumn intColumn_1 = dataFrame.Columns["IntColumn"]; - DataFrameColumn charColumn_1 = dataFrame.Columns["Char"]; + Assert.Equal(columnCount, dataFrame.Columns.Count); + + //Check that int column was inserted + DataFrameColumn intColumn_1 = dataFrame.Columns["NewIntColumn"]; Assert.True(ReferenceEquals(intColumn, intColumn_1)); - Assert.True(ReferenceEquals(charColumn, charColumn_1)); + + //Check that last column of the original dataframe was not removed + DataFrameColumn lastColumn_1 = dataFrame.Columns[originalLastColumn.Name]; + Assert.True(ReferenceEquals(originalLastColumn, lastColumn_1)); + + //Check that new column is the last one + int newIndex = dataFrame.Columns.IndexOf("NewIntColumn"); + Assert.Equal(columnCount - 1, newIndex); + + //Check that original last column now has correct index + int newIndexForOriginalLastColumn = dataFrame.Columns.IndexOf(originalLastColumn.Name); + Assert.Equal(columnCount - 2, newIndexForOriginalLastColumn); + } + + [Fact] + public void AddAndRemoveColumnToTheEmptyDataFrameTests() + { + DataFrame dataFrame = new DataFrame(); + DataFrameColumn intColumn = new Int32DataFrameColumn("NewIntColumn", Enumerable.Range(0, 10).Select(x => x)); + + dataFrame.Columns.Add(intColumn); + Assert.Single(dataFrame.Columns); + Assert.Equal(10, dataFrame.Rows.Count); + + dataFrame.Columns.Remove(intColumn); + Assert.Empty(dataFrame.Columns); + Assert.Equal(0, dataFrame.Rows.Count); + } + + [Fact] + public void ClearColumnsTests() + { + //Arrange + DataFrame dataFrame = MakeDataFrameWithAllMutableColumnTypes(10); + + //Act + dataFrame.Columns.Clear(); + + //Assert + Assert.Empty(dataFrame.Columns); + + Assert.Equal(0, dataFrame.Rows.Count); + Assert.Equal(0, dataFrame.Columns.LongCount()); } [Fact]