@@ -21,16 +21,88 @@ import org.apache.spark.sql.types._
2121import org .apache .spark .unsafe .PlatformDependent
2222import org .apache .spark .unsafe .array .ByteArrayMethods
2323
24- /** Write a column into an UnsafeRow */
24+ /**
25+ * Converts Rows into UnsafeRow format. This class is NOT thread-safe.
26+ *
27+ * @param fieldTypes the data types of the row's columns.
28+ */
29+ class UnsafeRowConverter (fieldTypes : Array [DataType ]) {
30+
31+ def this (schema : StructType ) {
32+ this (schema.fields.map(_.dataType))
33+ }
34+
35+ /** Re-used pointer to the unsafe row being written */
36+ private [this ] val unsafeRow = new UnsafeRow ()
37+
38+ /** Functions for encoding each column */
39+ private [this ] val writers : Array [UnsafeColumnWriter [Any ]] = {
40+ fieldTypes.map(t => UnsafeColumnWriter .forType(t).asInstanceOf [UnsafeColumnWriter [Any ]])
41+ }
42+
43+ /** The size, in bytes, of the fixed-length portion of the row, including the null bitmap */
44+ private [this ] val fixedLengthSize : Int =
45+ (8 * fieldTypes.length) + UnsafeRow .calculateBitSetWidthInBytes(fieldTypes.length)
46+
47+ /**
48+ * Compute the amount of space, in bytes, required to encode the given row.
49+ */
50+ def getSizeRequirement (row : Row ): Int = {
51+ var fieldNumber = 0
52+ var variableLengthFieldSize : Int = 0
53+ while (fieldNumber < writers.length) {
54+ if (! row.isNullAt(fieldNumber)) {
55+ variableLengthFieldSize += writers(fieldNumber).getSize(row(fieldNumber))
56+ }
57+ fieldNumber += 1
58+ }
59+ fixedLengthSize + variableLengthFieldSize
60+ }
61+
62+ /**
63+ * Convert the given row into UnsafeRow format.
64+ *
65+ * @param row the row to convert
66+ * @param baseObject the base object of the destination address
67+ * @param baseOffset the base offset of the destination address
68+ * @return the number of bytes written. This should be equal to `getSizeRequirement(row)`.
69+ */
70+ def writeRow (row : Row , baseObject : Object , baseOffset : Long ): Long = {
71+ unsafeRow.pointTo(baseObject, baseOffset, writers.length, null )
72+ var fieldNumber = 0
73+ var appendCursor : Int = fixedLengthSize
74+ while (fieldNumber < writers.length) {
75+ if (row.isNullAt(fieldNumber)) {
76+ unsafeRow.setNullAt(fieldNumber)
77+ // TODO: type-specific null value writing?
78+ } else {
79+ appendCursor += writers(fieldNumber).write(
80+ row(fieldNumber),
81+ fieldNumber,
82+ unsafeRow,
83+ baseObject,
84+ baseOffset,
85+ appendCursor)
86+ }
87+ fieldNumber += 1
88+ }
89+ appendCursor
90+ }
91+
92+ }
93+
94+ /**
95+ * Function for writing a column into an UnsafeRow.
96+ */
2597private abstract class UnsafeColumnWriter [T ] {
2698 /**
2799 * Write a value into an UnsafeRow.
28100 *
29101 * @param value the value to write
30102 * @param columnNumber what column to write it to
31103 * @param row a pointer to the unsafe row
32- * @param baseObject
33- * @param baseOffset
104+ * @param baseObject the base object of the target row's address
105+ * @param baseOffset the base offset of the target row's address
34106 * @param appendCursor the offset from the start of the unsafe row to the end of the row;
35107 * used for calculating where variable-length data should be written
36108 * @return the number of variable-length bytes written
@@ -50,6 +122,12 @@ private abstract class UnsafeColumnWriter[T] {
50122}
51123
52124private object UnsafeColumnWriter {
125+ private object IntUnsafeColumnWriter extends IntUnsafeColumnWriter
126+ private object LongUnsafeColumnWriter extends LongUnsafeColumnWriter
127+ private object FloatUnsafeColumnWriter extends FloatUnsafeColumnWriter
128+ private object DoubleUnsafeColumnWriter extends DoubleUnsafeColumnWriter
129+ private object StringUnsafeColumnWriter extends StringUnsafeColumnWriter
130+
53131 def forType (dataType : DataType ): UnsafeColumnWriter [_] = {
54132 dataType match {
55133 case IntegerType => IntUnsafeColumnWriter
@@ -63,34 +141,7 @@ private object UnsafeColumnWriter {
63141 }
64142}
65143
66- private class StringUnsafeColumnWriter private () extends UnsafeColumnWriter [UTF8String ] {
67- def getSize (value : UTF8String ): Int = {
68- // round to nearest word
69- val numBytes = value.getBytes.length
70- 8 + ByteArrayMethods .roundNumberOfBytesToNearestWord(numBytes)
71- }
72-
73- override def write (
74- value : UTF8String ,
75- columnNumber : Int ,
76- row : UnsafeRow ,
77- baseObject : Object ,
78- baseOffset : Long ,
79- appendCursor : Int ): Int = {
80- val numBytes = value.getBytes.length
81- PlatformDependent .UNSAFE .putLong(baseObject, baseOffset + appendCursor, numBytes)
82- PlatformDependent .copyMemory(
83- value.getBytes,
84- PlatformDependent .BYTE_ARRAY_OFFSET ,
85- baseObject,
86- baseOffset + appendCursor + 8 ,
87- numBytes
88- )
89- row.setLong(columnNumber, appendCursor)
90- 8 + ByteArrayMethods .roundNumberOfBytesToNearestWord(numBytes)
91- }
92- }
93- private object StringUnsafeColumnWriter extends StringUnsafeColumnWriter
144+ // ------------------------------------------------------------------------------------------------
94145
95146private abstract class PrimitiveUnsafeColumnWriter [T ] extends UnsafeColumnWriter [T ] {
96147 def getSize (value : T ): Int = 0
@@ -108,7 +159,6 @@ private class IntUnsafeColumnWriter private() extends PrimitiveUnsafeColumnWrite
108159 0
109160 }
110161}
111- private object IntUnsafeColumnWriter extends IntUnsafeColumnWriter
112162
113163private class LongUnsafeColumnWriter private () extends PrimitiveUnsafeColumnWriter [Long ] {
114164 override def write (
@@ -122,7 +172,6 @@ private class LongUnsafeColumnWriter private() extends PrimitiveUnsafeColumnWrit
122172 0
123173 }
124174}
125- private case object LongUnsafeColumnWriter extends LongUnsafeColumnWriter
126175
127176private class FloatUnsafeColumnWriter private () extends PrimitiveUnsafeColumnWriter [Float ] {
128177 override def write (
@@ -136,7 +185,6 @@ private class FloatUnsafeColumnWriter private() extends PrimitiveUnsafeColumnWri
136185 0
137186 }
138187}
139- private case object FloatUnsafeColumnWriter extends FloatUnsafeColumnWriter
140188
141189private class DoubleUnsafeColumnWriter private () extends PrimitiveUnsafeColumnWriter [Double ] {
142190 override def write (
@@ -150,55 +198,29 @@ private class DoubleUnsafeColumnWriter private() extends PrimitiveUnsafeColumnWr
150198 0
151199 }
152200}
153- private case object DoubleUnsafeColumnWriter extends DoubleUnsafeColumnWriter
154201
155- class UnsafeRowConverter (fieldTypes : Array [DataType ]) {
156-
157- def this (schema : StructType ) {
158- this (schema.fields.map(_.dataType))
159- }
160-
161- private [this ] val unsafeRow = new UnsafeRow ()
162-
163- private [this ] val writers : Array [UnsafeColumnWriter [Any ]] = {
164- fieldTypes.map(t => UnsafeColumnWriter .forType(t).asInstanceOf [UnsafeColumnWriter [Any ]])
165- }
166-
167- private [this ] val fixedLengthSize : Int =
168- (8 * fieldTypes.length) + UnsafeRow .calculateBitSetWidthInBytes(fieldTypes.length)
169-
170- def getSizeRequirement (row : Row ): Int = {
171- var fieldNumber = 0
172- var variableLengthFieldSize : Int = 0
173- while (fieldNumber < writers.length) {
174- if (! row.isNullAt(fieldNumber)) {
175- variableLengthFieldSize += writers(fieldNumber).getSize(row(fieldNumber))
176- }
177- fieldNumber += 1
178- }
179- fixedLengthSize + variableLengthFieldSize
202+ private class StringUnsafeColumnWriter private () extends UnsafeColumnWriter [UTF8String ] {
203+ def getSize (value : UTF8String ): Int = {
204+ 8 + ByteArrayMethods .roundNumberOfBytesToNearestWord(value.getBytes.length)
180205 }
181206
182- def writeRow (row : Row , baseObject : Object , baseOffset : Long ): Long = {
183- unsafeRow.pointTo(baseObject, baseOffset, writers.length, null )
184- var fieldNumber = 0
185- var appendCursor : Int = fixedLengthSize
186- while (fieldNumber < writers.length) {
187- if (row.isNullAt(fieldNumber)) {
188- unsafeRow.setNullAt(fieldNumber)
189- // TODO: type-specific null value writing?
190- } else {
191- appendCursor += writers(fieldNumber).write(
192- row(fieldNumber),
193- fieldNumber,
194- unsafeRow,
195- baseObject,
196- baseOffset,
197- appendCursor)
198- }
199- fieldNumber += 1
200- }
201- appendCursor
207+ override def write (
208+ value : UTF8String ,
209+ columnNumber : Int ,
210+ row : UnsafeRow ,
211+ baseObject : Object ,
212+ baseOffset : Long ,
213+ appendCursor : Int ): Int = {
214+ val numBytes = value.getBytes.length
215+ PlatformDependent .UNSAFE .putLong(baseObject, baseOffset + appendCursor, numBytes)
216+ PlatformDependent .copyMemory(
217+ value.getBytes,
218+ PlatformDependent .BYTE_ARRAY_OFFSET ,
219+ baseObject,
220+ baseOffset + appendCursor + 8 ,
221+ numBytes
222+ )
223+ row.setLong(columnNumber, appendCursor)
224+ 8 + ByteArrayMethods .roundNumberOfBytesToNearestWord(numBytes)
202225 }
203-
204- }
226+ }
0 commit comments