|
| 1 | +// Licensed to the .NET Foundation under one or more agreements. |
| 2 | +// The .NET Foundation licenses this file to you under the MIT license. |
| 3 | +// See the LICENSE file in the project root for more information. |
| 4 | + |
| 5 | +using System; |
| 6 | +using System.Collections; |
| 7 | +using System.Collections.Generic; |
| 8 | +using System.Diagnostics; |
| 9 | +using System.Text; |
| 10 | +using System.Threading.Tasks; |
| 11 | +using Apache.Arrow; |
| 12 | +using Apache.Arrow.Types; |
| 13 | + |
| 14 | +namespace Microsoft.Data.Analysis |
| 15 | +{ |
| 16 | + public partial class DataFrame |
| 17 | + { |
| 18 | + /// <summary> |
| 19 | + /// Wraps a <see cref="DataFrame"/> around an Arrow <see cref="RecordBatch"/> without copying data |
| 20 | + /// </summary> |
| 21 | + /// <param name="recordBatch"></param> |
| 22 | + /// <returns><see cref="DataFrame"/></returns> |
| 23 | + public static DataFrame FromArrowRecordBatch(RecordBatch recordBatch) |
| 24 | + { |
| 25 | + DataFrame ret = new DataFrame(); |
| 26 | + Apache.Arrow.Schema arrowSchema = recordBatch.Schema; |
| 27 | + int fieldIndex = 0; |
| 28 | + IEnumerable<IArrowArray> arrowArrays = recordBatch.Arrays; |
| 29 | + foreach (IArrowArray arrowArray in arrowArrays) |
| 30 | + { |
| 31 | + Field field = arrowSchema.GetFieldByIndex(fieldIndex); |
| 32 | + IArrowType fieldType = field.DataType; |
| 33 | + DataFrameColumn dataFrameColumn = null; |
| 34 | + switch (fieldType.TypeId) |
| 35 | + { |
| 36 | + case ArrowTypeId.Boolean: |
| 37 | + BooleanArray arrowBooleanArray = (BooleanArray)arrowArray; |
| 38 | + ReadOnlyMemory<byte> valueBuffer = arrowBooleanArray.ValueBuffer.Memory; |
| 39 | + ReadOnlyMemory<byte> nullBitMapBuffer = arrowBooleanArray.NullBitmapBuffer.Memory; |
| 40 | + dataFrameColumn = new PrimitiveDataFrameColumn<bool>(field.Name, valueBuffer, nullBitMapBuffer, arrowArray.Length, arrowArray.NullCount); |
| 41 | + break; |
| 42 | + case ArrowTypeId.Double: |
| 43 | + PrimitiveArray<double> arrowDoubleArray = (PrimitiveArray<double>)arrowArray; |
| 44 | + ReadOnlyMemory<byte> doubleValueBuffer = arrowDoubleArray.ValueBuffer.Memory; |
| 45 | + ReadOnlyMemory<byte> doubleNullBitMapBuffer = arrowDoubleArray.NullBitmapBuffer.Memory; |
| 46 | + dataFrameColumn = new PrimitiveDataFrameColumn<double>(field.Name, doubleValueBuffer, doubleNullBitMapBuffer, arrowArray.Length, arrowArray.NullCount); |
| 47 | + break; |
| 48 | + case ArrowTypeId.Float: |
| 49 | + PrimitiveArray<float> arrowFloatArray = (PrimitiveArray<float>)arrowArray; |
| 50 | + ReadOnlyMemory<byte> floatValueBuffer = arrowFloatArray.ValueBuffer.Memory; |
| 51 | + ReadOnlyMemory<byte> floatNullBitMapBuffer = arrowFloatArray.NullBitmapBuffer.Memory; |
| 52 | + dataFrameColumn = new PrimitiveDataFrameColumn<float>(field.Name, floatValueBuffer, floatNullBitMapBuffer, arrowArray.Length, arrowArray.NullCount); |
| 53 | + break; |
| 54 | + case ArrowTypeId.Int8: |
| 55 | + PrimitiveArray<sbyte> arrowsbyteArray = (PrimitiveArray<sbyte>)arrowArray; |
| 56 | + ReadOnlyMemory<byte> sbyteValueBuffer = arrowsbyteArray.ValueBuffer.Memory; |
| 57 | + ReadOnlyMemory<byte> sbyteNullBitMapBuffer = arrowsbyteArray.NullBitmapBuffer.Memory; |
| 58 | + dataFrameColumn = new PrimitiveDataFrameColumn<sbyte>(field.Name, sbyteValueBuffer, sbyteNullBitMapBuffer, arrowArray.Length, arrowArray.NullCount); |
| 59 | + break; |
| 60 | + case ArrowTypeId.Int16: |
| 61 | + PrimitiveArray<short> arrowshortArray = (PrimitiveArray<short>)arrowArray; |
| 62 | + ReadOnlyMemory<byte> shortValueBuffer = arrowshortArray.ValueBuffer.Memory; |
| 63 | + ReadOnlyMemory<byte> shortNullBitMapBuffer = arrowshortArray.NullBitmapBuffer.Memory; |
| 64 | + dataFrameColumn = new PrimitiveDataFrameColumn<short>(field.Name, shortValueBuffer, shortNullBitMapBuffer, arrowArray.Length, arrowArray.NullCount); |
| 65 | + break; |
| 66 | + case ArrowTypeId.Int32: |
| 67 | + PrimitiveArray<int> arrowIntArray = (PrimitiveArray<int>)arrowArray; |
| 68 | + ReadOnlyMemory<byte> intValueBuffer = arrowIntArray.ValueBuffer.Memory; |
| 69 | + ReadOnlyMemory<byte> intNullBitMapBuffer = arrowIntArray.NullBitmapBuffer.Memory; |
| 70 | + dataFrameColumn = new PrimitiveDataFrameColumn<int>(field.Name, intValueBuffer, intNullBitMapBuffer, arrowArray.Length, arrowArray.NullCount); |
| 71 | + break; |
| 72 | + case ArrowTypeId.Int64: |
| 73 | + PrimitiveArray<long> arrowLongArray = (PrimitiveArray<long>)arrowArray; |
| 74 | + ReadOnlyMemory<byte> longValueBuffer = arrowLongArray.ValueBuffer.Memory; |
| 75 | + ReadOnlyMemory<byte> longNullBitMapBuffer = arrowLongArray.NullBitmapBuffer.Memory; |
| 76 | + dataFrameColumn = new PrimitiveDataFrameColumn<long>(field.Name, longValueBuffer, longNullBitMapBuffer, arrowArray.Length, arrowArray.NullCount); |
| 77 | + break; |
| 78 | + case ArrowTypeId.String: |
| 79 | + StringArray stringArray = (StringArray)arrowArray; |
| 80 | + ReadOnlyMemory<byte> dataMemory = stringArray.ValueBuffer.Memory; |
| 81 | + ReadOnlyMemory<byte> offsetsMemory = stringArray.ValueOffsetsBuffer.Memory; |
| 82 | + ReadOnlyMemory<byte> nullMemory = stringArray.NullBitmapBuffer.Memory; |
| 83 | + dataFrameColumn = new ArrowStringDataFrameColumn(field.Name, dataMemory, offsetsMemory, nullMemory, stringArray.Length, stringArray.NullCount); |
| 84 | + break; |
| 85 | + case ArrowTypeId.UInt8: |
| 86 | + PrimitiveArray<byte> arrowbyteArray = (PrimitiveArray<byte>)arrowArray; |
| 87 | + ReadOnlyMemory<byte> byteValueBuffer = arrowbyteArray.ValueBuffer.Memory; |
| 88 | + ReadOnlyMemory<byte> byteNullBitMapBuffer = arrowbyteArray.NullBitmapBuffer.Memory; |
| 89 | + dataFrameColumn = new PrimitiveDataFrameColumn<byte>(field.Name, byteValueBuffer, byteNullBitMapBuffer, arrowArray.Length, arrowArray.NullCount); |
| 90 | + break; |
| 91 | + case ArrowTypeId.UInt16: |
| 92 | + PrimitiveArray<ushort> arrowUshortArray = (PrimitiveArray<ushort>)arrowArray; |
| 93 | + ReadOnlyMemory<byte> ushortValueBuffer = arrowUshortArray.ValueBuffer.Memory; |
| 94 | + ReadOnlyMemory<byte> ushortNullBitMapBuffer = arrowUshortArray.NullBitmapBuffer.Memory; |
| 95 | + dataFrameColumn = new PrimitiveDataFrameColumn<ushort>(field.Name, ushortValueBuffer, ushortNullBitMapBuffer, arrowArray.Length, arrowArray.NullCount); |
| 96 | + break; |
| 97 | + case ArrowTypeId.UInt32: |
| 98 | + PrimitiveArray<uint> arrowUintArray = (PrimitiveArray<uint>)arrowArray; |
| 99 | + ReadOnlyMemory<byte> uintValueBuffer = arrowUintArray.ValueBuffer.Memory; |
| 100 | + ReadOnlyMemory<byte> uintNullBitMapBuffer = arrowUintArray.NullBitmapBuffer.Memory; |
| 101 | + dataFrameColumn = new PrimitiveDataFrameColumn<uint>(field.Name, uintValueBuffer, uintNullBitMapBuffer, arrowArray.Length, arrowArray.NullCount); |
| 102 | + break; |
| 103 | + case ArrowTypeId.UInt64: |
| 104 | + PrimitiveArray<ulong> arrowUlongArray = (PrimitiveArray<ulong>)arrowArray; |
| 105 | + ReadOnlyMemory<byte> ulongValueBuffer = arrowUlongArray.ValueBuffer.Memory; |
| 106 | + ReadOnlyMemory<byte> ulongNullBitMapBuffer = arrowUlongArray.NullBitmapBuffer.Memory; |
| 107 | + dataFrameColumn = new PrimitiveDataFrameColumn<ulong>(field.Name, ulongValueBuffer, ulongNullBitMapBuffer, arrowArray.Length, arrowArray.NullCount); |
| 108 | + break; |
| 109 | + case ArrowTypeId.Decimal: |
| 110 | + case ArrowTypeId.Binary: |
| 111 | + case ArrowTypeId.Date32: |
| 112 | + case ArrowTypeId.Date64: |
| 113 | + case ArrowTypeId.Dictionary: |
| 114 | + case ArrowTypeId.FixedSizedBinary: |
| 115 | + case ArrowTypeId.HalfFloat: |
| 116 | + case ArrowTypeId.Interval: |
| 117 | + case ArrowTypeId.List: |
| 118 | + case ArrowTypeId.Map: |
| 119 | + case ArrowTypeId.Null: |
| 120 | + case ArrowTypeId.Struct: |
| 121 | + case ArrowTypeId.Time32: |
| 122 | + case ArrowTypeId.Time64: |
| 123 | + default: |
| 124 | + throw new NotImplementedException(nameof(fieldType.Name)); |
| 125 | + } |
| 126 | + ret.Columns.Insert(ret.Columns.Count, dataFrameColumn); |
| 127 | + fieldIndex++; |
| 128 | + } |
| 129 | + return ret; |
| 130 | + } |
| 131 | + |
| 132 | + /// <summary> |
| 133 | + /// Returns an <see cref="IEnumerable{RecordBatch}"/> without copying data |
| 134 | + /// </summary> |
| 135 | + public IEnumerable<RecordBatch> ToArrowRecordBatches() |
| 136 | + { |
| 137 | + Apache.Arrow.Schema.Builder schemaBuilder = new Apache.Arrow.Schema.Builder(); |
| 138 | + |
| 139 | + int columnCount = Columns.Count; |
| 140 | + for (int i = 0; i < columnCount; i++) |
| 141 | + { |
| 142 | + DataFrameColumn column = Columns[i]; |
| 143 | + Field field = column.GetArrowField(); |
| 144 | + schemaBuilder.Field(field); |
| 145 | + } |
| 146 | + |
| 147 | + Schema schema = schemaBuilder.Build(); |
| 148 | + List<Apache.Arrow.Array> arrays = new List<Apache.Arrow.Array>(); |
| 149 | + |
| 150 | + int recordBatchLength = Int32.MaxValue; |
| 151 | + int numberOfRowsInThisRecordBatch = (int)Math.Min(recordBatchLength, RowCount); |
| 152 | + long numberOfRowsProcessed = 0; |
| 153 | + |
| 154 | + // Sometimes .NET for Spark passes in DataFrames with no rows. In those cases, we just return a RecordBatch with the right Schema and no rows |
| 155 | + do |
| 156 | + { |
| 157 | + for (int i = 0; i < columnCount; i++) |
| 158 | + { |
| 159 | + DataFrameColumn column = Columns[i]; |
| 160 | + numberOfRowsInThisRecordBatch = (int)Math.Min(numberOfRowsInThisRecordBatch, column.GetMaxRecordBatchLength(numberOfRowsProcessed)); |
| 161 | + } |
| 162 | + for (int i = 0; i < columnCount; i++) |
| 163 | + { |
| 164 | + DataFrameColumn column = Columns[i]; |
| 165 | + arrays.Add(column.ToArrowArray(numberOfRowsProcessed, numberOfRowsInThisRecordBatch)); |
| 166 | + } |
| 167 | + numberOfRowsProcessed += numberOfRowsInThisRecordBatch; |
| 168 | + yield return new RecordBatch(schema, arrays, numberOfRowsInThisRecordBatch); |
| 169 | + } while (numberOfRowsProcessed < RowCount); |
| 170 | + } |
| 171 | + |
| 172 | + } |
| 173 | +} |
0 commit comments