Skip to content

Commit 5582f92

Browse files
beliefercloud-fan
authored andcommitted
[SPARK-37463][SQL] Read/Write Timestamp ntz from/to Orc uses int64
### What changes were proposed in this pull request? #33588 (comment) show Spark cannot read/write timestamp ntz and ltz correctly. Based on the discussion #34741 (comment), we just to fix read/write timestamp ntz to Orc uses int64. ### Why are the changes needed? Fix the bug about read/write timestamp ntz from/to Orc with different times zone. ### Does this PR introduce _any_ user-facing change? Yes. Orc timestamp ntz is a new feature. ### How was this patch tested? New tests. Closes #34984 from beliefer/SPARK-37463-int64. Authored-by: Jiaan Geng <[email protected]> Signed-off-by: Wenchen Fan <[email protected]> (cherry picked from commit e410d98) Signed-off-by: Wenchen Fan <[email protected]>
1 parent 2238b05 commit 5582f92

File tree

10 files changed

+59
-74
lines changed

10 files changed

+59
-74
lines changed

sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcAtomicColumnVector.java

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@
2727
import org.apache.spark.sql.types.DateType;
2828
import org.apache.spark.sql.types.Decimal;
2929
import org.apache.spark.sql.types.TimestampType;
30-
import org.apache.spark.sql.types.TimestampNTZType;
3130
import org.apache.spark.sql.vectorized.ColumnarArray;
3231
import org.apache.spark.sql.vectorized.ColumnarMap;
3332
import org.apache.spark.unsafe.types.UTF8String;
@@ -37,7 +36,6 @@
3736
*/
3837
public class OrcAtomicColumnVector extends OrcColumnVector {
3938
private final boolean isTimestamp;
40-
private final boolean isTimestampNTZ;
4139
private final boolean isDate;
4240

4341
// Column vector for each type. Only 1 is populated for any type.
@@ -56,12 +54,6 @@ public class OrcAtomicColumnVector extends OrcColumnVector {
5654
isTimestamp = false;
5755
}
5856

59-
if (type instanceof TimestampNTZType) {
60-
isTimestampNTZ = true;
61-
} else {
62-
isTimestampNTZ = false;
63-
}
64-
6557
if (type instanceof DateType) {
6658
isDate = true;
6759
} else {
@@ -113,8 +105,6 @@ public long getLong(int rowId) {
113105
int index = getRowIndex(rowId);
114106
if (isTimestamp) {
115107
return DateTimeUtils.fromJavaTimestamp(timestampData.asScratchTimestamp(index));
116-
} else if (isTimestampNTZ) {
117-
return OrcUtils.fromOrcNTZ(timestampData.asScratchTimestamp(index));
118108
} else {
119109
return longData.vector[index];
120110
}

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcDeserializer.scala

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ class OrcDeserializer(
105105
case IntegerType | _: YearMonthIntervalType => (ordinal, value) =>
106106
updater.setInt(ordinal, value.asInstanceOf[IntWritable].get)
107107

108-
case LongType | _: DayTimeIntervalType => (ordinal, value) =>
108+
case LongType | _: DayTimeIntervalType | _: TimestampNTZType => (ordinal, value) =>
109109
updater.setLong(ordinal, value.asInstanceOf[LongWritable].get)
110110

111111
case FloatType => (ordinal, value) =>
@@ -129,9 +129,6 @@ class OrcDeserializer(
129129
case TimestampType => (ordinal, value) =>
130130
updater.setLong(ordinal, DateTimeUtils.fromJavaTimestamp(value.asInstanceOf[OrcTimestamp]))
131131

132-
case TimestampNTZType => (ordinal, value) =>
133-
updater.setLong(ordinal, OrcUtils.fromOrcNTZ(value.asInstanceOf[OrcTimestamp]))
134-
135132
case DecimalType.Fixed(precision, scale) => (ordinal, value) =>
136133
val v = OrcShimUtils.getDecimal(value)
137134
v.changePrecision(precision, scale)

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -142,11 +142,10 @@ class OrcFileFormat
142142

143143
val fs = filePath.getFileSystem(conf)
144144
val readerOptions = OrcFile.readerOptions(conf).filesystem(fs)
145-
val resultedColPruneInfo =
146-
Utils.tryWithResource(OrcFile.createReader(filePath, readerOptions)) { reader =>
147-
OrcUtils.requestedColumnIds(
148-
isCaseSensitive, dataSchema, requiredSchema, reader, conf)
149-
}
145+
val orcSchema =
146+
Utils.tryWithResource(OrcFile.createReader(filePath, readerOptions))(_.getSchema)
147+
val resultedColPruneInfo = OrcUtils.requestedColumnIds(
148+
isCaseSensitive, dataSchema, requiredSchema, orcSchema, conf)
150149

151150
if (resultedColPruneInfo.isEmpty) {
152151
Iterator.empty

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilters.scala

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717

1818
package org.apache.spark.sql.execution.datasources.orc
1919

20-
import java.sql.Timestamp
2120
import java.time.{Duration, Instant, LocalDate, LocalDateTime, Period}
2221

2322
import org.apache.hadoop.hive.common.`type`.HiveDecimal
@@ -143,11 +142,11 @@ private[sql] object OrcFilters extends OrcFiltersBase {
143142
def getPredicateLeafType(dataType: DataType): PredicateLeaf.Type = dataType match {
144143
case BooleanType => PredicateLeaf.Type.BOOLEAN
145144
case ByteType | ShortType | IntegerType | LongType |
146-
_: AnsiIntervalType => PredicateLeaf.Type.LONG
145+
_: AnsiIntervalType | TimestampNTZType => PredicateLeaf.Type.LONG
147146
case FloatType | DoubleType => PredicateLeaf.Type.FLOAT
148147
case StringType => PredicateLeaf.Type.STRING
149148
case DateType => PredicateLeaf.Type.DATE
150-
case TimestampType | TimestampNTZType => PredicateLeaf.Type.TIMESTAMP
149+
case TimestampType => PredicateLeaf.Type.TIMESTAMP
151150
case _: DecimalType => PredicateLeaf.Type.DECIMAL
152151
case _ => throw QueryExecutionErrors.unsupportedOperationForDataTypeError(dataType)
153152
}
@@ -170,11 +169,7 @@ private[sql] object OrcFilters extends OrcFiltersBase {
170169
case _: TimestampType if value.isInstanceOf[Instant] =>
171170
toJavaTimestamp(instantToMicros(value.asInstanceOf[Instant]))
172171
case _: TimestampNTZType if value.isInstanceOf[LocalDateTime] =>
173-
val orcTimestamp = OrcUtils.toOrcNTZ(localDateTimeToMicros(value.asInstanceOf[LocalDateTime]))
174-
// Hive meets OrcTimestamp will throw ClassNotFoundException, So convert it.
175-
val timestamp = new Timestamp(orcTimestamp.getTime)
176-
timestamp.setNanos(orcTimestamp.getNanos)
177-
timestamp
172+
localDateTimeToMicros(value.asInstanceOf[LocalDateTime])
178173
case _: YearMonthIntervalType =>
179174
IntervalUtils.periodToMonths(value.asInstanceOf[Period]).longValue()
180175
case _: DayTimeIntervalType =>

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcSerializer.scala

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ class OrcSerializer(dataSchema: StructType) {
9898
}
9999

100100

101-
case LongType | _: DayTimeIntervalType =>
101+
case LongType | _: DayTimeIntervalType | _: TimestampNTZType =>
102102
if (reuseObj) {
103103
val result = new LongWritable()
104104
(getter, ordinal) =>
@@ -147,8 +147,6 @@ class OrcSerializer(dataSchema: StructType) {
147147
result.setNanos(ts.getNanos)
148148
result
149149

150-
case TimestampNTZType => (getter, ordinal) => OrcUtils.toOrcNTZ(getter.getLong(ordinal))
151-
152150
case DecimalType.Fixed(precision, scale) =>
153151
OrcShimUtils.getHiveDecimalWritable(precision, scale)
154152

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcUtils.scala

Lines changed: 13 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
package org.apache.spark.sql.execution.datasources.orc
1919

2020
import java.nio.charset.StandardCharsets.UTF_8
21-
import java.sql.Timestamp
2221
import java.util.Locale
2322

2423
import scala.collection.JavaConverters._
@@ -29,7 +28,6 @@ import org.apache.hadoop.fs.{FileStatus, Path}
2928
import org.apache.hadoop.hive.serde2.io.DateWritable
3029
import org.apache.hadoop.io.{BooleanWritable, ByteWritable, DoubleWritable, FloatWritable, IntWritable, LongWritable, ShortWritable, WritableComparable}
3130
import org.apache.orc.{BooleanColumnStatistics, ColumnStatistics, DateColumnStatistics, DoubleColumnStatistics, IntegerColumnStatistics, OrcConf, OrcFile, Reader, TypeDescription, Writer}
32-
import org.apache.orc.mapred.OrcTimestamp
3331

3432
import org.apache.spark.{SPARK_VERSION_SHORT, SparkException}
3533
import org.apache.spark.deploy.SparkHadoopUtil
@@ -39,8 +37,8 @@ import org.apache.spark.sql.catalyst.InternalRow
3937
import org.apache.spark.sql.catalyst.analysis.caseSensitiveResolution
4038
import org.apache.spark.sql.catalyst.expressions.JoinedRow
4139
import org.apache.spark.sql.catalyst.parser.CatalystSqlParser
42-
import org.apache.spark.sql.catalyst.util.{quoteIdentifier, CharVarcharUtils, DateTimeUtils}
43-
import org.apache.spark.sql.catalyst.util.DateTimeConstants._
40+
import org.apache.spark.sql.catalyst.util.CharVarcharUtils
41+
import org.apache.spark.sql.catalyst.util.quoteIdentifier
4442
import org.apache.spark.sql.connector.expressions.aggregate.{Aggregation, Count, CountStar, Max, Min}
4543
import org.apache.spark.sql.errors.QueryExecutionErrors
4644
import org.apache.spark.sql.execution.datasources.{AggregatePushDownUtils, SchemaMergeUtils}
@@ -199,7 +197,7 @@ object OrcUtils extends Logging {
199197
isCaseSensitive: Boolean,
200198
dataSchema: StructType,
201199
requiredSchema: StructType,
202-
reader: Reader,
200+
orcSchema: TypeDescription,
203201
conf: Configuration): Option[(Array[Int], Boolean)] = {
204202
def checkTimestampCompatibility(orcCatalystSchema: StructType, dataSchema: StructType): Unit = {
205203
orcCatalystSchema.fields.map(_.dataType).zip(dataSchema.fields.map(_.dataType)).foreach {
@@ -212,7 +210,6 @@ object OrcUtils extends Logging {
212210
}
213211
}
214212

215-
val orcSchema = reader.getSchema
216213
checkTimestampCompatibility(toCatalystSchema(orcSchema), dataSchema)
217214
val orcFieldNames = orcSchema.getFieldNames.asScala
218215
val forcePositionalEvolution = OrcConf.FORCE_POSITIONAL_EVOLUTION.getBoolean(conf)
@@ -261,7 +258,6 @@ object OrcUtils extends Logging {
261258
if (matchedOrcFields.size > 1) {
262259
// Need to fail if there is ambiguity, i.e. more than one field is matched.
263260
val matchedOrcFieldsString = matchedOrcFields.mkString("[", ", ", "]")
264-
reader.close()
265261
throw QueryExecutionErrors.foundDuplicateFieldInCaseInsensitiveModeError(
266262
requiredFieldName, matchedOrcFieldsString)
267263
} else {
@@ -285,18 +281,17 @@ object OrcUtils extends Logging {
285281
* Given a `StructType` object, this methods converts it to corresponding string representation
286282
* in ORC.
287283
*/
288-
def orcTypeDescriptionString(dt: DataType): String = dt match {
284+
def getOrcSchemaString(dt: DataType): String = dt match {
289285
case s: StructType =>
290286
val fieldTypes = s.fields.map { f =>
291-
s"${quoteIdentifier(f.name)}:${orcTypeDescriptionString(f.dataType)}"
287+
s"${quoteIdentifier(f.name)}:${getOrcSchemaString(f.dataType)}"
292288
}
293289
s"struct<${fieldTypes.mkString(",")}>"
294290
case a: ArrayType =>
295-
s"array<${orcTypeDescriptionString(a.elementType)}>"
291+
s"array<${getOrcSchemaString(a.elementType)}>"
296292
case m: MapType =>
297-
s"map<${orcTypeDescriptionString(m.keyType)},${orcTypeDescriptionString(m.valueType)}>"
298-
case TimestampNTZType => TypeDescription.Category.TIMESTAMP.getName
299-
case _: DayTimeIntervalType => LongType.catalogString
293+
s"map<${getOrcSchemaString(m.keyType)},${getOrcSchemaString(m.valueType)}>"
294+
case _: DayTimeIntervalType | _: TimestampNTZType => LongType.catalogString
300295
case _: YearMonthIntervalType => IntegerType.catalogString
301296
case _ => dt.catalogString
302297
}
@@ -306,16 +301,14 @@ object OrcUtils extends Logging {
306301
dt match {
307302
case y: YearMonthIntervalType =>
308303
val typeDesc = new TypeDescription(TypeDescription.Category.INT)
309-
typeDesc.setAttribute(
310-
CATALYST_TYPE_ATTRIBUTE_NAME, y.typeName)
304+
typeDesc.setAttribute(CATALYST_TYPE_ATTRIBUTE_NAME, y.typeName)
311305
Some(typeDesc)
312306
case d: DayTimeIntervalType =>
313307
val typeDesc = new TypeDescription(TypeDescription.Category.LONG)
314-
typeDesc.setAttribute(
315-
CATALYST_TYPE_ATTRIBUTE_NAME, d.typeName)
308+
typeDesc.setAttribute(CATALYST_TYPE_ATTRIBUTE_NAME, d.typeName)
316309
Some(typeDesc)
317310
case n: TimestampNTZType =>
318-
val typeDesc = new TypeDescription(TypeDescription.Category.TIMESTAMP)
311+
val typeDesc = new TypeDescription(TypeDescription.Category.LONG)
319312
typeDesc.setAttribute(CATALYST_TYPE_ATTRIBUTE_NAME, n.typeName)
320313
Some(typeDesc)
321314
case t: TimestampType =>
@@ -378,9 +371,9 @@ object OrcUtils extends Logging {
378371
partitionSchema: StructType,
379372
conf: Configuration): String = {
380373
val resultSchemaString = if (canPruneCols) {
381-
OrcUtils.orcTypeDescriptionString(resultSchema)
374+
OrcUtils.getOrcSchemaString(resultSchema)
382375
} else {
383-
OrcUtils.orcTypeDescriptionString(StructType(dataSchema.fields ++ partitionSchema.fields))
376+
OrcUtils.getOrcSchemaString(StructType(dataSchema.fields ++ partitionSchema.fields))
384377
}
385378
OrcConf.MAPRED_INPUT_SCHEMA.setString(conf, resultSchemaString)
386379
resultSchemaString
@@ -532,17 +525,4 @@ object OrcUtils extends Logging {
532525
resultRow
533526
}
534527
}
535-
536-
def fromOrcNTZ(ts: Timestamp): Long = {
537-
DateTimeUtils.millisToMicros(ts.getTime) +
538-
(ts.getNanos / NANOS_PER_MICROS) % MICROS_PER_MILLIS
539-
}
540-
541-
def toOrcNTZ(micros: Long): OrcTimestamp = {
542-
val seconds = Math.floorDiv(micros, MICROS_PER_SECOND)
543-
val nanos = (micros - seconds * MICROS_PER_SECOND) * NANOS_PER_MICROS
544-
val result = new OrcTimestamp(seconds * MILLIS_PER_SECOND)
545-
result.setNanos(nanos.toInt)
546-
result
547-
}
548528
}

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRowConverter.scala

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -358,6 +358,8 @@ private[parquet] class ParquetRowConverter(
358358
case StringType =>
359359
new ParquetStringConverter(updater)
360360

361+
// As long as the parquet type is INT64 timestamp, whether logical annotation
362+
// `isAdjustedToUTC` is false or true, it will be read as Spark's TimestampLTZ type
361363
case TimestampType
362364
if parquetType.getLogicalTypeAnnotation.isInstanceOf[TimestampLogicalTypeAnnotation] &&
363365
parquetType.getLogicalTypeAnnotation
@@ -368,6 +370,8 @@ private[parquet] class ParquetRowConverter(
368370
}
369371
}
370372

373+
// As long as the parquet type is INT64 timestamp, whether logical annotation
374+
// `isAdjustedToUTC` is false or true, it will be read as Spark's TimestampLTZ type
371375
case TimestampType
372376
if parquetType.getLogicalTypeAnnotation.isInstanceOf[TimestampLogicalTypeAnnotation] &&
373377
parquetType.getLogicalTypeAnnotation

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcPartitionReaderFactory.scala

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -88,11 +88,9 @@ case class OrcPartitionReaderFactory(
8888
}
8989
val filePath = new Path(new URI(file.filePath))
9090

91-
val resultedColPruneInfo =
92-
Utils.tryWithResource(createORCReader(filePath, conf)) { reader =>
93-
OrcUtils.requestedColumnIds(
94-
isCaseSensitive, dataSchema, readDataSchema, reader, conf)
95-
}
91+
val orcSchema = Utils.tryWithResource(createORCReader(filePath, conf))(_.getSchema)
92+
val resultedColPruneInfo = OrcUtils.requestedColumnIds(
93+
isCaseSensitive, dataSchema, readDataSchema, orcSchema, conf)
9694

9795
if (resultedColPruneInfo.isEmpty) {
9896
new EmptyPartitionReader[InternalRow]
@@ -131,11 +129,9 @@ case class OrcPartitionReaderFactory(
131129
}
132130
val filePath = new Path(new URI(file.filePath))
133131

134-
val resultedColPruneInfo =
135-
Utils.tryWithResource(createORCReader(filePath, conf)) { reader =>
136-
OrcUtils.requestedColumnIds(
137-
isCaseSensitive, dataSchema, readDataSchema, reader, conf)
138-
}
132+
val orcSchema = Utils.tryWithResource(createORCReader(filePath, conf))(_.getSchema)
133+
val resultedColPruneInfo = OrcUtils.requestedColumnIds(
134+
isCaseSensitive, dataSchema, readDataSchema, orcSchema, conf)
139135

140136
if (resultedColPruneInfo.isEmpty) {
141137
new EmptyPartitionReader

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcWrite.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ case class OrcWrite(
4343

4444
val conf = job.getConfiguration
4545

46-
conf.set(MAPRED_OUTPUT_SCHEMA.getAttribute, OrcUtils.orcTypeDescriptionString(dataSchema))
46+
conf.set(MAPRED_OUTPUT_SCHEMA.getAttribute, OrcUtils.getOrcSchemaString(dataSchema))
4747

4848
conf.set(COMPRESS.getAttribute, orcOptions.compressionCodec)
4949

sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcQuerySuite.scala

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ import org.apache.orc.mapreduce.OrcInputFormat
3535
import org.apache.spark.{SparkConf, SparkException}
3636
import org.apache.spark.sql._
3737
import org.apache.spark.sql.catalyst.TableIdentifier
38+
import org.apache.spark.sql.catalyst.util.DateTimeTestUtils
3839
import org.apache.spark.sql.execution.FileSourceScanExec
3940
import org.apache.spark.sql.execution.datasources.{HadoopFsRelation, LogicalRelation, RecordReaderIterator}
4041
import org.apache.spark.sql.execution.datasources.v2.BatchScanExec
@@ -803,6 +804,31 @@ abstract class OrcQuerySuite extends OrcQueryTest with SharedSparkSession {
803804
}
804805
}
805806
}
807+
808+
test("SPARK-37463: read/write Timestamp ntz to Orc with different time zone") {
809+
DateTimeTestUtils.withDefaultTimeZone(DateTimeTestUtils.LA) {
810+
val sqlText = """
811+
|select
812+
| timestamp_ntz '2021-06-01 00:00:00' ts_ntz1,
813+
| timestamp_ntz '1883-11-16 00:00:00.0' as ts_ntz2,
814+
| timestamp_ntz '2021-03-14 02:15:00.0' as ts_ntz3
815+
|""".stripMargin
816+
817+
val df = sql(sqlText)
818+
819+
df.write.mode("overwrite").orc("ts_ntz_orc")
820+
821+
val query = "select * from `orc`.`ts_ntz_orc`"
822+
823+
DateTimeTestUtils.outstandingZoneIds.foreach { zoneId =>
824+
DateTimeTestUtils.withDefaultTimeZone(zoneId) {
825+
withAllNativeOrcReaders {
826+
checkAnswer(sql(query), df)
827+
}
828+
}
829+
}
830+
}
831+
}
806832
}
807833

808834
class OrcV1QuerySuite extends OrcQuerySuite {

0 commit comments

Comments
 (0)