Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,13 @@ import java.io._
import java.util.Properties
import javax.annotation.Nullable

import org.apache.hadoop.hive.conf.HiveConf
import org.apache.hadoop.hive.ql.exec.{RecordWriter, RecordReader}

import scala.collection.JavaConverters._
import scala.util.control.NonFatal

import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.hive.serde.serdeConstants
import org.apache.hadoop.hive.serde2.AbstractSerDe
import org.apache.hadoop.hive.serde2.objectinspector._
Expand All @@ -34,10 +38,9 @@ import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical.ScriptInputOutputSchema
import org.apache.spark.sql.execution._
import org.apache.spark.sql.hive.HiveShim._
import org.apache.spark.sql.hive.{HiveContext, HiveInspectors}
import org.apache.spark.sql.types.DataType
import org.apache.spark.util.{CircularBuffer, RedirectThread, Utils}
import org.apache.spark.util.{SerializableConfiguration, CircularBuffer, RedirectThread, Utils}
import org.apache.spark.{Logging, TaskContext}

/**
Expand All @@ -58,6 +61,8 @@ case class ScriptTransformation(

override def otherCopyArgs: Seq[HiveContext] = sc :: Nil

private val _broadcastedHiveConf = new SerializableConfiguration(sc.hiveconf)

protected override def doExecute(): RDD[InternalRow] = {
def processIterator(inputIterator: Iterator[InternalRow]): Iterator[InternalRow] = {
val cmd = List("/bin/bash", "-c", script)
Expand All @@ -67,6 +72,7 @@ case class ScriptTransformation(
val inputStream = proc.getInputStream
val outputStream = proc.getOutputStream
val errorStream = proc.getErrorStream
val localHconf = _broadcastedHiveConf.value

// In order to avoid deadlocks, we need to consume the error output of the child process.
// To avoid issues caused by large error output, we use a circular buffer to limit the amount
Expand Down Expand Up @@ -96,19 +102,23 @@ case class ScriptTransformation(
outputStream,
proc,
stderrBuffer,
TaskContext.get()
TaskContext.get(),
localHconf
)

// This nullability is a performance optimization in order to avoid an Option.foreach() call
// inside of a loop
@Nullable val (outputSerde, outputSoi) = {
ioschema.initOutputSerDe(output).getOrElse((null, null))
@Nullable val (outputSerde, outputSoi, tableProperties) = {
ioschema.initOutputSerDe(output).getOrElse((null, null, null))
}

val reader = new BufferedReader(new InputStreamReader(inputStream))
val outputIterator: Iterator[InternalRow] = new Iterator[InternalRow] with HiveInspectors {
var curLine: String = null
val scriptOutputStream = new DataInputStream(inputStream)
val scriptOutputReader: RecordReader = ioschema.getRecordReader(localHconf)

scriptOutputReader.initialize(
new DataInputStream(inputStream), _broadcastedHiveConf.value, tableProperties)
var scriptOutputWritable: Writable = null
val reusedWritableObject: Writable = if (null != outputSerde) {
outputSerde.getSerializedClass().newInstance
Expand All @@ -134,15 +144,13 @@ case class ScriptTransformation(
}
} else if (scriptOutputWritable == null) {
scriptOutputWritable = reusedWritableObject
try {
scriptOutputWritable.readFields(scriptOutputStream)
if (scriptOutputReader.next(scriptOutputWritable) <= 0) {
if (writerThread.exception.isDefined) {
throw writerThread.exception.get
}
false
} else {
true
} catch {
case _: EOFException =>
if (writerThread.exception.isDefined) {
throw writerThread.exception.get
}
false
}
} else {
true
Expand Down Expand Up @@ -210,7 +218,8 @@ private class ScriptTransformationWriterThread(
outputStream: OutputStream,
proc: Process,
stderrBuffer: CircularBuffer,
taskContext: TaskContext
taskContext: TaskContext,
localHconf: Configuration
) extends Thread("Thread-ScriptTransformation-Feed") with Logging {

setDaemon(true)
Expand All @@ -222,9 +231,9 @@ private class ScriptTransformationWriterThread(

override def run(): Unit = Utils.logUncaughtExceptions {
TaskContext.setTaskContext(taskContext)

val dataOutputStream = new DataOutputStream(outputStream)

val scriptInWriter: RecordWriter = ioschema.getRecordWriter(localHconf)
scriptInWriter.initialize(dataOutputStream, localHconf)
// We can't use Utils.tryWithSafeFinally here because we also need a `catch` block, so
// let's use a variable to record whether the `finally` block was hit due to an exception
var threwException: Boolean = true
Expand All @@ -250,7 +259,7 @@ private class ScriptTransformationWriterThread(
} else {
val writable = inputSerde.serialize(
row.asInstanceOf[GenericInternalRow].values, inputSoi)
prepareWritable(writable, ioschema.outputSerdeProps).write(dataOutputStream)
scriptInWriter.write(writable)
}
}
outputStream.close()
Expand Down Expand Up @@ -300,6 +309,19 @@ case class HiveScriptIOSchema (
val inputRowFormatMap = inputRowFormat.toMap.withDefault((k) => defaultFormat(k))
val outputRowFormatMap = outputRowFormat.toMap.withDefault((k) => defaultFormat(k))

def getRecordReader(conf: Configuration): RecordReader = {
// TODO: add support to get reader from sql clause
val readerName =
HiveConf.getVar(conf, HiveConf.ConfVars.HIVESCRIPTRECORDREADER)
Utils.classForName(readerName).newInstance.asInstanceOf[RecordReader]
}

def getRecordWriter(conf: Configuration): RecordWriter = {
// TODO: add support to get writer from sql clause
val writerName =
HiveConf.getVar(conf, HiveConf.ConfVars.HIVESCRIPTRECORDWRITER)
Utils.classForName(writerName).newInstance.asInstanceOf[RecordWriter]
}

def initInputSerDe(input: Seq[Expression]): Option[(AbstractSerDe, ObjectInspector)] = {
inputSerdeClass.map { serdeClass =>
Expand All @@ -313,12 +335,14 @@ case class HiveScriptIOSchema (
}
}

def initOutputSerDe(output: Seq[Attribute]): Option[(AbstractSerDe, StructObjectInspector)] = {
def initOutputSerDe(output: Seq[Attribute])
: Option[(AbstractSerDe, StructObjectInspector, Properties)] = {
outputSerdeClass.map { serdeClass =>
val (columns, columnTypes) = parseAttrs(output)
val serde = initSerDe(serdeClass, columns, columnTypes, outputSerdeProps)
val structObjectInspector = serde.getObjectInspector().asInstanceOf[StructObjectInspector]
(serde, structObjectInspector)
(serde, structObjectInspector,
createTableProperties(serdeClass, columns, columnTypes, outputSerdeProps))
}
}

Expand All @@ -328,23 +352,27 @@ case class HiveScriptIOSchema (
(columns, columnTypes)
}

private def initSerDe(
private def createTableProperties(
serdeClassName: String,
columns: Seq[String],
columnTypes: Seq[DataType],
serdeProps: Seq[(String, String)]): AbstractSerDe = {

val serde = Utils.classForName(serdeClassName).newInstance.asInstanceOf[AbstractSerDe]

serdeProps: Seq[(String, String)]) = {
val columnTypesNames = columnTypes.map(_.toTypeInfo.getTypeName()).mkString(",")

var propsMap = serdeProps.toMap + (serdeConstants.LIST_COLUMNS -> columns.mkString(","))
propsMap = propsMap + (serdeConstants.LIST_COLUMN_TYPES -> columnTypesNames)

propsMap = propsMap + (serdeConstants.FIELD_DELIM -> "\t")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't we specify line delimiter here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

line delimiter is control by RecorderWriter, see TextRecordWriter as an example:

  public void write(Writable row) throws IOException {
    Text text = (Text) row;
    Text escapeText = text;

    if (HiveConf.getBoolVar(conf, HiveConf.ConfVars.HIVESCRIPTESCAPE)) {
      escapeText = HiveUtils.escapeText(text);
    }

    out.write(escapeText.getBytes(), 0, escapeText.getLength());
    out.write(Utilities.newLineCode);
  }

val properties = new Properties()
properties.putAll(propsMap.asJava)
serde.initialize(null, properties)
properties
}

private def initSerDe(
serdeClassName: String,
columns: Seq[String],
columnTypes: Seq[DataType],
serdeProps: Seq[(String, String)]): AbstractSerDe = {
val serde = Utils.classForName(serdeClassName).newInstance.asInstanceOf[AbstractSerDe]
serde.initialize(null, createTableProperties(serdeClassName, columns, columnTypes, serdeProps))
serde
}
}
7 changes: 7 additions & 0 deletions sql/hive/src/test/resources/data/scripts/test_transript.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
import sys

for line in sys.stdin:
arr = line.strip().split("\t")
for i in range(len(arr)):
arr[i] = arr[i] + "#"
print("\t".join(arr))
Original file line number Diff line number Diff line change
Expand Up @@ -429,7 +429,8 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter {
|'org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe' FROM src;
""".stripMargin.replaceAll(System.lineSeparator(), " "))

test("transform with SerDe2") {
// TODO: Only support serde which compatible with TextRecordReader at the moment.
ignore("transform with SerDe2") {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why this test case should be ignored? The involved SQL query doesn't contain a RECORDREADER clause, and should fall back to TextRecordReader, shouldn't it?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Within TextRecordReader it would try to covert writable to be Text which is not suitable for avro.


sql("CREATE TABLE small_src(key INT, value STRING)")
sql("INSERT OVERWRITE TABLE small_src SELECT key, value FROM src LIMIT 10")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -763,6 +763,18 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils {
""".stripMargin), (2 to 6).map(i => Row(i)))
}

test("test script transform script input and output format") {
val data = (1 to 5).map { i => (i, i) }
data.toDF("a", "b").registerTempTable("test")
checkAnswer(
sql("""FROM
|(FROM test SELECT TRANSFORM(a, b)
|USING 'python src/test/resources/data/scripts/test_transript.py'
|AS (thing1 string, thing2 string)) t
|SELECT thing1
""".stripMargin), (1 to 5).map(i => Row(i + "#")))
}

test("window function: udaf with aggregate expressin") {
val data = Seq(
WindowData(1, "a", 5),
Expand Down