diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala index 9696c3a0b6e1..fb6f379db640 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala @@ -1889,4 +1889,8 @@ object QueryExecutionErrors { def hiveTableWithAnsiIntervalsError(tableName: String): Throwable = { new UnsupportedOperationException(s"Hive table $tableName with ANSI intervals is not supported") } + + def cannotConvertOrcTimestampToTimestampNTZError(): Throwable = { + new RuntimeException("Unable to convert timestamp of Orc to data type 'timestamp_ntz'") + } } diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcAtomicColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcAtomicColumnVector.java index c2d8334d928c..b4f7b9924715 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcAtomicColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcAtomicColumnVector.java @@ -27,6 +27,7 @@ import org.apache.spark.sql.types.DateType; import org.apache.spark.sql.types.Decimal; import org.apache.spark.sql.types.TimestampType; +import org.apache.spark.sql.types.TimestampNTZType; import org.apache.spark.sql.vectorized.ColumnarArray; import org.apache.spark.sql.vectorized.ColumnarMap; import org.apache.spark.unsafe.types.UTF8String; @@ -36,6 +37,7 @@ */ public class OrcAtomicColumnVector extends OrcColumnVector { private final boolean isTimestamp; + private final boolean isTimestampNTZ; private final boolean isDate; // Column vector for each type. Only 1 is populated for any type. @@ -54,6 +56,12 @@ public class OrcAtomicColumnVector extends OrcColumnVector { isTimestamp = false; } + if (type instanceof TimestampNTZType) { + isTimestampNTZ = true; + } else { + isTimestampNTZ = false; + } + if (type instanceof DateType) { isDate = true; } else { @@ -105,6 +113,8 @@ public long getLong(int rowId) { int index = getRowIndex(rowId); if (isTimestamp) { return DateTimeUtils.fromJavaTimestamp(timestampData.asScratchTimestamp(index)); + } else if (isTimestampNTZ) { + return OrcUtils.fromOrcNTZ(timestampData.asScratchTimestamp(index)); } else { return longData.vector[index]; } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcDeserializer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcDeserializer.scala index 91408332b862..7ab556e330e4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcDeserializer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcDeserializer.scala @@ -129,6 +129,9 @@ class OrcDeserializer( case TimestampType => (ordinal, value) => updater.setLong(ordinal, DateTimeUtils.fromJavaTimestamp(value.asInstanceOf[OrcTimestamp])) + case TimestampNTZType => (ordinal, value) => + updater.setLong(ordinal, OrcUtils.fromOrcNTZ(value.asInstanceOf[OrcTimestamp])) + case DecimalType.Fixed(precision, scale) => (ordinal, value) => val v = OrcShimUtils.getDecimal(value) v.changePrecision(precision, scale) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala index 26af2c39b408..ce851c58cc4f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala @@ -42,24 +42,6 @@ import org.apache.spark.sql.sources._ import org.apache.spark.sql.types._ import org.apache.spark.util.{SerializableConfiguration, Utils} -private[sql] object OrcFileFormat { - - def getQuotedSchemaString(dataType: DataType): String = dataType match { - case _: DayTimeIntervalType => LongType.catalogString - case _: YearMonthIntervalType => IntegerType.catalogString - case _: AtomicType => dataType.catalogString - case StructType(fields) => - fields.map(f => s"`${f.name}`:${getQuotedSchemaString(f.dataType)}") - .mkString("struct<", ",", ">") - case ArrayType(elementType, _) => - s"array<${getQuotedSchemaString(elementType)}>" - case MapType(keyType, valueType, _) => - s"map<${getQuotedSchemaString(keyType)},${getQuotedSchemaString(valueType)}>" - case _ => // UDT and others - dataType.catalogString - } -} - /** * New ORC File Format based on Apache ORC. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcSerializer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcSerializer.scala index 9a1eb8a553c3..edd505273963 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcSerializer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcSerializer.scala @@ -18,7 +18,6 @@ package org.apache.spark.sql.execution.datasources.orc import org.apache.hadoop.io._ -import org.apache.orc.TypeDescription import org.apache.orc.mapred.{OrcList, OrcMap, OrcStruct, OrcTimestamp} import org.apache.spark.sql.catalyst.InternalRow @@ -148,6 +147,8 @@ class OrcSerializer(dataSchema: StructType) { result.setNanos(ts.getNanos) result + case TimestampNTZType => (getter, ordinal) => OrcUtils.toOrcNTZ(getter.getLong(ordinal)) + case DecimalType.Fixed(precision, scale) => OrcShimUtils.getHiveDecimalWritable(precision, scale) @@ -214,6 +215,6 @@ class OrcSerializer(dataSchema: StructType) { * Return a Orc value object for the given Spark schema. */ private def createOrcValue(dataType: DataType) = { - OrcStruct.createValue(TypeDescription.fromString(OrcFileFormat.getQuotedSchemaString(dataType))) + OrcStruct.createValue(OrcUtils.orcTypeDescription(dataType)) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcUtils.scala index b2624150a915..ec161e9e55dc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcUtils.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.execution.datasources.orc import java.nio.charset.StandardCharsets.UTF_8 +import java.sql.Timestamp import java.util.Locale import scala.collection.JavaConverters._ @@ -28,6 +29,7 @@ import org.apache.hadoop.fs.{FileStatus, Path} import org.apache.hadoop.hive.serde2.io.DateWritable import org.apache.hadoop.io.{BooleanWritable, ByteWritable, DoubleWritable, FloatWritable, IntWritable, LongWritable, ShortWritable, WritableComparable} import org.apache.orc.{BooleanColumnStatistics, ColumnStatistics, DateColumnStatistics, DoubleColumnStatistics, IntegerColumnStatistics, OrcConf, OrcFile, Reader, TypeDescription, Writer} +import org.apache.orc.mapred.OrcTimestamp import org.apache.spark.{SPARK_VERSION_SHORT, SparkException} import org.apache.spark.deploy.SparkHadoopUtil @@ -36,7 +38,8 @@ import org.apache.spark.sql.{SPARK_VERSION_METADATA_KEY, SparkSession} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.caseSensitiveResolution import org.apache.spark.sql.catalyst.parser.CatalystSqlParser -import org.apache.spark.sql.catalyst.util.{quoteIdentifier, CharVarcharUtils} +import org.apache.spark.sql.catalyst.util.{quoteIdentifier, CharVarcharUtils, DateTimeUtils} +import org.apache.spark.sql.catalyst.util.DateTimeConstants._ import org.apache.spark.sql.connector.expressions.aggregate.{Aggregation, Count, CountStar, Max, Min} import org.apache.spark.sql.errors.QueryExecutionErrors import org.apache.spark.sql.execution.datasources.SchemaMergeUtils @@ -196,7 +199,18 @@ object OrcUtils extends Logging { requiredSchema: StructType, reader: Reader, conf: Configuration): Option[(Array[Int], Boolean)] = { - val orcFieldNames = reader.getSchema.getFieldNames.asScala + def checkTimestampCompatibility(orcCatalystSchema: StructType, dataSchema: StructType): Unit = { + orcCatalystSchema.fields.map(_.dataType).zip(dataSchema.fields.map(_.dataType)).foreach { + case (TimestampType, TimestampNTZType) => + throw QueryExecutionErrors.cannotConvertOrcTimestampToTimestampNTZError() + case (t1: StructType, t2: StructType) => checkTimestampCompatibility(t1, t2) + case _ => + } + } + + val orcSchema = reader.getSchema + checkTimestampCompatibility(toCatalystSchema(orcSchema), dataSchema) + val orcFieldNames = orcSchema.getFieldNames.asScala val forcePositionalEvolution = OrcConf.FORCE_POSITIONAL_EVOLUTION.getBoolean(conf) if (orcFieldNames.isEmpty) { // SPARK-8501: Some old empty ORC files always have an empty schema stored in their footer. @@ -277,6 +291,7 @@ object OrcUtils extends Logging { s"array<${orcTypeDescriptionString(a.elementType)}>" case m: MapType => s"map<${orcTypeDescriptionString(m.keyType)},${orcTypeDescriptionString(m.valueType)}>" + case TimestampNTZType => TypeDescription.Category.TIMESTAMP.getName case _: DayTimeIntervalType => LongType.catalogString case _: YearMonthIntervalType => IntegerType.catalogString case _ => dt.catalogString @@ -286,15 +301,23 @@ object OrcUtils extends Logging { def getInnerTypeDecription(dt: DataType): Option[TypeDescription] = { dt match { case y: YearMonthIntervalType => - val typeDesc = orcTypeDescription(IntegerType) + val typeDesc = new TypeDescription(TypeDescription.Category.INT) typeDesc.setAttribute( CATALYST_TYPE_ATTRIBUTE_NAME, y.typeName) Some(typeDesc) case d: DayTimeIntervalType => - val typeDesc = orcTypeDescription(LongType) + val typeDesc = new TypeDescription(TypeDescription.Category.LONG) typeDesc.setAttribute( CATALYST_TYPE_ATTRIBUTE_NAME, d.typeName) Some(typeDesc) + case n: TimestampNTZType => + val typeDesc = new TypeDescription(TypeDescription.Category.TIMESTAMP) + typeDesc.setAttribute(CATALYST_TYPE_ATTRIBUTE_NAME, n.typeName) + Some(typeDesc) + case t: TimestampType => + val typeDesc = new TypeDescription(TypeDescription.Category.TIMESTAMP) + typeDesc.setAttribute(CATALYST_TYPE_ATTRIBUTE_NAME, t.typeName) + Some(typeDesc) case _ => None } } @@ -493,4 +516,17 @@ object OrcUtils extends Logging { val orcValuesDeserializer = new OrcDeserializer(aggSchema, (0 until aggSchema.length).toArray) orcValuesDeserializer.deserializeFromValues(aggORCValues) } + + def fromOrcNTZ(ts: Timestamp): Long = { + DateTimeUtils.millisToMicros(ts.getTime) + + (ts.getNanos / NANOS_PER_MICROS) % MICROS_PER_MILLIS + } + + def toOrcNTZ(micros: Long): OrcTimestamp = { + val seconds = Math.floorDiv(micros, MICROS_PER_SECOND) + val nanos = (micros - seconds * MICROS_PER_SECOND) * NANOS_PER_MICROS + val result = new OrcTimestamp(seconds * MILLIS_PER_SECOND) + result.setNanos(nanos.toInt) + result + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcWrite.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcWrite.scala index 286e87108053..1ac9266e8d5e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcWrite.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcWrite.scala @@ -23,7 +23,7 @@ import org.apache.orc.mapred.OrcStruct import org.apache.spark.sql.connector.write.LogicalWriteInfo import org.apache.spark.sql.execution.datasources.{OutputWriter, OutputWriterFactory} -import org.apache.spark.sql.execution.datasources.orc.{OrcFileFormat, OrcOptions, OrcOutputWriter, OrcUtils} +import org.apache.spark.sql.execution.datasources.orc.{OrcOptions, OrcOutputWriter, OrcUtils} import org.apache.spark.sql.execution.datasources.v2.FileWrite import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ @@ -43,7 +43,7 @@ case class OrcWrite( val conf = job.getConfiguration - conf.set(MAPRED_OUTPUT_SCHEMA.getAttribute, OrcFileFormat.getQuotedSchemaString(dataSchema)) + conf.set(MAPRED_OUTPUT_SCHEMA.getAttribute, OrcUtils.orcTypeDescriptionString(dataSchema)) conf.set(COMPRESS.getAttribute, orcOptions.compressionCodec) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala index 3f2f12d9d719..518090877e63 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala @@ -697,7 +697,7 @@ class FileBasedDataSourceSuite extends QueryTest test("SPARK-22790,SPARK-27668: spark.sql.sources.compressionFactor takes effect") { Seq(1.0, 0.5).foreach { compressionFactor => withSQLConf(SQLConf.FILE_COMPRESSION_FACTOR.key -> compressionFactor.toString, - SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "350") { + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "457") { withTempPath { workDir => // the file size is 504 bytes val workDirPath = workDir.getAbsolutePath diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcQuerySuite.scala index e4c33e96faa1..2d6978a81024 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcQuerySuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.execution.datasources.orc import java.io.File import java.nio.charset.StandardCharsets import java.sql.Timestamp +import java.time.{LocalDateTime, ZoneOffset} import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path @@ -768,6 +769,67 @@ abstract class OrcQuerySuite extends OrcQueryTest with SharedSparkSession { } } } + + test("Read/write all timestamp types") { + val data = (0 to 255).map { i => + (new Timestamp(i), LocalDateTime.of(2019, 3, 21, 0, 2, 3, 456000000 + i)) + } :+ (null, null) + + withOrcFile(data) { file => + withAllOrcReaders { + checkAnswer(spark.read.orc(file), data.toDF().collect()) + } + } + } + + test("SPARK-36346: can't read TimestampLTZ as TimestampNTZ") { + val data = (1 to 10).map { i => + val ts = new Timestamp(i) + Row(ts) + } + val answer = (1 to 10).map { i => + // The second parameter is `nanoOfSecond`, while java.sql.Timestamp accepts milliseconds + // as input. So here we multiple the `nanoOfSecond` by NANOS_PER_MILLIS + val ts = LocalDateTime.ofEpochSecond(0, i * 1000000, ZoneOffset.UTC) + Row(ts) + } + val actualSchema = StructType(Seq(StructField("time", TimestampType, false))) + val providedSchema = StructType(Seq(StructField("time", TimestampNTZType, false))) + + withTempPath { file => + val df = spark.createDataFrame(sparkContext.parallelize(data), actualSchema) + df.write.orc(file.getCanonicalPath) + withAllOrcReaders { + val msg = intercept[SparkException] { + spark.read.schema(providedSchema).orc(file.getCanonicalPath).collect() + }.getMessage + assert(msg.contains("Unable to convert timestamp of Orc to data type 'timestamp_ntz'")) + } + } + } + + test("SPARK-36346: read TimestampNTZ as TimestampLTZ") { + val data = (1 to 10).map { i => + // The second parameter is `nanoOfSecond`, while java.sql.Timestamp accepts milliseconds + // as input. So here we multiple the `nanoOfSecond` by NANOS_PER_MILLIS + val ts = LocalDateTime.ofEpochSecond(0, i * 1000000, ZoneOffset.UTC) + Row(ts) + } + val answer = (1 to 10).map { i => + val ts = new java.sql.Timestamp(i) + Row(ts) + } + val actualSchema = StructType(Seq(StructField("time", TimestampNTZType, false))) + val providedSchema = StructType(Seq(StructField("time", TimestampType, false))) + + withTempPath { file => + val df = spark.createDataFrame(sparkContext.parallelize(data), actualSchema) + df.write.orc(file.getCanonicalPath) + withAllOrcReaders { + checkAnswer(spark.read.schema(providedSchema).orc(file.getCanonicalPath), answer) + } + } + } } class OrcV1QuerySuite extends OrcQuerySuite { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcTest.scala index 4243318ac1dd..cd87374e8574 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcTest.scala @@ -143,6 +143,13 @@ abstract class OrcTest extends QueryTest with FileBasedDataSourceTest with Befor spark.read.orc(file.getAbsolutePath) } + def withAllOrcReaders(code: => Unit): Unit = { + // test the row-based reader + withSQLConf(SQLConf.ORC_VECTORIZED_READER_ENABLED.key -> "false")(code) + // test the vectorized reader + withSQLConf(SQLConf.ORC_VECTORIZED_READER_ENABLED.key -> "true")(code) + } + /** * Takes a sequence of products `data` to generate multi-level nested * dataframes as new test data. It tests both non-nested and nested dataframes