Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,25 +17,38 @@

package org.apache.spark.sql.execution

import java.io.OutputStream
import java.io.{BufferedReader, InputStream, InputStreamReader, OutputStream}
import java.nio.charset.StandardCharsets
import java.util.concurrent.TimeUnit

import scala.collection.JavaConverters._
import scala.util.control.NonFatal

import org.apache.hadoop.conf.Configuration

import org.apache.spark.{SparkException, TaskContext}
import org.apache.spark.internal.Logging
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{AttributeSet, UnsafeProjection}
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeSet, Cast, Expression, GenericInternalRow, UnsafeProjection}
import org.apache.spark.sql.catalyst.plans.logical.ScriptInputOutputSchema
import org.apache.spark.sql.catalyst.plans.physical.Partitioning
import org.apache.spark.sql.catalyst.util.{DateTimeUtils, IntervalUtils}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.DataType
import org.apache.spark.util.{CircularBuffer, SerializableConfiguration, Utils}
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
import org.apache.spark.util.{CircularBuffer, RedirectThread, SerializableConfiguration, Utils}

trait BaseScriptTransformationExec extends UnaryExecNode {
def input: Seq[Expression]
def script: String
def output: Seq[Attribute]
def child: SparkPlan
def ioschema: ScriptTransformationIOSchema

protected lazy val inputExpressionsWithoutSerde: Seq[Expression] = {
input.map(Cast(_, StringType).withTimeZone(conf.sessionLocalTimeZone))
}

override def producedAttributes: AttributeSet = outputSet -- inputSet

Expand All @@ -56,10 +69,91 @@ trait BaseScriptTransformationExec extends UnaryExecNode {
}
}

def processIterator(
protected def initProc: (OutputStream, Process, InputStream, CircularBuffer) = {
val cmd = List("/bin/bash", "-c", script)
val builder = new ProcessBuilder(cmd.asJava)

val proc = builder.start()
val inputStream = proc.getInputStream
val outputStream = proc.getOutputStream
val errorStream = proc.getErrorStream

// In order to avoid deadlocks, we need to consume the error output of the child process.
// To avoid issues caused by large error output, we use a circular buffer to limit the amount
// of error output that we retain. See SPARK-7862 for more discussion of the deadlock / hang
// that motivates this.
val stderrBuffer = new CircularBuffer(2048)
new RedirectThread(
errorStream,
stderrBuffer,
s"Thread-${this.getClass.getSimpleName}-STDERR-Consumer").start()
(outputStream, proc, inputStream, stderrBuffer)
}

protected def processIterator(
inputIterator: Iterator[InternalRow],
hadoopConf: Configuration): Iterator[InternalRow]

protected def createOutputIteratorWithoutSerde(
writerThread: BaseScriptTransformationWriterThread,
inputStream: InputStream,
proc: Process,
stderrBuffer: CircularBuffer): Iterator[InternalRow] = {
new Iterator[InternalRow] {
var curLine: String = null
val reader = new BufferedReader(new InputStreamReader(inputStream, StandardCharsets.UTF_8))

val outputRowFormat = ioschema.outputRowFormatMap("TOK_TABLEROWFORMATFIELD")
val processRowWithoutSerde = if (!ioschema.schemaLess) {
prevLine: String =>
new GenericInternalRow(
prevLine.split(outputRowFormat)
.zip(outputFieldWriters)
.map { case (data, writer) => writer(data) })
} else {
// In schema less mode, hive default serde will choose first two output column as output
// if output column size less then 2, it will throw ArrayIndexOutOfBoundsException.
// Here we change spark's behavior same as hive's default serde.
// But in hive, TRANSFORM with schema less behavior like origin spark, we will fix this
// to keep spark and hive behavior same in SPARK-32388
val kvWriter = CatalystTypeConverters.createToCatalystConverter(StringType)
prevLine: String =>
new GenericInternalRow(
prevLine.split(outputRowFormat).slice(0, 2)
.map(kvWriter))
}

override def hasNext: Boolean = {
try {
if (curLine == null) {
curLine = reader.readLine()
if (curLine == null) {
checkFailureAndPropagate(writerThread, null, proc, stderrBuffer)
return false
}
}
true
} catch {
case NonFatal(e) =>
// If this exception is due to abrupt / unclean termination of `proc`,
// then detect it and propagate a better exception message for end users
checkFailureAndPropagate(writerThread, e, proc, stderrBuffer)

throw e
}
}

override def next(): InternalRow = {
if (!hasNext) {
throw new NoSuchElementException
}
val prevLine = curLine
curLine = reader.readLine()
processRowWithoutSerde(prevLine)
}
}
}

protected def checkFailureAndPropagate(
writerThread: BaseScriptTransformationWriterThread,
cause: Throwable = null,
Expand Down Expand Up @@ -87,17 +181,72 @@ trait BaseScriptTransformationExec extends UnaryExecNode {
}
}
}

private lazy val outputFieldWriters: Seq[String => Any] = output.map { attr =>
val converter = CatalystTypeConverters.createToCatalystConverter(attr.dataType)
attr.dataType match {
case StringType => wrapperConvertException(data => data, converter)
case BooleanType => wrapperConvertException(data => data.toBoolean, converter)
case ByteType => wrapperConvertException(data => data.toByte, converter)
case BinaryType =>
wrapperConvertException(data => UTF8String.fromString(data).getBytes, converter)
case IntegerType => wrapperConvertException(data => data.toInt, converter)
case ShortType => wrapperConvertException(data => data.toShort, converter)
case LongType => wrapperConvertException(data => data.toLong, converter)
case FloatType => wrapperConvertException(data => data.toFloat, converter)
case DoubleType => wrapperConvertException(data => data.toDouble, converter)
case _: DecimalType => wrapperConvertException(data => BigDecimal(data), converter)
case DateType if conf.datetimeJava8ApiEnabled =>
wrapperConvertException(data => DateTimeUtils.stringToDate(
UTF8String.fromString(data),
DateTimeUtils.getZoneId(conf.sessionLocalTimeZone))
.map(DateTimeUtils.daysToLocalDate).orNull, converter)
case DateType => wrapperConvertException(data => DateTimeUtils.stringToDate(
UTF8String.fromString(data),
DateTimeUtils.getZoneId(conf.sessionLocalTimeZone))
.map(DateTimeUtils.toJavaDate).orNull, converter)
case TimestampType if conf.datetimeJava8ApiEnabled =>
wrapperConvertException(data => DateTimeUtils.stringToTimestamp(
UTF8String.fromString(data),
DateTimeUtils.getZoneId(conf.sessionLocalTimeZone))
.map(DateTimeUtils.microsToInstant).orNull, converter)
case TimestampType => wrapperConvertException(data => DateTimeUtils.stringToTimestamp(
UTF8String.fromString(data),
DateTimeUtils.getZoneId(conf.sessionLocalTimeZone))
.map(DateTimeUtils.toJavaTimestamp).orNull, converter)
case CalendarIntervalType => wrapperConvertException(
data => IntervalUtils.stringToInterval(UTF8String.fromString(data)),
converter)
case udt: UserDefinedType[_] =>
wrapperConvertException(data => udt.deserialize(data), converter)
case dt =>
throw new SparkException(s"${nodeName} without serde does not support " +
s"${dt.getClass.getSimpleName} as output data type")
}
}

// Keep consistent with Hive `LazySimpleSerde`, when there is a type case error, return null
private val wrapperConvertException: (String => Any, Any => Any) => String => Any =
(f: String => Any, converter: Any => Any) =>
(data: String) => converter {
try {
f(data)
} catch {
case NonFatal(_) => null
}
}
}

abstract class BaseScriptTransformationWriterThread(
iter: Iterator[InternalRow],
inputSchema: Seq[DataType],
ioSchema: BaseScriptTransformIOSchema,
outputStream: OutputStream,
proc: Process,
stderrBuffer: CircularBuffer,
taskContext: TaskContext,
conf: Configuration) extends Thread with Logging {
abstract class BaseScriptTransformationWriterThread extends Thread with Logging {

def iter: Iterator[InternalRow]
def inputSchema: Seq[DataType]
def ioSchema: ScriptTransformationIOSchema
def outputStream: OutputStream
def proc: Process
def stderrBuffer: CircularBuffer
def taskContext: TaskContext
def conf: Configuration

setDaemon(true)

Expand Down Expand Up @@ -169,34 +318,50 @@ abstract class BaseScriptTransformationWriterThread(
/**
* The wrapper class of input and output schema properties
*/
abstract class BaseScriptTransformIOSchema extends Serializable {
import ScriptIOSchema._

def inputRowFormat: Seq[(String, String)]

def outputRowFormat: Seq[(String, String)]

def inputSerdeClass: Option[String]

def outputSerdeClass: Option[String]

def inputSerdeProps: Seq[(String, String)]

def outputSerdeProps: Seq[(String, String)]

def recordReaderClass: Option[String]

def recordWriterClass: Option[String]

def schemaLess: Boolean
case class ScriptTransformationIOSchema(
inputRowFormat: Seq[(String, String)],
outputRowFormat: Seq[(String, String)],
inputSerdeClass: Option[String],
outputSerdeClass: Option[String],
inputSerdeProps: Seq[(String, String)],
outputSerdeProps: Seq[(String, String)],
recordReaderClass: Option[String],
recordWriterClass: Option[String],
schemaLess: Boolean) extends Serializable {
import ScriptTransformationIOSchema._

val inputRowFormatMap = inputRowFormat.toMap.withDefault((k) => defaultFormat(k))
val outputRowFormatMap = outputRowFormat.toMap.withDefault((k) => defaultFormat(k))
}

object ScriptIOSchema {
object ScriptTransformationIOSchema {
val defaultFormat = Map(
("TOK_TABLEROWFORMATFIELD", "\t"),
("TOK_TABLEROWFORMATLINES", "\n")
)

val defaultIOSchema = ScriptTransformationIOSchema(
inputRowFormat = Seq.empty,
outputRowFormat = Seq.empty,
inputSerdeClass = None,
outputSerdeClass = None,
inputSerdeProps = Seq.empty,
outputSerdeProps = Seq.empty,
recordReaderClass = None,
recordWriterClass = None,
schemaLess = false
)

def apply(input: ScriptInputOutputSchema): ScriptTransformationIOSchema = {
ScriptTransformationIOSchema(
input.inputRowFormat,
input.outputRowFormat,
input.inputSerdeClass,
input.outputSerdeClass,
input.inputSerdeProps,
input.outputSerdeProps,
input.recordReaderClass,
input.recordWriterClass,
input.schemaLess)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -1046,8 +1046,8 @@ private[hive] trait HiveInspectors {
getListTypeInfo(elemType.toTypeInfo)
case StructType(fields) =>
getStructTypeInfo(
java.util.Arrays.asList(fields.map(_.name) : _*),
java.util.Arrays.asList(fields.map(_.dataType.toTypeInfo) : _*))
java.util.Arrays.asList(fields.map(_.name): _*),
java.util.Arrays.asList(fields.map(_.dataType.toTypeInfo): _*))
case MapType(keyType, valueType, _) =>
getMapTypeInfo(keyType.toTypeInfo, valueType.toTypeInfo)
case BinaryType => binaryTypeInfo
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.command.{CreateTableCommand, DDLUtils}
import org.apache.spark.sql.execution.datasources.CreateTable
import org.apache.spark.sql.hive.execution._
import org.apache.spark.sql.hive.execution.{HiveScriptIOSchema, HiveScriptTransformationExec}
import org.apache.spark.sql.hive.execution.HiveScriptTransformationExec
import org.apache.spark.sql.internal.{HiveSerDe, SQLConf}


Expand Down Expand Up @@ -244,7 +244,7 @@ private[hive] trait HiveStrategies {
object HiveScripts extends Strategy {
def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
case ScriptTransformation(input, script, output, child, ioschema) =>
val hiveIoSchema = HiveScriptIOSchema(ioschema)
val hiveIoSchema = ScriptTransformationIOSchema(ioschema)
HiveScriptTransformationExec(input, script, output, planLater(child), hiveIoSchema) :: Nil
case _ => Nil
}
Expand Down
Loading