From 2ef9147b4f65fdb506bf46333bc981014cc000e4 Mon Sep 17 00:00:00 2001 From: Becca McHenry Date: Wed, 7 Sep 2022 16:33:11 -0500 Subject: [PATCH 01/13] vbuffer file --- .../VBufferDataFrameColumn.cs | 68 +++++++++++++++++++ 1 file changed, 68 insertions(+) create mode 100644 src/Microsoft.Data.Analysis/VBufferDataFrameColumn.cs diff --git a/src/Microsoft.Data.Analysis/VBufferDataFrameColumn.cs b/src/Microsoft.Data.Analysis/VBufferDataFrameColumn.cs new file mode 100644 index 0000000000..0d8544dc04 --- /dev/null +++ b/src/Microsoft.Data.Analysis/VBufferDataFrameColumn.cs @@ -0,0 +1,68 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using System.Collections.Generic; +using System.Collections; +using System.Linq; +using System.Text; +using Microsoft.ML.Data; + +namespace Microsoft.Data.Analysis +{ + internal class VBufferDataFrameColumn : DataFrameColumn, IEnumerable> + { + public VBufferDataFrameColumn(string name, long length, Type type) : base(name, length, type) + { + } + + public override long NullCount => throw new NotImplementedException(); + + public IEnumerator> GetEnumerator() + { + throw new NotImplementedException(); + } + + public override Dictionary> GetGroupedOccurrences(DataFrameColumn other, out HashSet otherColumnNullIndices) + { + throw new NotImplementedException(); + } + + protected override IEnumerator GetEnumeratorCore() + { + throw new NotImplementedException(); + } + + protected override object GetValue(long rowIndex) + { + throw new NotImplementedException(); + } + + protected override IReadOnlyList GetValues(long startIndex, int length) + { + throw new NotImplementedException(); + } + + protected override void SetValue(long rowIndex, object value) + { + if (value == null || value is VBuffer) + { + int bufferIndex = GetBufferIndexContainingRowIndex(ref rowIndex); + var oldValue = this[rowIndex]; + _stringBuffers[bufferIndex][(int)rowIndex] = (string)value; + if (oldValue != (string)value) + { + if (value == null) + _nullCount++; + if (oldValue == null && _nullCount > 0) + _nullCount--; + } + } + else + { + throw new ArgumentException(string.Format(Strings.MismatchedValueType, typeof(VBuffer)), nameof(value)); + } + } + } +} From 0e81765b95ebd7e408f4ea3d86855454698e8f24 Mon Sep 17 00:00:00 2001 From: Becca McHenry Date: Mon, 19 Sep 2022 11:44:03 -0500 Subject: [PATCH 02/13] updates to vbuffer file --- .../VBufferDataFrameColumn.cs | 75 ++++++++++++++----- 1 file changed, 58 insertions(+), 17 deletions(-) diff --git a/src/Microsoft.Data.Analysis/VBufferDataFrameColumn.cs b/src/Microsoft.Data.Analysis/VBufferDataFrameColumn.cs index 0d8544dc04..f135688787 100644 --- a/src/Microsoft.Data.Analysis/VBufferDataFrameColumn.cs +++ b/src/Microsoft.Data.Analysis/VBufferDataFrameColumn.cs @@ -3,27 +3,62 @@ // See the LICENSE file in the project root for more information. using System; -using System.Collections.Generic; using System.Collections; -using System.Linq; -using System.Text; +using System.Collections.Generic; +using System.Diagnostics; +using Apache.Arrow; +using Apache.Arrow.Types; +using Microsoft.ML; using Microsoft.ML.Data; namespace Microsoft.Data.Analysis { - internal class VBufferDataFrameColumn : DataFrameColumn, IEnumerable> + public class VBufferDataFrameColumn : DataFrameColumn, IEnumerable> { + private readonly List>> _vBuffers = new List>>(); // To store more than intMax number of strings + public VBufferDataFrameColumn(string name, long length, Type type) : base(name, length, type) { + int numberOfBuffersRequired = Math.Max((int)(length / int.MaxValue), 1); + for (int i = 0; i < numberOfBuffersRequired; i++) + { + long bufferLen = length - _vBuffers.Count * int.MaxValue; + List> buffer = new List>((int)Math.Min(int.MaxValue, bufferLen)); + _vBuffers.Add(buffer); + for (int j = 0; j < bufferLen; j++) + { + buffer.Add(default); + } + } } - public override long NullCount => throw new NotImplementedException(); + public VBufferDataFrameColumn(string name, IEnumerable> values) : base(name, 0, typeof(VBuffer)) + { + values = values ?? throw new ArgumentNullException(nameof(values)); + if (_vBuffers.Count == 0) + { + _vBuffers.Add(new List>()); + } + foreach (var value in values) + { + Append(value); + } + } - public IEnumerator> GetEnumerator() + public void Append(VBuffer value) { - throw new NotImplementedException(); + List> lastBuffer = _vBuffers[_vBuffers.Count - 1]; + if (lastBuffer.Count == int.MaxValue) + { + lastBuffer = new List>(); + _vBuffers.Add(lastBuffer); + } + lastBuffer.Add(value); + Length++; } + public override long NullCount => throw new NotImplementedException(); + public override Dictionary> GetGroupedOccurrences(DataFrameColumn other, out HashSet otherColumnNullIndices) { throw new NotImplementedException(); @@ -46,23 +81,29 @@ protected override IReadOnlyList GetValues(long startIndex, int length) protected override void SetValue(long rowIndex, object value) { - if (value == null || value is VBuffer) + if (value == null || value is VBuffer) { int bufferIndex = GetBufferIndexContainingRowIndex(ref rowIndex); - var oldValue = this[rowIndex]; - _stringBuffers[bufferIndex][(int)rowIndex] = (string)value; - if (oldValue != (string)value) - { - if (value == null) - _nullCount++; - if (oldValue == null && _nullCount > 0) - _nullCount--; - } + _vBuffers[bufferIndex][(int)rowIndex] = (VBuffer)value; } else { throw new ArgumentException(string.Format(Strings.MismatchedValueType, typeof(VBuffer)), nameof(value)); } } + + private int GetBufferIndexContainingRowIndex(ref long rowIndex) + { + if (rowIndex > Length) + { + throw new ArgumentOutOfRangeException(Strings.ColumnIndexOutOfRange, nameof(rowIndex)); + } + return (int)(rowIndex / int.MaxValue); + } + + IEnumerator> IEnumerable>.GetEnumerator() + { + throw new NotImplementedException(); + } } } From 6322303fe33a8812d5a78c1f802ccd5fea9da42c Mon Sep 17 00:00:00 2001 From: Becca McHenry Date: Mon, 3 Oct 2022 12:05:41 -0400 Subject: [PATCH 03/13] vector --- .../IDataView.Extension.cs | 8 + .../VBufferDataFrameColumn.cs | 242 +++++++++++++++--- src/Microsoft.ML.DataView/VectorType.cs | 13 + .../DataFrameIDataViewTests.cs | 58 +++++ .../DataFrameTests.cs | 11 + .../UnitTests/TestVBuffer.cs | 14 +- 6 files changed, 303 insertions(+), 43 deletions(-) diff --git a/src/Microsoft.Data.Analysis/IDataView.Extension.cs b/src/Microsoft.Data.Analysis/IDataView.Extension.cs index 32b97d365a..c5a9743926 100644 --- a/src/Microsoft.Data.Analysis/IDataView.Extension.cs +++ b/src/Microsoft.Data.Analysis/IDataView.Extension.cs @@ -112,6 +112,14 @@ public static DataFrame ToDataFrame(this IDataView dataView, long maxRows, param { dataFrameColumns.Add(new StringDataFrameColumn(dataViewColumn.Name)); } + else if (type is VectorDataViewType vectoryType) + // type.ToString() == "Vector") //== VectorDataViewType.Instance) + { + var itemType = vectoryType.ItemType; + //type.ItemType && type.Size + var subType = dataViewColumn.Annotations; + dataFrameColumns.Add(new VBufferDataFrameColumn(dataViewColumn.Name)); + } else { throw new NotSupportedException(String.Format(Microsoft.Data.Strings.NotSupportedColumnType, type.RawType.Name)); diff --git a/src/Microsoft.Data.Analysis/VBufferDataFrameColumn.cs b/src/Microsoft.Data.Analysis/VBufferDataFrameColumn.cs index f135688787..bcb3b62b64 100644 --- a/src/Microsoft.Data.Analysis/VBufferDataFrameColumn.cs +++ b/src/Microsoft.Data.Analysis/VBufferDataFrameColumn.cs @@ -6,6 +6,9 @@ using System.Collections; using System.Collections.Generic; using System.Diagnostics; +using System.Linq; +using System.Runtime.InteropServices; +using System.Text; using Apache.Arrow; using Apache.Arrow.Types; using Microsoft.ML; @@ -13,23 +16,27 @@ namespace Microsoft.Data.Analysis { - public class VBufferDataFrameColumn : DataFrameColumn, IEnumerable> + /// + /// An immutable column to hold Arrow style strings + /// + public partial class VBufferDataFrameColumn : DataFrameColumn, IEnumerable> { + private readonly IList>> _dataBuffers; + private readonly List>> _vBuffers = new List>>(); // To store more than intMax number of strings - public VBufferDataFrameColumn(string name, long length, Type type) : base(name, length, type) + /// + /// Constructs an empty with the given . + /// + /// The name of the column. + public VBufferDataFrameColumn(string name) : base(name, 0, typeof(VBuffer)) { - int numberOfBuffersRequired = Math.Max((int)(length / int.MaxValue), 1); - for (int i = 0; i < numberOfBuffersRequired; i++) - { - long bufferLen = length - _vBuffers.Count * int.MaxValue; - List> buffer = new List>((int)Math.Min(int.MaxValue, bufferLen)); - _vBuffers.Add(buffer); - for (int j = 0; j < bufferLen; j++) - { - buffer.Add(default); - } - } + _dataBuffers = new List>>(); + } + + public VBufferDataFrameColumn(string name, Type T) : base(name, 0, typeof(VBuffer)) + { + _dataBuffers = new List>>(); } public VBufferDataFrameColumn(string name, IEnumerable> values) : base(name, 0, typeof(VBuffer)) @@ -45,65 +52,228 @@ public VBufferDataFrameColumn(string name, IEnumerable> values) : bas } } - public void Append(VBuffer value) + /// + /// Constructs an with the given , and . The , and are the contents of the column in the Arrow format. + /// + /// The name of the column. + /// The Arrow formatted string values in this column. + /// The Arrow formatted offsets in this column. + /// The Arrow formatted null bits in this column. + /// The length of the column. + /// The number of values in this column. + public VBufferDataFrameColumn(string name, List> values, ReadOnlyMemory offsets, ReadOnlyMemory nullBits, int length, int nullCount) : base(name, length, typeof(string)) { - List> lastBuffer = _vBuffers[_vBuffers.Count - 1]; - if (lastBuffer.Count == int.MaxValue) + List> dataBuffer = new List>(values); + + _dataBuffers = new List>>(); + _dataBuffers.Add(dataBuffer); + + _nullCount = nullCount; + } + + private readonly long _nullCount; + + /// + public override long NullCount => _nullCount; + + /// + /// Indicates if the value at this is . + /// + /// The index to look up. + /// A boolean value indicating the validity at this . + public bool IsValid(long index) => NullCount == 0; + + /// + /// Returns an enumeration of immutable buffers representing the underlying values in the Apache Arrow format + /// + /// values are encoded in the buffers returned by GetReadOnlyNullBitmapBuffers in the Apache Arrow format + /// The offsets buffers returned by GetReadOnlyOffsetBuffers can be used to delineate each value + /// An enumeration of whose elements are the raw data buffers for the UTF8 string values. + public IEnumerable>> GetReadOnlyDataBuffers() + { + for (int i = 0; i < _dataBuffers.Count; i++) { - lastBuffer = new List>(); - _vBuffers.Add(lastBuffer); + // todo - performance + List> buffer = _dataBuffers.ElementAt(i); + yield return buffer; } - lastBuffer.Add(value); + } + + private void Append(VBuffer value) + { Length++; + _dataBuffers.Add(new List>() { value }); } - public override long NullCount => throw new NotImplementedException(); + /// + protected override object GetValue(long rowIndex) => GetValueImplementation(rowIndex); - public override Dictionary> GetGroupedOccurrences(DataFrameColumn other, out HashSet otherColumnNullIndices) + private List> GetValueImplementation(long rowIndex) { - throw new NotImplementedException(); + if (!IsValid(rowIndex)) + { + throw new ArgumentOutOfRangeException(nameof(rowIndex)); + } + return _dataBuffers.ElementAt((int)rowIndex); } - protected override IEnumerator GetEnumeratorCore() + /// + protected override IReadOnlyList GetValues(long startIndex, int length) { - throw new NotImplementedException(); + var ret = new List(); + while (ret.Count < length) + { + ret.Add(GetValueImplementation(startIndex++)); + } + return ret; } - protected override object GetValue(long rowIndex) + /// + protected override void SetValue(long rowIndex, object value) => throw new NotSupportedException(Strings.ImmutableColumn); + + + /// + /// Indexer to get values. This is an immutable column + /// + /// Zero based row index + /// The value stored at this + public new List> this[long rowIndex] { - throw new NotImplementedException(); + get => GetValueImplementation(rowIndex); + set => throw new NotSupportedException(Strings.ImmutableColumn); } - protected override IReadOnlyList GetValues(long startIndex, int length) + /// + /// Returns an enumerator that iterates through the string values in this column. + /// + public IEnumerator>> GetEnumerator() + { + for (long i = 0; i < Length; i++) + { + yield return this[i]; + } + } + + /// + protected override IEnumerator GetEnumeratorCore() => GetEnumerator(); + + /// + public override DataFrameColumn Sort(bool ascending = true) => throw new NotSupportedException(); + + /// + public override DataFrameColumn Clone(DataFrameColumn mapIndices = null, bool invertMapIndices = false, long numberOfNullsToAppend = 0) { throw new NotImplementedException(); } - protected override void SetValue(long rowIndex, object value) + /// + public VBufferDataFrameColumn FillNulls(VBuffer value, bool inPlace = false) { - if (value == null || value is VBuffer) + if (inPlace) { - int bufferIndex = GetBufferIndexContainingRowIndex(ref rowIndex); - _vBuffers[bufferIndex][(int)rowIndex] = (VBuffer)value; + // For now throw an exception if inPlace = true. + throw new NotSupportedException(); + } + + VBufferDataFrameColumn ret = new VBufferDataFrameColumn(Name); + for (long i = 0; i < Length; i++) + { + ret.Append(value); + } + return ret; + } + + protected override DataFrameColumn FillNullsImplementation(object value, bool inPlace) + { + if (value is VBuffer valueBuffer) + { + return FillNulls(valueBuffer, inPlace); } else { - throw new ArgumentException(string.Format(Strings.MismatchedValueType, typeof(VBuffer)), nameof(value)); + throw new ArgumentException(String.Format(Strings.MismatchedValueType, typeof(VBuffer)), nameof(value)); } } - private int GetBufferIndexContainingRowIndex(ref long rowIndex) + public override DataFrameColumn Clamp(U min, U max, bool inPlace = false) => throw new NotSupportedException(); + + public override DataFrameColumn Filter(U min, U max) => throw new NotSupportedException(); + + /// + protected internal override void AddDataViewColumn(DataViewSchema.Builder builder) + { + builder.AddColumn(Name, TextDataViewType.Instance); + } + + /// + protected internal override Delegate GetDataViewGetter(DataViewRowCursor cursor) + { + return CreateValueGetterDelegate(cursor); + } + + + private ValueGetter>> CreateValueGetterDelegate(DataViewRowCursor cursor) => + (ref List> value) => value = this[cursor.Position]; + + /// + /// Returns a boolean column that is the result of an elementwise equality comparison of each value in the column with + /// + public PrimitiveDataFrameColumn ElementwiseEquals(string value) + { + throw new NotImplementedException(); + } + + /// + public override PrimitiveDataFrameColumn ElementwiseEquals(U value) { - if (rowIndex > Length) + if (value is DataFrameColumn column) { - throw new ArgumentOutOfRangeException(Strings.ColumnIndexOutOfRange, nameof(rowIndex)); + return ElementwiseEquals(column); } - return (int)(rowIndex / int.MaxValue); + return ElementwiseEquals(value.ToString()); + } + + /// + public override PrimitiveDataFrameColumn ElementwiseEquals(DataFrameColumn column) + { + return StringDataFrameColumn.ElementwiseEqualsImplementation(this, column); + } + + public override Dictionary> GetGroupedOccurrences(DataFrameColumn other, out HashSet otherColumnNullIndices) + { + return GetGroupedOccurrences(other, out otherColumnNullIndices); + } + + protected internal override Delegate GetValueGetterUsingCursor(DataViewRowCursor cursor, DataViewSchema.Column schemaColumn) + { + return cursor.GetGetter>(schemaColumn); } IEnumerator> IEnumerable>.GetEnumerator() { throw new NotImplementedException(); } + + protected internal override void AddValueUsingCursor(DataViewRowCursor cursor, Delegate getter) + { + long row = cursor.Position; + VBuffer value = default; + Debug.Assert(getter != null, "Excepted getter to be valid"); + + (getter as ValueGetter>)(ref value); + + if (Length > row) + { + this[row] = new List>() { value }; + } + else if (Length == row) + { + Append(value); + } + else + { + throw new IndexOutOfRangeException(nameof(row)); + } + } } } diff --git a/src/Microsoft.ML.DataView/VectorType.cs b/src/Microsoft.ML.DataView/VectorType.cs index 574a473f1e..6c178ae86a 100644 --- a/src/Microsoft.ML.DataView/VectorType.cs +++ b/src/Microsoft.ML.DataView/VectorType.cs @@ -5,7 +5,9 @@ using System; using System.Collections.Immutable; using System.Linq; +using System.Runtime.CompilerServices; using System.Text; +using System.Threading; using Microsoft.ML.Internal.DataView; using Microsoft.ML.Internal.Utilities; @@ -34,6 +36,7 @@ public sealed class VectorDataViewType : StructuredDataViewType /// public ImmutableArray Dimensions { get; } + private static volatile VectorDataViewType _instance; /// /// Constructs a new single-dimensional vector type. /// @@ -85,6 +88,16 @@ public VectorDataViewType(PrimitiveDataViewType itemType, ImmutableArray di Size = ComputeSize(Dimensions); } + public static VectorDataViewType Instance + { + get + { + return _instance ?? + Interlocked.CompareExchange(ref _instance, new VectorDataViewType(NumberDataViewType.Single, 2), null) ?? + _instance; + } + } + private static Type GetRawType(PrimitiveDataViewType itemType) { Contracts.CheckValue(itemType, nameof(itemType)); diff --git a/test/Microsoft.Data.Analysis.Tests/DataFrameIDataViewTests.cs b/test/Microsoft.Data.Analysis.Tests/DataFrameIDataViewTests.cs index 3264815293..aa0cc10500 100644 --- a/test/Microsoft.Data.Analysis.Tests/DataFrameIDataViewTests.cs +++ b/test/Microsoft.Data.Analysis.Tests/DataFrameIDataViewTests.cs @@ -3,6 +3,7 @@ // See the LICENSE file in the project root for more information. using System; +using System.Collections.Generic; using System.Linq; using Microsoft.ML; using Microsoft.ML.Data; @@ -10,6 +11,16 @@ namespace Microsoft.Data.Analysis.Tests { + public class VectorInput + { + [LoadColumn(0, 1)] + [VectorType(2)] + public float[] Features { get; set; } + + [LoadColumn(3)] + public bool Label { get; set; } + } + public partial class DataFrameIDataViewTests { [Fact] @@ -338,6 +349,25 @@ private IDataView GetASampleIDataView() return data; } + private IDataView GetASampleIDataViewVBuffer() + { + var mlContext = new MLContext(); + + // Get a small dataset as an IEnumerable. + var enumerableOfData = new[] + { + new InputData() { Name = "Joey", FilterNext = false, Value = 1.0f }, + new InputData() { Name = "Chandler", FilterNext = false , Value = 2.0f}, + new InputData() { Name = "Ross", FilterNext = false , Value = 3.0f}, + new InputData() { Name = "Monica", FilterNext = true , Value = 4.0f}, + new InputData() { Name = "Rachel", FilterNext = true , Value = 5.0f}, + new InputData() { Name = "Phoebe", FilterNext = false , Value = 6.0f}, + }; + + IDataView data = mlContext.Data.LoadFromEnumerable(enumerableOfData); + return data; + } + private void VerifyDataFrameColumnAndDataViewColumnValues(string columnName, IDataView data, DataFrame df, int maxRows = -1) { int cc = 0; @@ -419,5 +449,33 @@ public void TestDataFrameFromIDataView_MLData_SelectColumnsAndRows() VerifyDataFrameColumnAndDataViewColumnValues("Name", data, df, 3); VerifyDataFrameColumnAndDataViewColumnValues("Value", data, df, 3); } + + [Fact] + public void VBufferTest() + { + var mlContext = new MLContext(); + + List inputs = new List() + { + new VectorInput() + { + Features = new float[] {33, 44}, + Label = true + }, + new VectorInput() + { + Features = new float[] {55, 66}, + Label = false + } + }; + + var data = mlContext.Data.LoadFromEnumerable(inputs); + + var df = data.ToDataFrame(); + + Assert.Equal(2, df.Columns.Count); + Assert.Equal(2, df.Rows.Count); + } + } } diff --git a/test/Microsoft.Data.Analysis.Tests/DataFrameTests.cs b/test/Microsoft.Data.Analysis.Tests/DataFrameTests.cs index 8f0e7bb00b..9b921b2d72 100644 --- a/test/Microsoft.Data.Analysis.Tests/DataFrameTests.cs +++ b/test/Microsoft.Data.Analysis.Tests/DataFrameTests.cs @@ -8,6 +8,7 @@ using System.Text; using Apache.Arrow; using Microsoft.ML; +using Microsoft.ML.Data; using Xunit; namespace Microsoft.Data.Analysis.Tests @@ -211,6 +212,16 @@ public DataFrame SplitTrainTest(DataFrame input, float testRatio, out DataFrame return input[trainIndices]; } + [Fact] + public void VBufferDataFrameTest() + { + var vbuf1 = new VBuffer(); + var vbuf2 = new VBuffer(); + var l1 = new List>() { vbuf1, vbuf2 }; + var column = new VBufferDataFrameColumn("vbuff", l1); + + Assert.Equal(2, column.Length); + } [Fact] public void TestIndexer() diff --git a/test/Microsoft.ML.Core.Tests/UnitTests/TestVBuffer.cs b/test/Microsoft.ML.Core.Tests/UnitTests/TestVBuffer.cs index be2af7f5a4..7948d6be22 100644 --- a/test/Microsoft.ML.Core.Tests/UnitTests/TestVBuffer.cs +++ b/test/Microsoft.ML.Core.Tests/UnitTests/TestVBuffer.cs @@ -1098,7 +1098,7 @@ private static void GeneratePair(Random rgen, int len, out VBuffer a, out const int cases = 8; Contracts.Assert(cases == Enum.GetValues(typeof(GenLogic)).Length); subcase = (GenLogic)rgen.Next(cases); - VBufferEditor bEditor; + // VBufferEditor bEditor; switch (subcase) { case GenLogic.BothDense: @@ -1116,21 +1116,21 @@ private static void GeneratePair(Random rgen, int len, out VBuffer a, out case GenLogic.BothSparseASameB: GenerateVBuffer(rgen, len, rgen.Next(len), out a); GenerateVBuffer(rgen, len, a.GetValues().Length, out b); - bEditor = VBufferEditor.CreateFromBuffer(ref b); + /*bEditor = VBufferEditor.CreateFromBuffer(ref b); for (int i = 0; i < a.GetIndices().Length; ++i) bEditor.Indices[i] = a.GetIndices()[i]; - b = bEditor.Commit(); + b = bEditor.Commit();*/ break; case GenLogic.BothSparseASubsetB: case GenLogic.BothSparseBSubsetA: GenerateVBuffer(rgen, len, rgen.Next(len), out a); GenerateVBuffer(rgen, a.GetValues().Length, rgen.Next(a.GetValues().Length), out b); - bEditor = VBufferEditor.Create(ref b, len, b.GetValues().Length); + /*bEditor = VBufferEditor.Create(ref b, len, b.GetValues().Length); for (int i = 0; i < bEditor.Values.Length; ++i) bEditor.Indices[i] = a.GetIndices()[bEditor.Indices[i]]; b = bEditor.Commit(); if (subcase == GenLogic.BothSparseASubsetB) - Utils.Swap(ref a, ref b); + Utils.Swap(ref a, ref b);*/ break; case GenLogic.BothSparseAUnrelatedB: GenerateVBuffer(rgen, len, rgen.Next(len), out a); @@ -1143,14 +1143,14 @@ private static void GeneratePair(Random rgen, int len, out VBuffer a, out if (a.GetValues().Length != 0 && b.GetValues().Length != 0 && a.GetValues().Length != b.GetValues().Length) { var aEditor = VBufferEditor.CreateFromBuffer(ref a); - bEditor = VBufferEditor.CreateFromBuffer(ref b); + /*bEditor = VBufferEditor.CreateFromBuffer(ref b); Utils.Shuffle(rgen, aEditor.Indices); aEditor.Indices.Slice(boundary).CopyTo(bEditor.Indices); GenericSpanSortHelper.Sort(aEditor.Indices, 0, boundary); GenericSpanSortHelper.Sort(bEditor.Indices, 0, bEditor.Indices.Length); a = aEditor.CommitTruncated(boundary); - b = bEditor.Commit(); + b = bEditor.Commit();*/ } if (rgen.Next(2) == 0) Utils.Swap(ref a, ref b); From 76d7a48510d40a4c605f071075fbf1037832df28 Mon Sep 17 00:00:00 2001 From: Becca McHenry Date: Fri, 7 Oct 2022 13:06:10 -0500 Subject: [PATCH 04/13] edit to single list --- .../IDataView.Extension.cs | 13 +++- .../VBufferDataFrameColumn.cs | 77 ++++--------------- src/Microsoft.ML.DataView/VectorType.cs | 11 --- 3 files changed, 29 insertions(+), 72 deletions(-) diff --git a/src/Microsoft.Data.Analysis/IDataView.Extension.cs b/src/Microsoft.Data.Analysis/IDataView.Extension.cs index c5a9743926..8415f4d3e2 100644 --- a/src/Microsoft.Data.Analysis/IDataView.Extension.cs +++ b/src/Microsoft.Data.Analysis/IDataView.Extension.cs @@ -118,7 +118,18 @@ public static DataFrame ToDataFrame(this IDataView dataView, long maxRows, param var itemType = vectoryType.ItemType; //type.ItemType && type.Size var subType = dataViewColumn.Annotations; - dataFrameColumns.Add(new VBufferDataFrameColumn(dataViewColumn.Name)); + if (itemType.RawType.FullName == "System.Single") + { + dataFrameColumns.Add(new VBufferDataFrameColumn(dataViewColumn.Name)); + } + else if (itemType.RawType.FullName == "System.Int32") + { + dataFrameColumns.Add(new VBufferDataFrameColumn(dataViewColumn.Name)); + } + else if (itemType.RawType.FullName == "System.String") + { + dataFrameColumns.Add(new VBufferDataFrameColumn(dataViewColumn.Name)); + } } else { diff --git a/src/Microsoft.Data.Analysis/VBufferDataFrameColumn.cs b/src/Microsoft.Data.Analysis/VBufferDataFrameColumn.cs index bcb3b62b64..74ba7df7b3 100644 --- a/src/Microsoft.Data.Analysis/VBufferDataFrameColumn.cs +++ b/src/Microsoft.Data.Analysis/VBufferDataFrameColumn.cs @@ -17,13 +17,13 @@ namespace Microsoft.Data.Analysis { /// - /// An immutable column to hold Arrow style strings + /// Column to hold VBuffer /// public partial class VBufferDataFrameColumn : DataFrameColumn, IEnumerable> { - private readonly IList>> _dataBuffers; + // private readonly IList>> _dataBuffers; - private readonly List>> _vBuffers = new List>>(); // To store more than intMax number of strings + private readonly List> _vBuffers = new List>(); /// /// Constructs an empty with the given . @@ -31,20 +31,18 @@ public partial class VBufferDataFrameColumn : DataFrameColumn, IEnumerableThe name of the column. public VBufferDataFrameColumn(string name) : base(name, 0, typeof(VBuffer)) { - _dataBuffers = new List>>(); - } - - public VBufferDataFrameColumn(string name, Type T) : base(name, 0, typeof(VBuffer)) - { - _dataBuffers = new List>>(); + _vBuffers = new List>(); + _nullCount = 0; } public VBufferDataFrameColumn(string name, IEnumerable> values) : base(name, 0, typeof(VBuffer)) { + _vBuffers = new List>(); + values = values ?? throw new ArgumentNullException(nameof(values)); if (_vBuffers.Count == 0) { - _vBuffers.Add(new List>()); + _vBuffers.Add(new VBuffer()); } foreach (var value in values) { @@ -52,25 +50,6 @@ public VBufferDataFrameColumn(string name, IEnumerable> values) : bas } } - /// - /// Constructs an with the given , and . The , and are the contents of the column in the Arrow format. - /// - /// The name of the column. - /// The Arrow formatted string values in this column. - /// The Arrow formatted offsets in this column. - /// The Arrow formatted null bits in this column. - /// The length of the column. - /// The number of values in this column. - public VBufferDataFrameColumn(string name, List> values, ReadOnlyMemory offsets, ReadOnlyMemory nullBits, int length, int nullCount) : base(name, length, typeof(string)) - { - List> dataBuffer = new List>(values); - - _dataBuffers = new List>>(); - _dataBuffers.Add(dataBuffer); - - _nullCount = nullCount; - } - private readonly long _nullCount; /// @@ -83,38 +62,22 @@ public VBufferDataFrameColumn(string name, List> values, ReadOnlyMemo /// A boolean value indicating the validity at this . public bool IsValid(long index) => NullCount == 0; - /// - /// Returns an enumeration of immutable buffers representing the underlying values in the Apache Arrow format - /// - /// values are encoded in the buffers returned by GetReadOnlyNullBitmapBuffers in the Apache Arrow format - /// The offsets buffers returned by GetReadOnlyOffsetBuffers can be used to delineate each value - /// An enumeration of whose elements are the raw data buffers for the UTF8 string values. - public IEnumerable>> GetReadOnlyDataBuffers() - { - for (int i = 0; i < _dataBuffers.Count; i++) - { - // todo - performance - List> buffer = _dataBuffers.ElementAt(i); - yield return buffer; - } - } - private void Append(VBuffer value) { Length++; - _dataBuffers.Add(new List>() { value }); + _vBuffers.Add(value); } /// protected override object GetValue(long rowIndex) => GetValueImplementation(rowIndex); - private List> GetValueImplementation(long rowIndex) + private VBuffer GetValueImplementation(long rowIndex) { if (!IsValid(rowIndex)) { throw new ArgumentOutOfRangeException(nameof(rowIndex)); } - return _dataBuffers.ElementAt((int)rowIndex); + return _vBuffers.ElementAt((int)rowIndex); } /// @@ -137,7 +100,7 @@ protected override IReadOnlyList GetValues(long startIndex, int length) /// /// Zero based row index /// The value stored at this - public new List> this[long rowIndex] + public new VBuffer this[long rowIndex] { get => GetValueImplementation(rowIndex); set => throw new NotSupportedException(Strings.ImmutableColumn); @@ -146,7 +109,7 @@ protected override IReadOnlyList GetValues(long startIndex, int length) /// /// Returns an enumerator that iterates through the string values in this column. /// - public IEnumerator>> GetEnumerator() + public IEnumerator> GetEnumerator() { for (long i = 0; i < Length; i++) { @@ -212,13 +175,13 @@ protected internal override Delegate GetDataViewGetter(DataViewRowCursor cursor) } - private ValueGetter>> CreateValueGetterDelegate(DataViewRowCursor cursor) => - (ref List> value) => value = this[cursor.Position]; + private ValueGetter> CreateValueGetterDelegate(DataViewRowCursor cursor) => + (ref VBuffer value) => value = this[cursor.Position]; /// /// Returns a boolean column that is the result of an elementwise equality comparison of each value in the column with /// - public PrimitiveDataFrameColumn ElementwiseEquals(string value) + public PrimitiveDataFrameColumn ElementwiseEquals(VBuffer value) { throw new NotImplementedException(); } @@ -233,12 +196,6 @@ public override PrimitiveDataFrameColumn ElementwiseEquals(U value) return ElementwiseEquals(value.ToString()); } - /// - public override PrimitiveDataFrameColumn ElementwiseEquals(DataFrameColumn column) - { - return StringDataFrameColumn.ElementwiseEqualsImplementation(this, column); - } - public override Dictionary> GetGroupedOccurrences(DataFrameColumn other, out HashSet otherColumnNullIndices) { return GetGroupedOccurrences(other, out otherColumnNullIndices); @@ -264,7 +221,7 @@ protected internal override void AddValueUsingCursor(DataViewRowCursor cursor, D if (Length > row) { - this[row] = new List>() { value }; + this[row] = value; } else if (Length == row) { diff --git a/src/Microsoft.ML.DataView/VectorType.cs b/src/Microsoft.ML.DataView/VectorType.cs index 6c178ae86a..18d4876c29 100644 --- a/src/Microsoft.ML.DataView/VectorType.cs +++ b/src/Microsoft.ML.DataView/VectorType.cs @@ -36,7 +36,6 @@ public sealed class VectorDataViewType : StructuredDataViewType /// public ImmutableArray Dimensions { get; } - private static volatile VectorDataViewType _instance; /// /// Constructs a new single-dimensional vector type. /// @@ -88,16 +87,6 @@ public VectorDataViewType(PrimitiveDataViewType itemType, ImmutableArray di Size = ComputeSize(Dimensions); } - public static VectorDataViewType Instance - { - get - { - return _instance ?? - Interlocked.CompareExchange(ref _instance, new VectorDataViewType(NumberDataViewType.Single, 2), null) ?? - _instance; - } - } - private static Type GetRawType(PrimitiveDataViewType itemType) { Contracts.CheckValue(itemType, nameof(itemType)); From 59e152a34907a8f23183aad72e8a6a263cb6663a Mon Sep 17 00:00:00 2001 From: Becca McHenry Date: Wed, 12 Oct 2022 14:26:40 -0500 Subject: [PATCH 05/13] update to list --- .../IDataView.Extension.cs | 83 ++++++-- .../VBufferDataFrameColumn.cs | 196 +++++++++++++----- .../DataFrameIDataViewTests.cs | 25 +++ 3 files changed, 237 insertions(+), 67 deletions(-) diff --git a/src/Microsoft.Data.Analysis/IDataView.Extension.cs b/src/Microsoft.Data.Analysis/IDataView.Extension.cs index 8415f4d3e2..6b99f6f2d9 100644 --- a/src/Microsoft.Data.Analysis/IDataView.Extension.cs +++ b/src/Microsoft.Data.Analysis/IDataView.Extension.cs @@ -112,24 +112,9 @@ public static DataFrame ToDataFrame(this IDataView dataView, long maxRows, param { dataFrameColumns.Add(new StringDataFrameColumn(dataViewColumn.Name)); } - else if (type is VectorDataViewType vectoryType) - // type.ToString() == "Vector") //== VectorDataViewType.Instance) + else if (type is VectorDataViewType vectorType) { - var itemType = vectoryType.ItemType; - //type.ItemType && type.Size - var subType = dataViewColumn.Annotations; - if (itemType.RawType.FullName == "System.Single") - { - dataFrameColumns.Add(new VBufferDataFrameColumn(dataViewColumn.Name)); - } - else if (itemType.RawType.FullName == "System.Int32") - { - dataFrameColumns.Add(new VBufferDataFrameColumn(dataViewColumn.Name)); - } - else if (itemType.RawType.FullName == "System.String") - { - dataFrameColumns.Add(new VBufferDataFrameColumn(dataViewColumn.Name)); - } + dataFrameColumns.Add(GetVectorDataFrame(vectorType, dataViewColumn.Name)); } else { @@ -158,6 +143,70 @@ public static DataFrame ToDataFrame(this IDataView dataView, long maxRows, param return new DataFrame(dataFrameColumns); } + + private static DataFrameColumn GetVectorDataFrame(VectorDataViewType vectorType, string name) + { + var itemType = vectorType.ItemType; + + if (itemType.RawType == typeof(bool)) + { + return new VBufferDataFrameColumn(name); + } + else if (itemType.RawType == typeof(byte)) + { + return new VBufferDataFrameColumn(name); + } + else if (itemType.RawType == typeof(double)) + { + return new VBufferDataFrameColumn(name); + } + else if (itemType.RawType == typeof(float)) + { + return new VBufferDataFrameColumn(name); + } + else if (itemType.RawType == typeof(int)) + { + return new VBufferDataFrameColumn(name); + } + else if (itemType.RawType == typeof(long)) + { + return new VBufferDataFrameColumn(name); + } + else if (itemType.RawType == typeof(sbyte)) + { + return new VBufferDataFrameColumn(name); + } + else if (itemType.RawType == typeof(short)) + { + return new VBufferDataFrameColumn(name); + } + else if (itemType.RawType == typeof(uint)) + { + return new VBufferDataFrameColumn(name); + } + else if (itemType.RawType == typeof(ulong)) + { + return new VBufferDataFrameColumn(name); + } + else if (itemType.RawType == typeof(ushort)) + { + return new VBufferDataFrameColumn(name); + } + else if (itemType.RawType == typeof(char)) + { + return new VBufferDataFrameColumn(name); + } + else if (itemType.RawType == typeof(decimal)) + { + return new VBufferDataFrameColumn(name); + } + else if (itemType.RawType == typeof(String)) + { + return new VBufferDataFrameColumn(name); + } + + throw new NotSupportedException("Specified vector subtype " + itemType.ToString() + " is not supported."); + } } } diff --git a/src/Microsoft.Data.Analysis/VBufferDataFrameColumn.cs b/src/Microsoft.Data.Analysis/VBufferDataFrameColumn.cs index 74ba7df7b3..d5d64fa2eb 100644 --- a/src/Microsoft.Data.Analysis/VBufferDataFrameColumn.cs +++ b/src/Microsoft.Data.Analysis/VBufferDataFrameColumn.cs @@ -21,28 +21,36 @@ namespace Microsoft.Data.Analysis /// public partial class VBufferDataFrameColumn : DataFrameColumn, IEnumerable> { - // private readonly IList>> _dataBuffers; + private readonly List>> _vBuffers = new List>>(); // To store more than intMax number of vbuffers - private readonly List> _vBuffers = new List>(); + //private readonly List> _vBuffers2 = new List>(); /// - /// Constructs an empty with the given . + /// Constructs an empty VBufferDataFrameColumn with the given . /// /// The name of the column. - public VBufferDataFrameColumn(string name) : base(name, 0, typeof(VBuffer)) + /// Length of values + public VBufferDataFrameColumn(string name, long length = 0) : base(name, 0, typeof(VBuffer)) { - _vBuffers = new List>(); - _nullCount = 0; + int numberOfBuffersRequired = Math.Max((int)(length / int.MaxValue), 1); + for (int i = 0; i < numberOfBuffersRequired; i++) + { + long bufferLen = length - _vBuffers.Count * int.MaxValue; + List> buffer = new List>((int)Math.Min(int.MaxValue, bufferLen)); + _vBuffers.Add(buffer); + for (int j = 0; j < bufferLen; j++) + { + buffer.Add(default); + } + } } public VBufferDataFrameColumn(string name, IEnumerable> values) : base(name, 0, typeof(VBuffer)) { - _vBuffers = new List>(); - values = values ?? throw new ArgumentNullException(nameof(values)); if (_vBuffers.Count == 0) { - _vBuffers.Add(new VBuffer()); + _vBuffers.Add(new List>()); } foreach (var value in values) { @@ -50,11 +58,21 @@ public VBufferDataFrameColumn(string name, IEnumerable> values) : bas } } - private readonly long _nullCount; + private long _nullCount; - /// public override long NullCount => _nullCount; + protected internal override void Resize(long length) + { + if (length < Length) + throw new ArgumentException(Strings.CannotResizeDown, nameof(length)); + + for (long i = Length; i < length; i++) + { + Append(default); + } + } + /// /// Indicates if the value at this is . /// @@ -62,58 +80,90 @@ public VBufferDataFrameColumn(string name, IEnumerable> values) : bas /// A boolean value indicating the validity at this . public bool IsValid(long index) => NullCount == 0; - private void Append(VBuffer value) + public void Append(VBuffer value) { + List> lastBuffer = _vBuffers[_vBuffers.Count - 1]; + if (lastBuffer.Count == int.MaxValue) + { + lastBuffer = new List>(); + _vBuffers.Add(lastBuffer); + } + lastBuffer.Add(value); + if (value.Length == 0) //TODO + _nullCount++; + Length++; Length++; - _vBuffers.Add(value); } - /// - protected override object GetValue(long rowIndex) => GetValueImplementation(rowIndex); - - private VBuffer GetValueImplementation(long rowIndex) + private int GetBufferIndexContainingRowIndex(ref long rowIndex) { - if (!IsValid(rowIndex)) + if (rowIndex > Length) { - throw new ArgumentOutOfRangeException(nameof(rowIndex)); + throw new ArgumentOutOfRangeException(Strings.ColumnIndexOutOfRange, nameof(rowIndex)); } - return _vBuffers.ElementAt((int)rowIndex); + return (int)(rowIndex / int.MaxValue); + } + + protected override object GetValue(long rowIndex) + { + int bufferIndex = GetBufferIndexContainingRowIndex(ref rowIndex); + return _vBuffers[bufferIndex][(int)rowIndex]; } - /// protected override IReadOnlyList GetValues(long startIndex, int length) { var ret = new List(); - while (ret.Count < length) + int bufferIndex = GetBufferIndexContainingRowIndex(ref startIndex); + while (ret.Count < length && bufferIndex < _vBuffers.Count) { - ret.Add(GetValueImplementation(startIndex++)); + for (int i = (int)startIndex; ret.Count < length && i < _vBuffers[bufferIndex].Count; i++) + { + ret.Add(_vBuffers[bufferIndex][i]); + } + bufferIndex++; + startIndex = 0; } return ret; } - /// - protected override void SetValue(long rowIndex, object value) => throw new NotSupportedException(Strings.ImmutableColumn); - + protected override void SetValue(long rowIndex, object value) + { + if (value == null || value is VBuffer) + { + int bufferIndex = GetBufferIndexContainingRowIndex(ref rowIndex); + var oldValue = this[rowIndex]; + _vBuffers[bufferIndex][(int)rowIndex] = (VBuffer)value; + if (!oldValue.Equals((VBuffer)value)) + { + if (value == null) + _nullCount++; + if (oldValue.Length == 0 && _nullCount > 0) + _nullCount--; + } + } + else + { + throw new ArgumentException(string.Format(Strings.MismatchedValueType, typeof(VBuffer)), nameof(value)); + } + } - /// - /// Indexer to get values. This is an immutable column - /// - /// Zero based row index - /// The value stored at this public new VBuffer this[long rowIndex] { - get => GetValueImplementation(rowIndex); - set => throw new NotSupportedException(Strings.ImmutableColumn); + get => (VBuffer)GetValue(rowIndex); + set => SetValue(rowIndex, value); } /// - /// Returns an enumerator that iterates through the string values in this column. + /// Returns an enumerator that iterates through the VBuffer values in this column. /// public IEnumerator> GetEnumerator() { - for (long i = 0; i < Length; i++) + foreach (List> buffer in _vBuffers) { - yield return this[i]; + foreach (VBuffer value in buffer) + { + yield return value; + } } } @@ -165,7 +215,65 @@ protected override DataFrameColumn FillNullsImplementation(object value, bool in /// protected internal override void AddDataViewColumn(DataViewSchema.Builder builder) { - builder.AddColumn(Name, TextDataViewType.Instance); + builder.AddColumn(Name, GetDataViewType()); + } + + private static VectorDataViewType GetDataViewType() + { + if (typeof(T) == typeof(bool)) + { + return new VectorDataViewType(BooleanDataViewType.Instance); + } + else if (typeof(T) == typeof(byte)) + { + return new VectorDataViewType(NumberDataViewType.Byte); + } + else if (typeof(T) == typeof(double)) + { + return new VectorDataViewType(NumberDataViewType.Double); + } + else if (typeof(T) == typeof(float)) + { + return new VectorDataViewType(NumberDataViewType.Single); + } + else if (typeof(T) == typeof(int)) + { + return new VectorDataViewType(NumberDataViewType.Int32); + } + else if (typeof(T) == typeof(long)) + { + return new VectorDataViewType(NumberDataViewType.Int64); + } + else if (typeof(T) == typeof(sbyte)) + { + return new VectorDataViewType(NumberDataViewType.SByte); + } + else if (typeof(T) == typeof(short)) + { + return new VectorDataViewType(NumberDataViewType.Int16); + } + else if (typeof(T) == typeof(uint)) + { + return new VectorDataViewType(NumberDataViewType.UInt32); + } + else if (typeof(T) == typeof(ulong)) + { + return new VectorDataViewType(NumberDataViewType.UInt64); + } + else if (typeof(T) == typeof(ushort)) + { + return new VectorDataViewType(NumberDataViewType.UInt16); + } + else if (typeof(T) == typeof(char)) + { + return new VectorDataViewType(NumberDataViewType.UInt16); + } + else if (typeof(T) == typeof(decimal)) + { + return new VectorDataViewType(NumberDataViewType.Double); + } + + throw new NotSupportedException(); } /// @@ -174,26 +282,14 @@ protected internal override Delegate GetDataViewGetter(DataViewRowCursor cursor) return CreateValueGetterDelegate(cursor); } - private ValueGetter> CreateValueGetterDelegate(DataViewRowCursor cursor) => (ref VBuffer value) => value = this[cursor.Position]; - /// - /// Returns a boolean column that is the result of an elementwise equality comparison of each value in the column with - /// - public PrimitiveDataFrameColumn ElementwiseEquals(VBuffer value) - { - throw new NotImplementedException(); - } /// public override PrimitiveDataFrameColumn ElementwiseEquals(U value) { - if (value is DataFrameColumn column) - { - return ElementwiseEquals(column); - } - return ElementwiseEquals(value.ToString()); + throw new NotImplementedException(); } public override Dictionary> GetGroupedOccurrences(DataFrameColumn other, out HashSet otherColumnNullIndices) diff --git a/test/Microsoft.Data.Analysis.Tests/DataFrameIDataViewTests.cs b/test/Microsoft.Data.Analysis.Tests/DataFrameIDataViewTests.cs index aa0cc10500..1fd4233290 100644 --- a/test/Microsoft.Data.Analysis.Tests/DataFrameIDataViewTests.cs +++ b/test/Microsoft.Data.Analysis.Tests/DataFrameIDataViewTests.cs @@ -5,6 +5,7 @@ using System; using System.Collections.Generic; using System.Linq; +using Microsoft.Data.Analysis.Tests; using Microsoft.ML; using Microsoft.ML.Data; using Xunit; @@ -475,7 +476,31 @@ public void VBufferTest() Assert.Equal(2, df.Columns.Count); Assert.Equal(2, df.Rows.Count); + + var value = df[0, 0]; + var a = df.Preview(); } + [Fact] + public void VBufferTest2() + { + var data = new[] + { + new { + NumericData=Enumerable.Range(0,10).ToArray(), + TextData=Enumerable.Repeat("html",15).ToArray() + }, + new { + NumericData=Enumerable.Range(5,10).ToArray(), + TextData=Enumerable.Repeat("div",10).ToArray() + } + }; + + var ctx = new MLContext(); + + var idv = ctx.Data.LoadFromEnumerable(data); + + var df = idv.ToDataFrame(); + } } } From 2c7553e5d6aa20174ced7b0ee0b295c76a43198f Mon Sep 17 00:00:00 2001 From: Becca McHenry Date: Fri, 21 Oct 2022 16:46:49 -0500 Subject: [PATCH 06/13] fix tests --- .../IDataView.Extension.cs | 4 -- .../VBufferDataFrameColumn.cs | 28 +------------- .../DataFrameIDataViewTests.cs | 38 ++++++------------- 3 files changed, 12 insertions(+), 58 deletions(-) diff --git a/src/Microsoft.Data.Analysis/IDataView.Extension.cs b/src/Microsoft.Data.Analysis/IDataView.Extension.cs index 976c5783cb..58eb1ddacb 100644 --- a/src/Microsoft.Data.Analysis/IDataView.Extension.cs +++ b/src/Microsoft.Data.Analysis/IDataView.Extension.cs @@ -204,10 +204,6 @@ private static DataFrameColumn GetVectorDataFrame(VectorDataViewType vectorType, { return new VBufferDataFrameColumn(name); } - else if (itemType.RawType == typeof(String)) - { - return new VBufferDataFrameColumn(name); - } throw new NotSupportedException("Specified vector subtype " + itemType.ToString() + " is not supported."); } diff --git a/src/Microsoft.Data.Analysis/VBufferDataFrameColumn.cs b/src/Microsoft.Data.Analysis/VBufferDataFrameColumn.cs index d5d64fa2eb..9000f4e89d 100644 --- a/src/Microsoft.Data.Analysis/VBufferDataFrameColumn.cs +++ b/src/Microsoft.Data.Analysis/VBufferDataFrameColumn.cs @@ -23,8 +23,6 @@ public partial class VBufferDataFrameColumn : DataFrameColumn, IEnumerable>> _vBuffers = new List>>(); // To store more than intMax number of vbuffers - //private readonly List> _vBuffers2 = new List>(); - /// /// Constructs an empty VBufferDataFrameColumn with the given . /// @@ -92,7 +90,6 @@ public void Append(VBuffer value) if (value.Length == 0) //TODO _nullCount++; Length++; - Length++; } private int GetBufferIndexContainingRowIndex(ref long rowIndex) @@ -182,30 +179,7 @@ public override DataFrameColumn Clone(DataFrameColumn mapIndices = null, bool in /// public VBufferDataFrameColumn FillNulls(VBuffer value, bool inPlace = false) { - if (inPlace) - { - // For now throw an exception if inPlace = true. - throw new NotSupportedException(); - } - - VBufferDataFrameColumn ret = new VBufferDataFrameColumn(Name); - for (long i = 0; i < Length; i++) - { - ret.Append(value); - } - return ret; - } - - protected override DataFrameColumn FillNullsImplementation(object value, bool inPlace) - { - if (value is VBuffer valueBuffer) - { - return FillNulls(valueBuffer, inPlace); - } - else - { - throw new ArgumentException(String.Format(Strings.MismatchedValueType, typeof(VBuffer)), nameof(value)); - } + throw new NotImplementedException(); } public override DataFrameColumn Clamp(U min, U max, bool inPlace = false) => throw new NotSupportedException(); diff --git a/test/Microsoft.Data.Analysis.Tests/DataFrameIDataViewTests.cs b/test/Microsoft.Data.Analysis.Tests/DataFrameIDataViewTests.cs index 28254067d0..896c7f3827 100644 --- a/test/Microsoft.Data.Analysis.Tests/DataFrameIDataViewTests.cs +++ b/test/Microsoft.Data.Analysis.Tests/DataFrameIDataViewTests.cs @@ -16,9 +16,13 @@ public class VectorInput { [LoadColumn(0, 1)] [VectorType(2)] - public float[] Features { get; set; } + public float[] FloatFeatures { get; set; } - [LoadColumn(3)] + [LoadColumn(3, 4)] + [VectorType(2)] + public Int32[] IntFeatures { get; set; } + + [LoadColumn(5)] public bool Label { get; set; } } @@ -471,12 +475,14 @@ public void VBufferTest() { new VectorInput() { - Features = new float[] {33, 44}, + FloatFeatures = new float[] {33, 44}, + IntFeatures = new int[] {5, 6}, Label = true }, new VectorInput() { - Features = new float[] {55, 66}, + FloatFeatures = new float[] {55, 66}, + IntFeatures = new int[] {5, 6}, Label = false } }; @@ -485,33 +491,11 @@ public void VBufferTest() var df = data.ToDataFrame(); - Assert.Equal(2, df.Columns.Count); + Assert.Equal(3, df.Columns.Count); Assert.Equal(2, df.Rows.Count); var value = df[0, 0]; var a = df.Preview(); } - - [Fact] - public void VBufferTest2() - { - var data = new[] - { - new { - NumericData=Enumerable.Range(0,10).ToArray(), - TextData=Enumerable.Repeat("html",15).ToArray() - }, - new { - NumericData=Enumerable.Range(5,10).ToArray(), - TextData=Enumerable.Repeat("div",10).ToArray() - } - }; - - var ctx = new MLContext(); - - var idv = ctx.Data.LoadFromEnumerable(data); - - var df = idv.ToDataFrame(); - } } } From fe9578495ed2c50370c6f5bab63cfbedaca02654 Mon Sep 17 00:00:00 2001 From: Becca McHenry Date: Mon, 24 Oct 2022 14:40:27 -0500 Subject: [PATCH 07/13] cleanup PR --- .../VBufferDataFrameColumn.cs | 108 ++++++------------ src/Microsoft.ML.DataView/VectorType.cs | 2 - .../DataFrameIDataViewTests.cs | 82 +++++-------- .../UnitTests/TestVBuffer.cs | 14 +-- 4 files changed, 72 insertions(+), 134 deletions(-) diff --git a/src/Microsoft.Data.Analysis/VBufferDataFrameColumn.cs b/src/Microsoft.Data.Analysis/VBufferDataFrameColumn.cs index 9000f4e89d..d9611e621c 100644 --- a/src/Microsoft.Data.Analysis/VBufferDataFrameColumn.cs +++ b/src/Microsoft.Data.Analysis/VBufferDataFrameColumn.cs @@ -71,13 +71,6 @@ protected internal override void Resize(long length) } } - /// - /// Indicates if the value at this is . - /// - /// The index to look up. - /// A boolean value indicating the validity at this . - public bool IsValid(long index) => NullCount == 0; - public void Append(VBuffer value) { List> lastBuffer = _vBuffers[_vBuffers.Count - 1]; @@ -87,8 +80,6 @@ public void Append(VBuffer value) _vBuffers.Add(lastBuffer); } lastBuffer.Add(value); - if (value.Length == 0) //TODO - _nullCount++; Length++; } @@ -168,28 +159,50 @@ public IEnumerator> GetEnumerator() protected override IEnumerator GetEnumeratorCore() => GetEnumerator(); /// - public override DataFrameColumn Sort(bool ascending = true) => throw new NotSupportedException(); - - /// - public override DataFrameColumn Clone(DataFrameColumn mapIndices = null, bool invertMapIndices = false, long numberOfNullsToAppend = 0) + protected internal override void AddDataViewColumn(DataViewSchema.Builder builder) { - throw new NotImplementedException(); + builder.AddColumn(Name, GetDataViewType()); } /// - public VBufferDataFrameColumn FillNulls(VBuffer value, bool inPlace = false) + protected internal override Delegate GetDataViewGetter(DataViewRowCursor cursor) { - throw new NotImplementedException(); + return CreateValueGetterDelegate(cursor); } - public override DataFrameColumn Clamp(U min, U max, bool inPlace = false) => throw new NotSupportedException(); + private ValueGetter> CreateValueGetterDelegate(DataViewRowCursor cursor) => + (ref VBuffer value) => value = this[cursor.Position]; - public override DataFrameColumn Filter(U min, U max) => throw new NotSupportedException(); + public override Dictionary> GetGroupedOccurrences(DataFrameColumn other, out HashSet otherColumnNullIndices) + { + return GetGroupedOccurrences(other, out otherColumnNullIndices); + } - /// - protected internal override void AddDataViewColumn(DataViewSchema.Builder builder) + protected internal override Delegate GetValueGetterUsingCursor(DataViewRowCursor cursor, DataViewSchema.Column schemaColumn) { - builder.AddColumn(Name, GetDataViewType()); + return cursor.GetGetter>(schemaColumn); + } + + protected internal override void AddValueUsingCursor(DataViewRowCursor cursor, Delegate getter) + { + long row = cursor.Position; + VBuffer value = default; + Debug.Assert(getter != null, "Excepted getter to be valid"); + + (getter as ValueGetter>)(ref value); + + if (Length > row) + { + this[row] = value; + } + else if (Length == row) + { + Append(value); + } + else + { + throw new IndexOutOfRangeException(nameof(row)); + } } private static VectorDataViewType GetDataViewType() @@ -249,58 +262,5 @@ private static VectorDataViewType GetDataViewType() throw new NotSupportedException(); } - - /// - protected internal override Delegate GetDataViewGetter(DataViewRowCursor cursor) - { - return CreateValueGetterDelegate(cursor); - } - - private ValueGetter> CreateValueGetterDelegate(DataViewRowCursor cursor) => - (ref VBuffer value) => value = this[cursor.Position]; - - - /// - public override PrimitiveDataFrameColumn ElementwiseEquals(U value) - { - throw new NotImplementedException(); - } - - public override Dictionary> GetGroupedOccurrences(DataFrameColumn other, out HashSet otherColumnNullIndices) - { - return GetGroupedOccurrences(other, out otherColumnNullIndices); - } - - protected internal override Delegate GetValueGetterUsingCursor(DataViewRowCursor cursor, DataViewSchema.Column schemaColumn) - { - return cursor.GetGetter>(schemaColumn); - } - - IEnumerator> IEnumerable>.GetEnumerator() - { - throw new NotImplementedException(); - } - - protected internal override void AddValueUsingCursor(DataViewRowCursor cursor, Delegate getter) - { - long row = cursor.Position; - VBuffer value = default; - Debug.Assert(getter != null, "Excepted getter to be valid"); - - (getter as ValueGetter>)(ref value); - - if (Length > row) - { - this[row] = value; - } - else if (Length == row) - { - Append(value); - } - else - { - throw new IndexOutOfRangeException(nameof(row)); - } - } } } diff --git a/src/Microsoft.ML.DataView/VectorType.cs b/src/Microsoft.ML.DataView/VectorType.cs index 18d4876c29..574a473f1e 100644 --- a/src/Microsoft.ML.DataView/VectorType.cs +++ b/src/Microsoft.ML.DataView/VectorType.cs @@ -5,9 +5,7 @@ using System; using System.Collections.Immutable; using System.Linq; -using System.Runtime.CompilerServices; using System.Text; -using System.Threading; using Microsoft.ML.Internal.DataView; using Microsoft.ML.Internal.Utilities; diff --git a/test/Microsoft.Data.Analysis.Tests/DataFrameIDataViewTests.cs b/test/Microsoft.Data.Analysis.Tests/DataFrameIDataViewTests.cs index 896c7f3827..8e13ef9aa0 100644 --- a/test/Microsoft.Data.Analysis.Tests/DataFrameIDataViewTests.cs +++ b/test/Microsoft.Data.Analysis.Tests/DataFrameIDataViewTests.cs @@ -9,23 +9,11 @@ using Microsoft.ML; using Microsoft.ML.Data; using Xunit; +using Microsoft.ML.Trainers; + namespace Microsoft.Data.Analysis.Tests { - public class VectorInput - { - [LoadColumn(0, 1)] - [VectorType(2)] - public float[] FloatFeatures { get; set; } - - [LoadColumn(3, 4)] - [VectorType(2)] - public Int32[] IntFeatures { get; set; } - - [LoadColumn(5)] - public bool Label { get; set; } - } - public partial class DataFrameIDataViewTests { [Fact] @@ -365,25 +353,6 @@ private IDataView GetASampleIDataView() return data; } - private IDataView GetASampleIDataViewVBuffer() - { - var mlContext = new MLContext(); - - // Get a small dataset as an IEnumerable. - var enumerableOfData = new[] - { - new InputData() { Name = "Joey", FilterNext = false, Value = 1.0f }, - new InputData() { Name = "Chandler", FilterNext = false , Value = 2.0f}, - new InputData() { Name = "Ross", FilterNext = false , Value = 3.0f}, - new InputData() { Name = "Monica", FilterNext = true , Value = 4.0f}, - new InputData() { Name = "Rachel", FilterNext = true , Value = 5.0f}, - new InputData() { Name = "Phoebe", FilterNext = false , Value = 6.0f}, - }; - - IDataView data = mlContext.Data.LoadFromEnumerable(enumerableOfData); - return data; - } - private void VerifyDataFrameColumnAndDataViewColumnValues(string columnName, IDataView data, DataFrame df, int maxRows = -1) { int cc = 0; @@ -467,35 +436,46 @@ public void TestDataFrameFromIDataView_MLData_SelectColumnsAndRows() } [Fact] - public void VBufferTest() + public void TestDataFrameFromIDataView_VBufferType() { var mlContext = new MLContext(); - List inputs = new List() + var inputData = new[] { - new VectorInput() - { - FloatFeatures = new float[] {33, 44}, - IntFeatures = new int[] {5, 6}, - Label = true + new { + boolFeature = new bool[] {false, false}, + byteFeatures = new byte[] {0, 0}, + doubleFeatures = new double[] {0, 0}, + floatFeatures = new float[] {0, 0}, + intFeatures = new int[] {0, 0}, + longFeatures = new long[] {0, 0}, + sbyteFeatures = new sbyte[] {0, 0}, + shortFeatures = new short[] {0, 0}, + ushortFeatures = new ushort[] {0, 0}, + uintFeatures = new uint[] {0, 0}, + ulongFeatures = new ulong[] {0, 0}, }, - new VectorInput() - { - FloatFeatures = new float[] {55, 66}, - IntFeatures = new int[] {5, 6}, - Label = false + new { + boolFeature = new bool[] {false, false}, + byteFeatures = new byte[] {0, 0}, + doubleFeatures = new double[] {0, 0}, + floatFeatures = new float[] {1, 1}, + intFeatures = new int[] {0, 0}, + longFeatures = new long[] {0, 0}, + sbyteFeatures = new sbyte[] {0, 0}, + shortFeatures = new short[] {0, 0}, + ushortFeatures = new ushort[] {0, 0}, + uintFeatures = new uint[] {0, 0}, + ulongFeatures = new ulong[] {0, 0}, } }; - var data = mlContext.Data.LoadFromEnumerable(inputs); - + var data = mlContext.Data.LoadFromEnumerable(inputData); var df = data.ToDataFrame(); - Assert.Equal(3, df.Columns.Count); + Assert.Equal(11, df.Columns.Count); Assert.Equal(2, df.Rows.Count); - - var value = df[0, 0]; - var a = df.Preview(); } } } + diff --git a/test/Microsoft.ML.Core.Tests/UnitTests/TestVBuffer.cs b/test/Microsoft.ML.Core.Tests/UnitTests/TestVBuffer.cs index 7948d6be22..be2af7f5a4 100644 --- a/test/Microsoft.ML.Core.Tests/UnitTests/TestVBuffer.cs +++ b/test/Microsoft.ML.Core.Tests/UnitTests/TestVBuffer.cs @@ -1098,7 +1098,7 @@ private static void GeneratePair(Random rgen, int len, out VBuffer a, out const int cases = 8; Contracts.Assert(cases == Enum.GetValues(typeof(GenLogic)).Length); subcase = (GenLogic)rgen.Next(cases); - // VBufferEditor bEditor; + VBufferEditor bEditor; switch (subcase) { case GenLogic.BothDense: @@ -1116,21 +1116,21 @@ private static void GeneratePair(Random rgen, int len, out VBuffer a, out case GenLogic.BothSparseASameB: GenerateVBuffer(rgen, len, rgen.Next(len), out a); GenerateVBuffer(rgen, len, a.GetValues().Length, out b); - /*bEditor = VBufferEditor.CreateFromBuffer(ref b); + bEditor = VBufferEditor.CreateFromBuffer(ref b); for (int i = 0; i < a.GetIndices().Length; ++i) bEditor.Indices[i] = a.GetIndices()[i]; - b = bEditor.Commit();*/ + b = bEditor.Commit(); break; case GenLogic.BothSparseASubsetB: case GenLogic.BothSparseBSubsetA: GenerateVBuffer(rgen, len, rgen.Next(len), out a); GenerateVBuffer(rgen, a.GetValues().Length, rgen.Next(a.GetValues().Length), out b); - /*bEditor = VBufferEditor.Create(ref b, len, b.GetValues().Length); + bEditor = VBufferEditor.Create(ref b, len, b.GetValues().Length); for (int i = 0; i < bEditor.Values.Length; ++i) bEditor.Indices[i] = a.GetIndices()[bEditor.Indices[i]]; b = bEditor.Commit(); if (subcase == GenLogic.BothSparseASubsetB) - Utils.Swap(ref a, ref b);*/ + Utils.Swap(ref a, ref b); break; case GenLogic.BothSparseAUnrelatedB: GenerateVBuffer(rgen, len, rgen.Next(len), out a); @@ -1143,14 +1143,14 @@ private static void GeneratePair(Random rgen, int len, out VBuffer a, out if (a.GetValues().Length != 0 && b.GetValues().Length != 0 && a.GetValues().Length != b.GetValues().Length) { var aEditor = VBufferEditor.CreateFromBuffer(ref a); - /*bEditor = VBufferEditor.CreateFromBuffer(ref b); + bEditor = VBufferEditor.CreateFromBuffer(ref b); Utils.Shuffle(rgen, aEditor.Indices); aEditor.Indices.Slice(boundary).CopyTo(bEditor.Indices); GenericSpanSortHelper.Sort(aEditor.Indices, 0, boundary); GenericSpanSortHelper.Sort(bEditor.Indices, 0, bEditor.Indices.Length); a = aEditor.CommitTruncated(boundary); - b = bEditor.Commit();*/ + b = bEditor.Commit(); } if (rgen.Next(2) == 0) Utils.Swap(ref a, ref b); From 8ca27c9f2e2894b795b9ab310c18f1f5feb278ee Mon Sep 17 00:00:00 2001 From: Becca McHenry Date: Tue, 25 Oct 2022 16:06:07 -0500 Subject: [PATCH 08/13] fix all the tests --- .../VBufferDataFrameColumn.cs | 133 +++++++++++++++++- .../DataFrameIDataViewTests.cs | 15 +- .../DataFrameTests.cs | 17 +++ 3 files changed, 160 insertions(+), 5 deletions(-) diff --git a/src/Microsoft.Data.Analysis/VBufferDataFrameColumn.cs b/src/Microsoft.Data.Analysis/VBufferDataFrameColumn.cs index d9611e621c..caa786afdb 100644 --- a/src/Microsoft.Data.Analysis/VBufferDataFrameColumn.cs +++ b/src/Microsoft.Data.Analysis/VBufferDataFrameColumn.cs @@ -28,7 +28,7 @@ public partial class VBufferDataFrameColumn : DataFrameColumn, IEnumerable /// The name of the column. /// Length of values - public VBufferDataFrameColumn(string name, long length = 0) : base(name, 0, typeof(VBuffer)) + public VBufferDataFrameColumn(string name, long length = 0) : base(name, length, typeof(VBuffer)) { int numberOfBuffersRequired = Math.Max((int)(length / int.MaxValue), 1); for (int i = 0; i < numberOfBuffersRequired; i++) @@ -205,6 +205,137 @@ protected internal override void AddValueUsingCursor(DataViewRowCursor cursor, D } } + private VBufferDataFrameColumn Clone(PrimitiveDataFrameColumn boolColumn) + { + if (boolColumn.Length > Length) + throw new ArgumentException(Strings.MapIndicesExceedsColumnLenth, nameof(boolColumn)); + VBufferDataFrameColumn ret = new VBufferDataFrameColumn(Name, 0); + for (long i = 0; i < boolColumn.Length; i++) + { + bool? value = boolColumn[i]; + if (value.HasValue && value.Value == true) + ret.Append(this[i]); + } + return ret; + } + + private VBufferDataFrameColumn Clone(PrimitiveDataFrameColumn mapIndices = null, bool invertMapIndex = false) + { + if (mapIndices is null) + { + VBufferDataFrameColumn ret = new VBufferDataFrameColumn(Name, Length); + for (long i = 0; i < Length; i++) + { + ret[i] = this[i]; + } + return ret; + } + else + { + return CloneImplementation(mapIndices, invertMapIndex); + } + } + + private VBufferDataFrameColumn Clone(PrimitiveDataFrameColumn mapIndices, bool invertMapIndex = false) + { + return CloneImplementation(mapIndices, invertMapIndex); + } + + private VBufferDataFrameColumn CloneImplementation(PrimitiveDataFrameColumn mapIndices, bool invertMapIndices = false, long numberOfNullsToAppend = 0) + where U : unmanaged + { + mapIndices = mapIndices ?? throw new ArgumentNullException(nameof(mapIndices)); + VBufferDataFrameColumn ret = new VBufferDataFrameColumn(Name, mapIndices.Length); + + List> setBuffer = ret._vBuffers[0]; + long setBufferMinRange = 0; + long setBufferMaxRange = int.MaxValue; + List> getBuffer = _vBuffers[0]; + long getBufferMinRange = 0; + long getBufferMaxRange = int.MaxValue; + long maxCapacity = int.MaxValue; + if (mapIndices.DataType == typeof(long)) + { + PrimitiveDataFrameColumn longMapIndices = mapIndices as PrimitiveDataFrameColumn; + longMapIndices.ApplyElementwise((long? mapIndex, long rowIndex) => + { + long index = rowIndex; + if (invertMapIndices) + index = longMapIndices.Length - 1 - index; + if (index < setBufferMinRange || index >= setBufferMaxRange) + { + int bufferIndex = (int)(index / maxCapacity); + setBuffer = ret._vBuffers[bufferIndex]; + setBufferMinRange = bufferIndex * maxCapacity; + setBufferMaxRange = (bufferIndex + 1) * maxCapacity; + } + index -= setBufferMinRange; + + if (mapIndex.Value < getBufferMinRange || mapIndex.Value >= getBufferMaxRange) + { + int bufferIndex = (int)(mapIndex.Value / maxCapacity); + getBuffer = _vBuffers[bufferIndex]; + getBufferMinRange = bufferIndex * maxCapacity; + getBufferMaxRange = (bufferIndex + 1) * maxCapacity; + } + int bufferLocalMapIndex = (int)(mapIndex - getBufferMinRange); + VBuffer value = getBuffer[bufferLocalMapIndex]; + setBuffer[(int)index] = value; + + return mapIndex; + }); + } + else if (mapIndices.DataType == typeof(int)) + { + PrimitiveDataFrameColumn intMapIndices = mapIndices as PrimitiveDataFrameColumn; + intMapIndices.ApplyElementwise((int? mapIndex, long rowIndex) => + { + long index = rowIndex; + if (invertMapIndices) + index = intMapIndices.Length - 1 - index; + + VBuffer value = getBuffer[mapIndex.Value]; + setBuffer[(int)index] = value; + + return mapIndex; + }); + } + else + { + Debug.Assert(false, nameof(mapIndices.DataType)); + } + + return ret; + } + + public new VBufferDataFrameColumn Clone(DataFrameColumn mapIndices, bool invertMapIndices, long numberOfNullsToAppend) + { + VBufferDataFrameColumn clone; + if (!(mapIndices is null)) + { + Type dataType = mapIndices.DataType; + if (dataType != typeof(long) && dataType != typeof(int) && dataType != typeof(bool)) + throw new ArgumentException(String.Format(Strings.MultipleMismatchedValueType, typeof(long), typeof(int), typeof(bool)), nameof(mapIndices)); + if (mapIndices.DataType == typeof(long)) + clone = Clone(mapIndices as PrimitiveDataFrameColumn, invertMapIndices); + else if (dataType == typeof(int)) + clone = Clone(mapIndices as PrimitiveDataFrameColumn, invertMapIndices); + else + clone = Clone(mapIndices as PrimitiveDataFrameColumn); + } + else + { + clone = Clone(); + } + + return clone; + } + + protected override DataFrameColumn CloneImplementation(DataFrameColumn mapIndices = null, bool invertMapIndices = false, long numberOfNullsToAppend = 0) + { + return Clone(mapIndices, invertMapIndices, numberOfNullsToAppend); + } + private static VectorDataViewType GetDataViewType() { if (typeof(T) == typeof(bool)) diff --git a/test/Microsoft.Data.Analysis.Tests/DataFrameIDataViewTests.cs b/test/Microsoft.Data.Analysis.Tests/DataFrameIDataViewTests.cs index 8e13ef9aa0..4576870707 100644 --- a/test/Microsoft.Data.Analysis.Tests/DataFrameIDataViewTests.cs +++ b/test/Microsoft.Data.Analysis.Tests/DataFrameIDataViewTests.cs @@ -11,7 +11,6 @@ using Xunit; using Microsoft.ML.Trainers; - namespace Microsoft.Data.Analysis.Tests { public partial class DataFrameIDataViewTests @@ -23,7 +22,7 @@ public void TestIDataView() DataDebuggerPreview preview = dataView.Preview(); Assert.Equal(10, preview.RowView.Length); - Assert.Equal(16, preview.ColumnView.Length); + Assert.Equal(17, preview.ColumnView.Length); Assert.Equal("Byte", preview.ColumnView[0].Column.Name); Assert.Equal((byte)0, preview.ColumnView[0].Values[0]); @@ -88,6 +87,10 @@ public void TestIDataView() Assert.Equal("ArrowString", preview.ColumnView[15].Column.Name); Assert.Equal("foo".ToString(), preview.ColumnView[15].Values[0].ToString()); Assert.Equal("foo".ToString(), preview.ColumnView[15].Values[1].ToString()); + + Assert.Equal("VBuffer", preview.ColumnView[16].Column.Name); + Assert.Equal("Dense vector of size 5", preview.ColumnView[16].Values[0].ToString()); + Assert.Equal("Dense vector of size 5", preview.ColumnView[16].Values[1].ToString()); } [Fact] @@ -125,7 +128,7 @@ public void TestIDataViewWithNulls() DataDebuggerPreview preview = dataView.Preview(); Assert.Equal(length, preview.RowView.Length); - Assert.Equal(16, preview.ColumnView.Length); + Assert.Equal(17, preview.ColumnView.Length); Assert.Equal("Byte", preview.ColumnView[0].Column.Name); Assert.Equal((byte)0, preview.ColumnView[0].Values[0]); @@ -238,12 +241,16 @@ public void TestIDataViewWithNulls() Assert.Equal("foo", preview.ColumnView[15].Values[4].ToString()); Assert.Equal("", preview.ColumnView[15].Values[5].ToString()); // null row Assert.Equal("foo", preview.ColumnView[15].Values[6].ToString()); + + Assert.Equal("VBuffer", preview.ColumnView[16].Column.Name); + Assert.True(preview.ColumnView[16].Values[0] is VBuffer); + Assert.True(preview.ColumnView[16].Values[6] is VBuffer); } [Fact] public void TestDataFrameFromIDataView() { - DataFrame df = DataFrameTests.MakeDataFrameWithAllColumnTypes(10, withNulls: false); + DataFrame df = DataFrameTests.MakeDataFrameWithAllMutableAndArrowColumnTypes(10, withNulls: false); df.Columns.Remove("Char"); // Because chars are returned as uint16 by IDataView, so end up comparing CharDataFrameColumn to UInt16DataFrameColumn and fail asserts IDataView dfAsIDataView = df; DataFrame newDf = dfAsIDataView.ToDataFrame(); diff --git a/test/Microsoft.Data.Analysis.Tests/DataFrameTests.cs b/test/Microsoft.Data.Analysis.Tests/DataFrameTests.cs index 08bf71e4b9..235e1282cb 100644 --- a/test/Microsoft.Data.Analysis.Tests/DataFrameTests.cs +++ b/test/Microsoft.Data.Analysis.Tests/DataFrameTests.cs @@ -75,11 +75,28 @@ public static ArrowStringDataFrameColumn CreateArrowStringColumn(int length, boo return new ArrowStringDataFrameColumn("ArrowString", dataMemory, offsetMemory, nullMemory, length, nullCount); } + public static VBufferDataFrameColumn CreateVBufferDataFrame(int length) + { + var buffers = Enumerable.Repeat(new VBuffer(5, new[] { 0, 1, 2, 3, 4 }), length).ToArray(); + return new VBufferDataFrameColumn("VBuffer", buffers); + } + public static DataFrame MakeDataFrameWithAllColumnTypes(int length, bool withNulls = true) + { + DataFrame df = MakeDataFrameWithAllMutableAndArrowColumnTypes(length, withNulls); + + var vBufferColumn = CreateVBufferDataFrame(length); + df.Columns.Insert(df.Columns.Count, vBufferColumn); + + return df; + } + + public static DataFrame MakeDataFrameWithAllMutableAndArrowColumnTypes(int length, bool withNulls = true) { DataFrame df = MakeDataFrameWithAllMutableColumnTypes(length, withNulls); DataFrameColumn arrowStringColumn = CreateArrowStringColumn(length, withNulls); df.Columns.Insert(df.Columns.Count, arrowStringColumn); + return df; } From 9674032883692da40e9a0473f2badafa604f94d5 Mon Sep 17 00:00:00 2001 From: Becca McHenry Date: Tue, 25 Oct 2022 16:27:55 -0500 Subject: [PATCH 09/13] change vbuffertest --- test/Microsoft.Data.Analysis.Tests/DataFrameTests.cs | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/test/Microsoft.Data.Analysis.Tests/DataFrameTests.cs b/test/Microsoft.Data.Analysis.Tests/DataFrameTests.cs index 235e1282cb..06286b3ae4 100644 --- a/test/Microsoft.Data.Analysis.Tests/DataFrameTests.cs +++ b/test/Microsoft.Data.Analysis.Tests/DataFrameTests.cs @@ -230,14 +230,13 @@ public DataFrame SplitTrainTest(DataFrame input, float testRatio, out DataFrame } [Fact] - public void VBufferDataFrameTest() + public void TestVBufferColumn() { - var vbuf1 = new VBuffer(); - var vbuf2 = new VBuffer(); - var l1 = new List>() { vbuf1, vbuf2 }; - var column = new VBufferDataFrameColumn("vbuff", l1); + var vBufferColumn = CreateVBufferDataFrame(10); - Assert.Equal(2, column.Length); + Assert.Equal(10, vBufferColumn.Length); + Assert.Equal(5, vBufferColumn[0].GetValues().Length); + Assert.Equal(0, vBufferColumn[0].GetValues()[0]); } [Fact] From 2b85451a0dcfa34229e3a4a897aa76118e79477d Mon Sep 17 00:00:00 2001 From: Becca McHenry Date: Tue, 25 Oct 2022 17:15:38 -0500 Subject: [PATCH 10/13] update string --- src/Microsoft.Data.Analysis/IDataView.Extension.cs | 2 +- src/Microsoft.Data.Analysis/Strings.Designer.cs | 13 +++++++++++-- src/Microsoft.Data.Analysis/Strings.resx | 4 ++++ .../API/AutoMLExperimentExtension.cs | 3 +-- .../Microsoft.Data.Analysis.Tests.csproj | 4 +++- 5 files changed, 20 insertions(+), 6 deletions(-) diff --git a/src/Microsoft.Data.Analysis/IDataView.Extension.cs b/src/Microsoft.Data.Analysis/IDataView.Extension.cs index 58eb1ddacb..303eaf3b76 100644 --- a/src/Microsoft.Data.Analysis/IDataView.Extension.cs +++ b/src/Microsoft.Data.Analysis/IDataView.Extension.cs @@ -205,7 +205,7 @@ private static DataFrameColumn GetVectorDataFrame(VectorDataViewType vectorType, return new VBufferDataFrameColumn(name); } - throw new NotSupportedException("Specified vector subtype " + itemType.ToString() + " is not supported."); + throw new NotSupportedException(String.Format(Microsoft.Data.Strings.VectorSubTypeNotSupported, itemType.ToString())); } } diff --git a/src/Microsoft.Data.Analysis/Strings.Designer.cs b/src/Microsoft.Data.Analysis/Strings.Designer.cs index 9cbf90f38e..dd97d837ed 100644 --- a/src/Microsoft.Data.Analysis/Strings.Designer.cs +++ b/src/Microsoft.Data.Analysis/Strings.Designer.cs @@ -19,7 +19,7 @@ namespace Microsoft.Data { // class via a tool like ResGen or Visual Studio. // To add or remove a member, edit your .ResX file then rerun ResGen // with the /str option, or rebuild your VS project. - [global::System.CodeDom.Compiler.GeneratedCodeAttribute("System.Resources.Tools.StronglyTypedResourceBuilder", "16.0.0.0")] + [global::System.CodeDom.Compiler.GeneratedCodeAttribute("System.Resources.Tools.StronglyTypedResourceBuilder", "17.0.0.0")] [global::System.Diagnostics.DebuggerNonUserCodeAttribute()] [global::System.Runtime.CompilerServices.CompilerGeneratedAttribute()] internal class Strings { @@ -39,7 +39,7 @@ internal Strings() { internal static global::System.Resources.ResourceManager ResourceManager { get { if (object.ReferenceEquals(resourceMan, null)) { - global::System.Resources.ResourceManager temp = new global::System.Resources.ResourceManager("Microsoft.Data.Analysis.Strings", typeof(Strings).Assembly); + global::System.Resources.ResourceManager temp = new global::System.Resources.ResourceManager("Microsoft.Data.Analysis.Tests.Strings", typeof(Strings).Assembly); resourceMan = temp; } return resourceMan; @@ -437,5 +437,14 @@ internal static string StreamDoesntSupportReading { return ResourceManager.GetString("StreamDoesntSupportReading", resourceCulture); } } + + /// + /// Looks up a localized string similar to Specified vector subtype {0} is not supported.. + /// + internal static string VectorSubTypeNotSupported { + get { + return ResourceManager.GetString("VectorSubTypeNotSupported", resourceCulture); + } + } } } diff --git a/src/Microsoft.Data.Analysis/Strings.resx b/src/Microsoft.Data.Analysis/Strings.resx index 79764037cc..53120a4a73 100644 --- a/src/Microsoft.Data.Analysis/Strings.resx +++ b/src/Microsoft.Data.Analysis/Strings.resx @@ -243,4 +243,8 @@ Stream doesn't support reading + + Specified vector subtype {0} is not supported. + {0} vectory subtype + \ No newline at end of file diff --git a/src/Microsoft.ML.AutoML/API/AutoMLExperimentExtension.cs b/src/Microsoft.ML.AutoML/API/AutoMLExperimentExtension.cs index f16afac17e..3322d11d39 100644 --- a/src/Microsoft.ML.AutoML/API/AutoMLExperimentExtension.cs +++ b/src/Microsoft.ML.AutoML/API/AutoMLExperimentExtension.cs @@ -254,14 +254,13 @@ public static AutoMLExperiment SetGridSearchTuner(this AutoMLExperiment experime return tuner; }); - + return experiment; } /// Set checkpoint folder for . The checkpoint folder will be used to save /// temporary output, run history and many other stuff which will be used for restoring training process /// from last checkpoint and continue training. - /// /// . /// checkpoint folder. This folder will be created if not exist. /// diff --git a/test/Microsoft.Data.Analysis.Tests/Microsoft.Data.Analysis.Tests.csproj b/test/Microsoft.Data.Analysis.Tests/Microsoft.Data.Analysis.Tests.csproj index 07c3ea2c33..71767d3ebe 100644 --- a/test/Microsoft.Data.Analysis.Tests/Microsoft.Data.Analysis.Tests.csproj +++ b/test/Microsoft.Data.Analysis.Tests/Microsoft.Data.Analysis.Tests.csproj @@ -32,7 +32,9 @@ - + + + From da83f4bb81aa78cb61fa0cc814c2e2c48825901a Mon Sep 17 00:00:00 2001 From: Becca McHenry Date: Fri, 6 Jan 2023 11:44:20 -0600 Subject: [PATCH 11/13] revert csproj string issue --- .../Microsoft.Data.Analysis.Tests.csproj | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/test/Microsoft.Data.Analysis.Tests/Microsoft.Data.Analysis.Tests.csproj b/test/Microsoft.Data.Analysis.Tests/Microsoft.Data.Analysis.Tests.csproj index 71767d3ebe..872769ef74 100644 --- a/test/Microsoft.Data.Analysis.Tests/Microsoft.Data.Analysis.Tests.csproj +++ b/test/Microsoft.Data.Analysis.Tests/Microsoft.Data.Analysis.Tests.csproj @@ -32,9 +32,7 @@ - - - + From 2f76952bd5c6b8dfc095c3f096da97b5a45229e0 Mon Sep 17 00:00:00 2001 From: Becca McHenry Date: Fri, 6 Jan 2023 17:01:36 -0600 Subject: [PATCH 12/13] update string namespace --- src/Microsoft.Data.Analysis/Strings.Designer.cs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/Microsoft.Data.Analysis/Strings.Designer.cs b/src/Microsoft.Data.Analysis/Strings.Designer.cs index dd97d837ed..38cabf7ecb 100644 --- a/src/Microsoft.Data.Analysis/Strings.Designer.cs +++ b/src/Microsoft.Data.Analysis/Strings.Designer.cs @@ -39,7 +39,7 @@ internal Strings() { internal static global::System.Resources.ResourceManager ResourceManager { get { if (object.ReferenceEquals(resourceMan, null)) { - global::System.Resources.ResourceManager temp = new global::System.Resources.ResourceManager("Microsoft.Data.Analysis.Tests.Strings", typeof(Strings).Assembly); + global::System.Resources.ResourceManager temp = new global::System.Resources.ResourceManager("Microsoft.Data.Analysis.Strings", typeof(Strings).Assembly); resourceMan = temp; } return resourceMan; From 496681e6ac70f7444a9a557a72f4e7da225ffd29 Mon Sep 17 00:00:00 2001 From: Becca McHenry Date: Mon, 9 Jan 2023 14:24:10 -0600 Subject: [PATCH 13/13] update summary tag --- src/Microsoft.ML.AutoML/API/AutoMLExperimentExtension.cs | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/Microsoft.ML.AutoML/API/AutoMLExperimentExtension.cs b/src/Microsoft.ML.AutoML/API/AutoMLExperimentExtension.cs index 3322d11d39..b2437588d0 100644 --- a/src/Microsoft.ML.AutoML/API/AutoMLExperimentExtension.cs +++ b/src/Microsoft.ML.AutoML/API/AutoMLExperimentExtension.cs @@ -258,9 +258,11 @@ public static AutoMLExperiment SetGridSearchTuner(this AutoMLExperiment experime return experiment; } + /// /// Set checkpoint folder for . The checkpoint folder will be used to save /// temporary output, run history and many other stuff which will be used for restoring training process /// from last checkpoint and continue training. + /// /// . /// checkpoint folder. This folder will be created if not exist. ///