Skip to content

Commit 51fb84a

Browse files
nastracloud-fan
authored andcommitted
[SPARK-50624][SQL] Add TimestampNTZType to ColumnarRow/MutableColumnarRow
### What changes were proposed in this pull request? Noticed that this was missing when using this in Iceberg. See additional details in apache/iceberg#11815 (comment) ### Why are the changes needed? To be able to read `TimestampNTZType` when using `ColumnarRow` ### Does this PR introduce _any_ user-facing change? no ### How was this patch tested? Added some unit tests that failed without the fix ### Was this patch authored or co-authored using generative AI tooling? No Closes apache#49437 from nastra/SPARK-50624. Authored-by: Eduard Tudenhoefner <[email protected]> Signed-off-by: Wenchen Fan <[email protected]> (cherry picked from commit d7545d0) Signed-off-by: Wenchen Fan <[email protected]>
1 parent 00b3833 commit 51fb84a

File tree

4 files changed

+49
-0
lines changed

4 files changed

+49
-0
lines changed

sql/catalyst/src/main/java/org/apache/spark/sql/vectorized/ColumnarRow.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -183,6 +183,8 @@ public Object get(int ordinal, DataType dataType) {
183183
return getInt(ordinal);
184184
} else if (dataType instanceof TimestampType) {
185185
return getLong(ordinal);
186+
} else if (dataType instanceof TimestampNTZType) {
187+
return getLong(ordinal);
186188
} else if (dataType instanceof ArrayType) {
187189
return getArray(ordinal);
188190
} else if (dataType instanceof StructType) {

sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/MutableColumnarRow.java

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,8 @@ public InternalRow copy() {
8080
row.setInt(i, getInt(i));
8181
} else if (dt instanceof TimestampType) {
8282
row.setLong(i, getLong(i));
83+
} else if (dt instanceof TimestampNTZType) {
84+
row.setLong(i, getLong(i));
8385
} else if (dt instanceof StructType) {
8486
row.update(i, getStruct(i, ((StructType) dt).fields().length).copy());
8587
} else if (dt instanceof ArrayType) {
@@ -185,6 +187,8 @@ public Object get(int ordinal, DataType dataType) {
185187
return getInt(ordinal);
186188
} else if (dataType instanceof TimestampType) {
187189
return getLong(ordinal);
190+
} else if (dataType instanceof TimestampNTZType) {
191+
return getLong(ordinal);
188192
} else if (dataType instanceof ArrayType) {
189193
return getArray(ordinal);
190194
} else if (dataType instanceof StructType) {

sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnVectorSuite.scala

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -271,6 +271,19 @@ class ColumnVectorSuite extends SparkFunSuite {
271271
}
272272
}
273273

274+
testVectors("mutable ColumnarRow with TimestampNTZType", 10, TimestampNTZType) { testVector =>
275+
val mutableRow = new MutableColumnarRow(Array(testVector))
276+
(0 until 10).foreach { i =>
277+
mutableRow.rowId = i
278+
mutableRow.setLong(0, 10 - i)
279+
}
280+
(0 until 10).foreach { i =>
281+
mutableRow.rowId = i
282+
assert(mutableRow.get(0, TimestampNTZType) === (10 - i))
283+
assert(mutableRow.copy().get(0, TimestampNTZType) === (10 - i))
284+
}
285+
}
286+
274287
val arrayType: ArrayType = ArrayType(IntegerType, containsNull = true)
275288
testVectors("array", 10, arrayType) { testVector =>
276289

@@ -381,18 +394,24 @@ class ColumnVectorSuite extends SparkFunSuite {
381394
}
382395

383396
val structType: StructType = new StructType().add("int", IntegerType).add("double", DoubleType)
397+
.add("ts", TimestampNTZType)
384398
testVectors("struct", 10, structType) { testVector =>
385399
val c1 = testVector.getChild(0)
386400
val c2 = testVector.getChild(1)
401+
val c3 = testVector.getChild(2)
387402
c1.putInt(0, 123)
388403
c2.putDouble(0, 3.45)
404+
c3.putLong(0, 1000L)
389405
c1.putInt(1, 456)
390406
c2.putDouble(1, 5.67)
407+
c3.putLong(1, 2000L)
391408

392409
assert(testVector.getStruct(0).get(0, IntegerType) === 123)
393410
assert(testVector.getStruct(0).get(1, DoubleType) === 3.45)
411+
assert(testVector.getStruct(0).get(2, TimestampNTZType) === 1000L)
394412
assert(testVector.getStruct(1).get(0, IntegerType) === 456)
395413
assert(testVector.getStruct(1).get(1, DoubleType) === 5.67)
414+
assert(testVector.getStruct(1).get(2, TimestampNTZType) === 2000L)
396415
}
397416

398417
testVectors("SPARK-44805: getInts with dictionary", 3, IntegerType) { testVector =>

sql/core/src/test/scala/org/apache/spark/sql/vectorized/ArrowColumnVectorSuite.scala

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -515,4 +515,28 @@ class ArrowColumnVectorSuite extends SparkFunSuite {
515515
columnVector.close()
516516
allocator.close()
517517
}
518+
519+
test("struct with TimestampNTZType") {
520+
val allocator = ArrowUtils.rootAllocator.newChildAllocator("struct", 0, Long.MaxValue)
521+
val schema = new StructType().add("ts", TimestampNTZType)
522+
val vector = ArrowUtils.toArrowField("struct", schema, nullable = true, null)
523+
.createVector(allocator).asInstanceOf[StructVector]
524+
vector.allocateNew()
525+
val timestampVector = vector.getChildByOrdinal(0).asInstanceOf[TimeStampMicroVector]
526+
527+
vector.setIndexDefined(0)
528+
timestampVector.setSafe(0, 1000L)
529+
530+
timestampVector.setValueCount(1)
531+
vector.setValueCount(1)
532+
533+
val columnVector = new ArrowColumnVector(vector)
534+
assert(columnVector.dataType === schema)
535+
536+
val row0 = columnVector.getStruct(0)
537+
assert(row0.get(0, TimestampNTZType) === 1000L)
538+
539+
columnVector.close()
540+
allocator.close()
541+
}
518542
}

0 commit comments

Comments
 (0)