Skip to content

Commit 92d7b8f

Browse files
committed
Update code
1 parent 7c343a7 commit 92d7b8f

File tree

2 files changed

+37
-19
lines changed

2 files changed

+37
-19
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,14 +22,14 @@ import java.net.URI
2222

2323
import org.apache.hadoop.conf.Configuration
2424
import org.apache.hadoop.fs.{FileStatus, Path}
25+
import org.apache.hadoop.io.WritableComparable
2526
import org.apache.hadoop.mapred.JobConf
2627
import org.apache.hadoop.mapreduce._
2728
import org.apache.hadoop.mapreduce.lib.input.FileSplit
2829
import org.apache.hadoop.mapreduce.task.TaskAttemptContextImpl
2930
import org.apache.orc.{OrcUtils => _, _}
3031
import org.apache.orc.OrcConf.COMPRESS
3132
import org.apache.orc.mapred.OrcStruct
32-
import org.apache.orc.mapreduce._
3333

3434
import org.apache.spark.TaskContext
3535
import org.apache.spark.sql.SparkSession
@@ -155,7 +155,7 @@ class OrcFileFormat
155155
if (orcFilterPushDown && filters.nonEmpty) {
156156
OrcUtils.readCatalystSchema(filePath, conf, ignoreCorruptFiles).foreach { fileSchema =>
157157
OrcFilters.createFilter(fileSchema, filters).foreach { f =>
158-
OrcInputFormat.setSearchArgument(conf, f, fileSchema.fieldNames)
158+
mapreduce.OrcInputFormat.setSearchArgument(conf, f, fileSchema.fieldNames)
159159
}
160160
}
161161
}
@@ -193,8 +193,8 @@ class OrcFileFormat
193193

194194
iter.asInstanceOf[Iterator[InternalRow]]
195195
} else {
196-
val orcRecordReader = new OrcInputFormat[OrcStruct]
197-
.createRecordReader(fileSplit, taskAttemptContext)
196+
val orcRecordReader: mapreduce.OrcMapreduceRecordReader[OrcStruct] =
197+
createRecordReader[OrcStruct](fileSplit, taskAttemptContext)
198198
val iter = new RecordReaderIterator[OrcStruct](orcRecordReader)
199199
Option(TaskContext.get()).foreach(_.addTaskCompletionListener[Unit](_ => iter.close()))
200200

@@ -214,6 +214,19 @@ class OrcFileFormat
214214
}
215215
}
216216

217+
private def createRecordReader[V <: WritableComparable[_]](
218+
inputSplit: InputSplit,
219+
taskAttemptContext: TaskAttemptContext): mapreduce.OrcMapreduceRecordReader[V] = {
220+
val split = inputSplit.asInstanceOf[FileSplit]
221+
val conf = taskAttemptContext.getConfiguration()
222+
val readOptions = OrcFile.readerOptions(conf)
223+
.maxLength(OrcConf.MAX_FILE_LENGTH.getLong(conf)).useUTCTimestamp(true)
224+
val file = OrcFile.createReader(split.getPath(), readOptions)
225+
val options = org.apache.orc.mapred.OrcInputFormat.buildOptions(
226+
conf, file, split.getStart(), split.getLength()).useSelected(true)
227+
new mapreduce.OrcMapreduceRecordReader(file, options)
228+
}
229+
217230
override def supportDataType(dataType: DataType): Boolean = dataType match {
218231
case _: AtomicType => true
219232

sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcQuerySuite.scala

Lines changed: 20 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@ import java.io.File
2121
import java.nio.charset.StandardCharsets
2222
import java.sql.Timestamp
2323
import java.time.{LocalDateTime, ZoneOffset}
24+
import java.util.TimeZone
25+
2426
import org.apache.hadoop.conf.Configuration
2527
import org.apache.hadoop.fs.Path
2628
import org.apache.hadoop.mapreduce.{JobID, TaskAttemptID, TaskID, TaskType}
@@ -30,6 +32,7 @@ import org.apache.orc.{OrcConf, OrcFile}
3032
import org.apache.orc.OrcConf.COMPRESS
3133
import org.apache.orc.mapred.OrcStruct
3234
import org.apache.orc.mapreduce.OrcInputFormat
35+
3336
import org.apache.spark.{SparkConf, SparkException}
3437
import org.apache.spark.sql._
3538
import org.apache.spark.sql.catalyst.TableIdentifier
@@ -41,8 +44,6 @@ import org.apache.spark.sql.test.SharedSparkSession
4144
import org.apache.spark.sql.types._
4245
import org.apache.spark.util.Utils
4346

44-
import java.util.TimeZone
45-
4647
case class AllDataTypesWithNonPrimitiveType(
4748
stringField: String,
4849
intField: Int,
@@ -832,23 +833,27 @@ abstract class OrcQuerySuite extends OrcQueryTest with SharedSparkSession {
832833
}
833834

834835
test("SPARK-37463: read/write Timestamp ntz or ltz to Orc uses UTC timestamp") {
835-
TimeZone.setDefault(TimeZone.getTimeZone("America/Los_Angeles"))
836-
sql("set spark.sql.session.timeZone = America/Los_Angeles")
836+
val localTimeZone = TimeZone.getDefault
837+
try {
838+
TimeZone.setDefault(TimeZone.getTimeZone("America/Los_Angeles"))
837839

838-
val df =
839-
sql("select timestamp_ntz '2021-06-01 00:00:00' ts_ntz, timestamp_ltz '2021-06-01 00:00:00' ts_ltz")
840+
val df = sql("""
841+
|select timestamp_ntz '2021-06-01 00:00:00' ts_ntz,
842+
|timestamp_ltz '2021-06-01 00:00:00' ts_ltz
843+
|""".stripMargin)
840844

841-
df.write.mode("overwrite").orc("ts_ntz_orc")
842-
df.write.mode("overwrite").parquet("ts_ntz_parquet")
845+
df.write.mode("overwrite").orc("ts_ntz_orc")
843846

844-
val queryOrc = "select * from `orc`.`ts_ntz_orc`"
845-
val queryParquet = "select * from `parquet`.`ts_ntz_parquet`"
847+
val query = "select * from `orc`.`ts_ntz_orc`"
846848

847-
val tzs = Seq("America/Los_Angeles", "UTC", "Europe/Amsterdam")
848-
for (tz <- tzs) {
849-
TimeZone.setDefault(TimeZone.getTimeZone(tz))
850-
sql(s"set spark.sql.session.timeZone = $tz")
851-
sql(queryOrc).collect().equals(sql(queryParquet).collect())
849+
Seq("America/Los_Angeles", "UTC", "Europe/Amsterdam").foreach { tz =>
850+
TimeZone.setDefault(TimeZone.getTimeZone(tz))
851+
withAllOrcReaders {
852+
checkAnswer(sql(query), df)
853+
}
854+
}
855+
} finally {
856+
TimeZone.setDefault(localTimeZone)
852857
}
853858
}
854859
}

0 commit comments

Comments
 (0)