Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 43 additions & 9 deletions sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ import org.apache.hadoop.hive.ql.lib.Node
import org.apache.hadoop.hive.ql.parse._
import org.apache.hadoop.hive.ql.plan.PlanUtils
import org.apache.hadoop.hive.ql.session.SessionState
import org.apache.hadoop.hive.serde2.`lazy`.LazySimpleSerDe

import org.apache.spark.Logging
import org.apache.spark.sql.AnalysisException
Expand Down Expand Up @@ -884,16 +885,22 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C
AttributeReference("value", StringType)()), true)
}

def matchSerDe(clause: Seq[ASTNode])
: (Seq[(String, String)], Option[String], Seq[(String, String)]) = clause match {
type SerDeInfo = (
Seq[(String, String)], // Input row format information
Option[String], // Optional input SerDe class
Seq[(String, String)], // Input SerDe properties
Boolean // Whether to use default record reader/writer
)

def matchSerDe(clause: Seq[ASTNode]): SerDeInfo = clause match {
case Token("TOK_SERDEPROPS", propsClause) :: Nil =>
val rowFormat = propsClause.map {
case Token(name, Token(value, Nil) :: Nil) => (name, value)
}
(rowFormat, None, Nil)
(rowFormat, None, Nil, false)

case Token("TOK_SERDENAME", Token(serdeClass, Nil) :: Nil) :: Nil =>
(Nil, Some(BaseSemanticAnalyzer.unescapeSQLString(serdeClass)), Nil)
(Nil, Some(BaseSemanticAnalyzer.unescapeSQLString(serdeClass)), Nil, false)

case Token("TOK_SERDENAME", Token(serdeClass, Nil) ::
Token("TOK_TABLEPROPERTIES",
Expand All @@ -903,20 +910,47 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C
(BaseSemanticAnalyzer.unescapeSQLString(name),
BaseSemanticAnalyzer.unescapeSQLString(value))
}
(Nil, Some(BaseSemanticAnalyzer.unescapeSQLString(serdeClass)), serdeProps)

case Nil => (Nil, Option(hiveConf.getVar(ConfVars.HIVESCRIPTSERDE)), Nil)
// SPARK-10310: Special cases LazySimpleSerDe
// TODO Fully supports user-defined record reader/writer classes
val unescapedSerDeClass = BaseSemanticAnalyzer.unescapeSQLString(serdeClass)
val useDefaultRecordReaderWriter =
unescapedSerDeClass == classOf[LazySimpleSerDe].getCanonicalName
(Nil, Some(unescapedSerDeClass), serdeProps, useDefaultRecordReaderWriter)

case Nil =>
// Uses default TextRecordReader/TextRecordWriter, sets field delimiter here
val serdeProps = Seq(serdeConstants.FIELD_DELIM -> "\t")
(Nil, Option(hiveConf.getVar(ConfVars.HIVESCRIPTSERDE)), serdeProps, true)
}

val (inRowFormat, inSerdeClass, inSerdeProps) = matchSerDe(inputSerdeClause)
val (outRowFormat, outSerdeClass, outSerdeProps) = matchSerDe(outputSerdeClause)
val (inRowFormat, inSerdeClass, inSerdeProps, useDefaultRecordReader) =
matchSerDe(inputSerdeClause)

val (outRowFormat, outSerdeClass, outSerdeProps, useDefaultRecordWriter) =
matchSerDe(outputSerdeClause)

val unescapedScript = BaseSemanticAnalyzer.unescapeSQLString(script)

// TODO Adds support for user-defined record reader/writer classes
val recordReaderClass = if (useDefaultRecordReader) {
Option(hiveConf.getVar(ConfVars.HIVESCRIPTRECORDREADER))
} else {
None
}

val recordWriterClass = if (useDefaultRecordWriter) {
Option(hiveConf.getVar(ConfVars.HIVESCRIPTRECORDWRITER))
} else {
None
}

val schema = HiveScriptIOSchema(
inRowFormat, outRowFormat,
inSerdeClass, outSerdeClass,
inSerdeProps, outSerdeProps, schemaLess)
inSerdeProps, outSerdeProps,
recordReaderClass, recordWriterClass,
schemaLess)

Some(
logical.ScriptTransformation(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,20 +24,22 @@ import javax.annotation.Nullable
import scala.collection.JavaConverters._
import scala.util.control.NonFatal

import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.hive.ql.exec.{RecordReader, RecordWriter}
import org.apache.hadoop.hive.serde.serdeConstants
import org.apache.hadoop.hive.serde2.AbstractSerDe
import org.apache.hadoop.hive.serde2.objectinspector._
import org.apache.hadoop.io.Writable

import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical.ScriptInputOutputSchema
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
import org.apache.spark.sql.execution._
import org.apache.spark.sql.hive.HiveShim._
import org.apache.spark.sql.hive.{HiveContext, HiveInspectors}
import org.apache.spark.sql.types.DataType
import org.apache.spark.util.{CircularBuffer, RedirectThread, Utils}
import org.apache.spark.util.{CircularBuffer, RedirectThread, SerializableConfiguration, Utils}
import org.apache.spark.{Logging, TaskContext}

/**
Expand All @@ -58,6 +60,8 @@ case class ScriptTransformation(

override def otherCopyArgs: Seq[HiveContext] = sc :: Nil

private val serializedHiveConf = new SerializableConfiguration(sc.hiveconf)

protected override def doExecute(): RDD[InternalRow] = {
def processIterator(inputIterator: Iterator[InternalRow]): Iterator[InternalRow] = {
val cmd = List("/bin/bash", "-c", script)
Expand All @@ -67,6 +71,7 @@ case class ScriptTransformation(
val inputStream = proc.getInputStream
val outputStream = proc.getOutputStream
val errorStream = proc.getErrorStream
val localHiveConf = serializedHiveConf.value

// In order to avoid deadlocks, we need to consume the error output of the child process.
// To avoid issues caused by large error output, we use a circular buffer to limit the amount
Expand Down Expand Up @@ -96,7 +101,8 @@ case class ScriptTransformation(
outputStream,
proc,
stderrBuffer,
TaskContext.get()
TaskContext.get(),
localHiveConf
)

// This nullability is a performance optimization in order to avoid an Option.foreach() call
Expand All @@ -109,6 +115,10 @@ case class ScriptTransformation(
val outputIterator: Iterator[InternalRow] = new Iterator[InternalRow] with HiveInspectors {
var curLine: String = null
val scriptOutputStream = new DataInputStream(inputStream)

@Nullable val scriptOutputReader =
ioschema.recordReader(scriptOutputStream, localHiveConf).orNull

var scriptOutputWritable: Writable = null
val reusedWritableObject: Writable = if (null != outputSerde) {
outputSerde.getSerializedClass().newInstance
Expand All @@ -134,15 +144,25 @@ case class ScriptTransformation(
}
} else if (scriptOutputWritable == null) {
scriptOutputWritable = reusedWritableObject
try {
scriptOutputWritable.readFields(scriptOutputStream)
true
} catch {
case _: EOFException =>
if (writerThread.exception.isDefined) {
throw writerThread.exception.get
}

if (scriptOutputReader != null) {
if (scriptOutputReader.next(scriptOutputWritable) <= 0) {
writerThread.exception.foreach(throw _)
false
} else {
true
}
} else {
try {
scriptOutputWritable.readFields(scriptOutputStream)
true
} catch {
case _: EOFException =>
if (writerThread.exception.isDefined) {
throw writerThread.exception.get
}
false
}
}
} else {
true
Expand Down Expand Up @@ -210,7 +230,8 @@ private class ScriptTransformationWriterThread(
outputStream: OutputStream,
proc: Process,
stderrBuffer: CircularBuffer,
taskContext: TaskContext
taskContext: TaskContext,
conf: Configuration
) extends Thread("Thread-ScriptTransformation-Feed") with Logging {

setDaemon(true)
Expand All @@ -224,6 +245,7 @@ private class ScriptTransformationWriterThread(
TaskContext.setTaskContext(taskContext)

val dataOutputStream = new DataOutputStream(outputStream)
@Nullable val scriptInputWriter = ioschema.recordWriter(dataOutputStream, conf).orNull

// We can't use Utils.tryWithSafeFinally here because we also need a `catch` block, so
// let's use a variable to record whether the `finally` block was hit due to an exception
Expand All @@ -250,7 +272,12 @@ private class ScriptTransformationWriterThread(
} else {
val writable = inputSerde.serialize(
row.asInstanceOf[GenericInternalRow].values, inputSoi)
prepareWritable(writable, ioschema.outputSerdeProps).write(dataOutputStream)

if (scriptInputWriter != null) {
scriptInputWriter.write(writable)
} else {
prepareWritable(writable, ioschema.outputSerdeProps).write(dataOutputStream)
}
}
}
outputStream.close()
Expand Down Expand Up @@ -290,6 +317,8 @@ case class HiveScriptIOSchema (
outputSerdeClass: Option[String],
inputSerdeProps: Seq[(String, String)],
outputSerdeProps: Seq[(String, String)],
recordReaderClass: Option[String],
recordWriterClass: Option[String],
schemaLess: Boolean) extends ScriptInputOutputSchema with HiveInspectors {

private val defaultFormat = Map(
Expand Down Expand Up @@ -347,4 +376,24 @@ case class HiveScriptIOSchema (

serde
}

def recordReader(
inputStream: InputStream,
conf: Configuration): Option[RecordReader] = {
recordReaderClass.map { klass =>
val instance = Utils.classForName(klass).newInstance().asInstanceOf[RecordReader]
val props = new Properties()
props.putAll(outputSerdeProps.toMap.asJava)
instance.initialize(inputStream, conf, props)
instance
}
}

def recordWriter(outputStream: OutputStream, conf: Configuration): Option[RecordWriter] = {
recordWriterClass.map { klass =>
val instance = Utils.classForName(klass).newInstance().asInstanceOf[RecordWriter]
instance.initialize(outputStream, conf)
instance
}
}
}
6 changes: 6 additions & 0 deletions sql/hive/src/test/resources/data/scripts/test_transform.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
import sys

delim = sys.argv[1]

for row in sys.stdin:
print(delim.join([w + '#' for w in row[:-1].split(delim)]))
Original file line number Diff line number Diff line change
Expand Up @@ -1184,4 +1184,43 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton {

checkAnswer(df, Row("text inside layer 2") :: Nil)
}

test("SPARK-10310: " +
"script transformation using default input/output SerDe and record reader/writer") {
sqlContext
.range(5)
.selectExpr("id AS a", "id AS b")
.registerTempTable("test")

checkAnswer(
sql(
"""FROM(
| FROM test SELECT TRANSFORM(a, b)
| USING 'python src/test/resources/data/scripts/test_transform.py "\t"'
| AS (c STRING, d STRING)
|) t
|SELECT c
""".stripMargin),
(0 until 5).map(i => Row(i + "#")))
}

test("SPARK-10310: script transformation using LazySimpleSerDe") {
sqlContext
.range(5)
.selectExpr("id AS a", "id AS b")
.registerTempTable("test")

val df = sql(
"""FROM test
|SELECT TRANSFORM(a, b)
|ROW FORMAT SERDE 'org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe'
|WITH SERDEPROPERTIES('field.delim' = '|')
|USING 'python src/test/resources/data/scripts/test_transform.py "|"'
|AS (c STRING, d STRING)
|ROW FORMAT SERDE 'org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe'
|WITH SERDEPROPERTIES('field.delim' = '|')
""".stripMargin)

checkAnswer(df, (0 until 5).map(i => Row(i + "#", i + "#")))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ class ScriptTransformationSuite extends SparkPlanTest with TestHiveSingleton {
outputSerdeClass = None,
inputSerdeProps = Seq.empty,
outputSerdeProps = Seq.empty,
recordReaderClass = None,
recordWriterClass = None,
schemaLess = false
)

Expand Down