Skip to content

Commit 1ac830a

Browse files
HyukjinKwonrxin
authored andcommitted
[SPARK-16044][SQL] Backport input_file_name() for data source based on NewHadoopRDD to branch 1.6
## What changes were proposed in this pull request? This PR backports #13759. (`SqlNewHadoopRDDState` was renamed to `InputFileNameHolder` and `spark` API does not exist in branch 1.6) ## How was this patch tested? Unit tests in `ColumnExpressionSuite`. Author: hyukjinkwon <[email protected]> Closes #13806 from HyukjinKwon/backport-SPARK-16044.
1 parent 0cb06c9 commit 1ac830a

File tree

2 files changed

+42
-4
lines changed

2 files changed

+42
-4
lines changed

core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,12 @@ class NewHadoopRDD[K, V](
134134
val inputMetrics = context.taskMetrics
135135
.getInputMetricsForReadMethod(DataReadMethod.Hadoop)
136136

137+
// Sets the thread local variable for the file's name
138+
split.serializableHadoopSplit.value match {
139+
case fs: FileSplit => SqlNewHadoopRDDState.setInputFileName(fs.getPath.toString)
140+
case _ => SqlNewHadoopRDDState.unsetInputFileName()
141+
}
142+
137143
// Find a function that will return the FileSystem bytes read by this thread. Do this before
138144
// creating RecordReader, because RecordReader's constructor might read some bytes
139145
val bytesReadCallback = inputMetrics.bytesReadCallback.orElse {
@@ -190,6 +196,7 @@ class NewHadoopRDD[K, V](
190196

191197
private def close() {
192198
if (reader != null) {
199+
SqlNewHadoopRDDState.unsetInputFileName()
193200
// Close the reader and release it. Note: it's very important that we don't close the
194201
// reader more than once, since that exposes us to MAPREDUCE-5918 when running against
195202
// Hadoop 1.x and older Hadoop 2.x releases. That bug can lead to non-deterministic

sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala

Lines changed: 35 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,11 @@
1717

1818
package org.apache.spark.sql
1919

20-
import org.apache.spark.sql.catalyst.expressions.NamedExpression
20+
import org.apache.hadoop.io.{LongWritable, Text}
21+
import org.apache.hadoop.mapreduce.lib.input.{TextInputFormat => NewTextInputFormat}
2122
import org.scalatest.Matchers._
2223

24+
import org.apache.spark.sql.catalyst.expressions.NamedExpression
2325
import org.apache.spark.sql.execution.Project
2426
import org.apache.spark.sql.functions._
2527
import org.apache.spark.sql.test.SharedSQLContext
@@ -591,15 +593,44 @@ class ColumnExpressionSuite extends QueryTest with SharedSQLContext {
591593
)
592594
}
593595

594-
test("InputFileName") {
596+
test("InputFileName - SqlNewHadoopRDD") {
595597
withTempPath { dir =>
596598
val data = sparkContext.parallelize(0 to 10).toDF("id")
597599
data.write.parquet(dir.getCanonicalPath)
598-
val answer = sqlContext.read.parquet(dir.getCanonicalPath).select(inputFileName())
600+
val answer = sqlContext.read.parquet(dir.getCanonicalPath).select(input_file_name())
599601
.head.getString(0)
600602
assert(answer.contains(dir.getCanonicalPath))
601603

602-
checkAnswer(data.select(inputFileName()).limit(1), Row(""))
604+
checkAnswer(data.select(input_file_name()).limit(1), Row(""))
605+
}
606+
}
607+
608+
test("input_file_name - HadoopRDD") {
609+
withTempPath { dir =>
610+
val data = sparkContext.parallelize((0 to 10).map(_.toString)).toDF()
611+
data.write.text(dir.getCanonicalPath)
612+
val df = sparkContext.textFile(dir.getCanonicalPath).toDF()
613+
val answer = df.select(input_file_name()).head.getString(0)
614+
assert(answer.contains(dir.getCanonicalPath))
615+
616+
checkAnswer(data.select(input_file_name()).limit(1), Row(""))
617+
}
618+
}
619+
620+
test("input_file_name - NewHadoopRDD") {
621+
withTempPath { dir =>
622+
val data = sparkContext.parallelize((0 to 10).map(_.toString)).toDF()
623+
data.write.text(dir.getCanonicalPath)
624+
val rdd = sparkContext.newAPIHadoopFile(
625+
dir.getCanonicalPath,
626+
classOf[NewTextInputFormat],
627+
classOf[LongWritable],
628+
classOf[Text])
629+
val df = rdd.map(pair => pair._2.toString).toDF()
630+
val answer = df.select(input_file_name()).head.getString(0)
631+
assert(answer.contains(dir.getCanonicalPath))
632+
633+
checkAnswer(data.select(input_file_name()).limit(1), Row(""))
603634
}
604635
}
605636

0 commit comments

Comments
 (0)