Skip to content

Commit 8364f1f

Browse files
committed
[SPARK-32105][SQL][FOLLOWUP]Refactor current ScriptTransformationExec code
1 parent 6d49964 commit 8364f1f

File tree

4 files changed

+304
-184
lines changed

4 files changed

+304
-184
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/BaseScriptTransformationExec.scala

Lines changed: 201 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -17,26 +17,40 @@
1717

1818
package org.apache.spark.sql.execution
1919

20-
import java.io.OutputStream
20+
import java.io.{BufferedReader, InputStream, InputStreamReader, OutputStream}
2121
import java.nio.charset.StandardCharsets
2222
import java.util.concurrent.TimeUnit
2323

24+
import scala.collection.JavaConverters._
2425
import scala.util.control.NonFatal
2526

2627
import org.apache.hadoop.conf.Configuration
2728

2829
import org.apache.spark.{SparkException, TaskContext}
2930
import org.apache.spark.internal.Logging
3031
import org.apache.spark.rdd.RDD
31-
import org.apache.spark.sql.catalyst.InternalRow
32-
import org.apache.spark.sql.catalyst.expressions.{AttributeSet, UnsafeProjection}
32+
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
33+
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeSet, Cast, Expression, GenericInternalRow, UnsafeProjection}
34+
import org.apache.spark.sql.catalyst.plans.logical.ScriptInputOutputSchema
3335
import org.apache.spark.sql.catalyst.plans.physical.Partitioning
36+
import org.apache.spark.sql.catalyst.util.{DateTimeUtils, IntervalUtils}
3437
import org.apache.spark.sql.internal.SQLConf
35-
import org.apache.spark.sql.types.DataType
36-
import org.apache.spark.util.{CircularBuffer, SerializableConfiguration, Utils}
38+
import org.apache.spark.sql.types._
39+
import org.apache.spark.unsafe.types.UTF8String
40+
import org.apache.spark.util.{CircularBuffer, RedirectThread, SerializableConfiguration, Utils}
3741

3842
trait BaseScriptTransformationExec extends UnaryExecNode {
3943

44+
def input: Seq[Expression]
45+
def script: String
46+
def output: Seq[Attribute]
47+
def child: SparkPlan
48+
def ioschema: ScriptTransformationIOSchema
49+
50+
protected lazy val inputExpressionsWithoutSerde: Seq[Expression] = {
51+
input.map(Cast(_, StringType).withTimeZone(conf.sessionLocalTimeZone))
52+
}
53+
4054
override def producedAttributes: AttributeSet = outputSet -- inputSet
4155

4256
override def outputPartitioning: Partitioning = child.outputPartitioning
@@ -56,10 +70,91 @@ trait BaseScriptTransformationExec extends UnaryExecNode {
5670
}
5771
}
5872

73+
protected def initProc: (OutputStream, Process, InputStream, CircularBuffer) = {
74+
val cmd = List("/bin/bash", "-c", script)
75+
val builder = new ProcessBuilder(cmd.asJava)
76+
77+
val proc = builder.start()
78+
val inputStream = proc.getInputStream
79+
val outputStream = proc.getOutputStream
80+
val errorStream = proc.getErrorStream
81+
82+
// In order to avoid deadlocks, we need to consume the error output of the child process.
83+
// To avoid issues caused by large error output, we use a circular buffer to limit the amount
84+
// of error output that we retain. See SPARK-7862 for more discussion of the deadlock / hang
85+
// that motivates this.
86+
val stderrBuffer = new CircularBuffer(2048)
87+
new RedirectThread(
88+
errorStream,
89+
stderrBuffer,
90+
s"Thread-${this.getClass.getSimpleName}-STDERR-Consumer").start()
91+
(outputStream, proc, inputStream, stderrBuffer)
92+
}
93+
5994
def processIterator(
6095
inputIterator: Iterator[InternalRow],
6196
hadoopConf: Configuration): Iterator[InternalRow]
6297

98+
protected def createOutputIteratorWithoutSerde(
99+
writerThread: BaseScriptTransformationWriterThread,
100+
inputStream: InputStream,
101+
proc: Process,
102+
stderrBuffer: CircularBuffer): Iterator[InternalRow] = {
103+
new Iterator[InternalRow] {
104+
var curLine: String = null
105+
val reader = new BufferedReader(new InputStreamReader(inputStream, StandardCharsets.UTF_8))
106+
107+
val outputRowFormat = ioschema.outputRowFormatMap("TOK_TABLEROWFORMATFIELD")
108+
val kvWriter = CatalystTypeConverters.createToCatalystConverter(StringType)
109+
val processRowWithoutSerde = if (!ioschema.schemaLess) {
110+
prevLine: String =>
111+
new GenericInternalRow(
112+
prevLine.split(outputRowFormat)
113+
.zip(outputFieldWriters)
114+
.map { case (data, writer) => writer(data) })
115+
} else {
116+
// In schema less mode, hive default serde will choose first two output column as output
117+
// if output column size less then 2, it will throw ArrayIndexOutOfBoundsException.
118+
// Here we change spark's behavior same as hive's default serde.
119+
// But in hive, TRANSFORM with schema less behavior like origin spark, we will fix this
120+
// to keep spark and hive behavior same in SPARK-32388
121+
prevLine: String =>
122+
new GenericInternalRow(
123+
prevLine.split(outputRowFormat, 2)
124+
.map(kvWriter))
125+
}
126+
127+
override def hasNext: Boolean = {
128+
try {
129+
if (curLine == null) {
130+
curLine = reader.readLine()
131+
if (curLine == null) {
132+
checkFailureAndPropagate(writerThread, null, proc, stderrBuffer)
133+
return false
134+
}
135+
}
136+
true
137+
} catch {
138+
case NonFatal(e) =>
139+
// If this exception is due to abrupt / unclean termination of `proc`,
140+
// then detect it and propagate a better exception message for end users
141+
checkFailureAndPropagate(writerThread, e, proc, stderrBuffer)
142+
143+
throw e
144+
}
145+
}
146+
147+
override def next(): InternalRow = {
148+
if (!hasNext) {
149+
throw new NoSuchElementException
150+
}
151+
val prevLine = curLine
152+
curLine = reader.readLine()
153+
processRowWithoutSerde(prevLine)
154+
}
155+
}
156+
}
157+
63158
protected def checkFailureAndPropagate(
64159
writerThread: BaseScriptTransformationWriterThread,
65160
cause: Throwable = null,
@@ -87,17 +182,72 @@ trait BaseScriptTransformationExec extends UnaryExecNode {
87182
}
88183
}
89184
}
185+
186+
private lazy val outputFieldWriters: Seq[String => Any] = output.map { attr =>
187+
val converter = CatalystTypeConverters.createToCatalystConverter(attr.dataType)
188+
attr.dataType match {
189+
case StringType => wrapperConvertException(data => data, converter)
190+
case BooleanType => wrapperConvertException(data => data.toBoolean, converter)
191+
case ByteType => wrapperConvertException(data => data.toByte, converter)
192+
case BinaryType =>
193+
wrapperConvertException(data => UTF8String.fromString(data).getBytes, converter)
194+
case IntegerType => wrapperConvertException(data => data.toInt, converter)
195+
case ShortType => wrapperConvertException(data => data.toShort, converter)
196+
case LongType => wrapperConvertException(data => data.toLong, converter)
197+
case FloatType => wrapperConvertException(data => data.toFloat, converter)
198+
case DoubleType => wrapperConvertException(data => data.toDouble, converter)
199+
case _: DecimalType => wrapperConvertException(data => BigDecimal(data), converter)
200+
case DateType if conf.datetimeJava8ApiEnabled =>
201+
wrapperConvertException(data => DateTimeUtils.stringToDate(
202+
UTF8String.fromString(data),
203+
DateTimeUtils.getZoneId(conf.sessionLocalTimeZone))
204+
.map(DateTimeUtils.daysToLocalDate).orNull, converter)
205+
case DateType => wrapperConvertException(data => DateTimeUtils.stringToDate(
206+
UTF8String.fromString(data),
207+
DateTimeUtils.getZoneId(conf.sessionLocalTimeZone))
208+
.map(DateTimeUtils.toJavaDate).orNull, converter)
209+
case TimestampType if conf.datetimeJava8ApiEnabled =>
210+
wrapperConvertException(data => DateTimeUtils.stringToTimestamp(
211+
UTF8String.fromString(data),
212+
DateTimeUtils.getZoneId(conf.sessionLocalTimeZone))
213+
.map(DateTimeUtils.microsToInstant).orNull, converter)
214+
case TimestampType => wrapperConvertException(data => DateTimeUtils.stringToTimestamp(
215+
UTF8String.fromString(data),
216+
DateTimeUtils.getZoneId(conf.sessionLocalTimeZone))
217+
.map(DateTimeUtils.toJavaTimestamp).orNull, converter)
218+
case CalendarIntervalType => wrapperConvertException(
219+
data => IntervalUtils.stringToInterval(UTF8String.fromString(data)),
220+
converter)
221+
case udt: UserDefinedType[_] =>
222+
wrapperConvertException(data => udt.deserialize(data), converter)
223+
case dt =>
224+
throw new SparkException(s"${nodeName} without serde does not support " +
225+
s"${dt.getClass.getSimpleName} as output data type")
226+
}
227+
}
228+
229+
// Keep consistent with Hive `LazySimpleSerde`, when there is a type case error, return null
230+
private val wrapperConvertException: (String => Any, Any => Any) => String => Any =
231+
(f: String => Any, converter: Any => Any) =>
232+
(data: String) => converter {
233+
try {
234+
f(data)
235+
} catch {
236+
case NonFatal(_) => null
237+
}
238+
}
90239
}
91240

92-
abstract class BaseScriptTransformationWriterThread(
93-
iter: Iterator[InternalRow],
94-
inputSchema: Seq[DataType],
95-
ioSchema: BaseScriptTransformIOSchema,
96-
outputStream: OutputStream,
97-
proc: Process,
98-
stderrBuffer: CircularBuffer,
99-
taskContext: TaskContext,
100-
conf: Configuration) extends Thread with Logging {
241+
abstract class BaseScriptTransformationWriterThread extends Thread with Logging {
242+
243+
def iter: Iterator[InternalRow]
244+
def inputSchema: Seq[DataType]
245+
def ioSchema: ScriptTransformationIOSchema
246+
def outputStream: OutputStream
247+
def proc: Process
248+
def stderrBuffer: CircularBuffer
249+
def taskContext: TaskContext
250+
def conf: Configuration
101251

102252
setDaemon(true)
103253

@@ -169,34 +319,50 @@ abstract class BaseScriptTransformationWriterThread(
169319
/**
170320
* The wrapper class of input and output schema properties
171321
*/
172-
abstract class BaseScriptTransformIOSchema extends Serializable {
173-
import ScriptIOSchema._
174-
175-
def inputRowFormat: Seq[(String, String)]
176-
177-
def outputRowFormat: Seq[(String, String)]
178-
179-
def inputSerdeClass: Option[String]
180-
181-
def outputSerdeClass: Option[String]
182-
183-
def inputSerdeProps: Seq[(String, String)]
184-
185-
def outputSerdeProps: Seq[(String, String)]
186-
187-
def recordReaderClass: Option[String]
188-
189-
def recordWriterClass: Option[String]
190-
191-
def schemaLess: Boolean
322+
case class ScriptTransformationIOSchema(
323+
inputRowFormat: Seq[(String, String)],
324+
outputRowFormat: Seq[(String, String)],
325+
inputSerdeClass: Option[String],
326+
outputSerdeClass: Option[String],
327+
inputSerdeProps: Seq[(String, String)],
328+
outputSerdeProps: Seq[(String, String)],
329+
recordReaderClass: Option[String],
330+
recordWriterClass: Option[String],
331+
schemaLess: Boolean) extends Serializable {
332+
import ScriptTransformationIOSchema._
192333

193334
val inputRowFormatMap = inputRowFormat.toMap.withDefault((k) => defaultFormat(k))
194335
val outputRowFormatMap = outputRowFormat.toMap.withDefault((k) => defaultFormat(k))
195336
}
196337

197-
object ScriptIOSchema {
338+
object ScriptTransformationIOSchema {
198339
val defaultFormat = Map(
199340
("TOK_TABLEROWFORMATFIELD", "\t"),
200341
("TOK_TABLEROWFORMATLINES", "\n")
201342
)
343+
344+
val defaultIOSchema = ScriptTransformationIOSchema(
345+
inputRowFormat = Seq.empty,
346+
outputRowFormat = Seq.empty,
347+
inputSerdeClass = None,
348+
outputSerdeClass = None,
349+
inputSerdeProps = Seq.empty,
350+
outputSerdeProps = Seq.empty,
351+
recordReaderClass = None,
352+
recordWriterClass = None,
353+
schemaLess = false
354+
)
355+
356+
def apply(input: ScriptInputOutputSchema): ScriptTransformationIOSchema = {
357+
ScriptTransformationIOSchema(
358+
input.inputRowFormat,
359+
input.outputRowFormat,
360+
input.inputSerdeClass,
361+
input.outputSerdeClass,
362+
input.inputSerdeProps,
363+
input.outputSerdeProps,
364+
input.recordReaderClass,
365+
input.recordWriterClass,
366+
input.schemaLess)
367+
}
202368
}

sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ import org.apache.spark.sql.execution._
3333
import org.apache.spark.sql.execution.command.{CreateTableCommand, DDLUtils}
3434
import org.apache.spark.sql.execution.datasources.CreateTable
3535
import org.apache.spark.sql.hive.execution._
36-
import org.apache.spark.sql.hive.execution.{HiveScriptIOSchema, HiveScriptTransformationExec}
36+
import org.apache.spark.sql.hive.execution.HiveScriptTransformationExec
3737
import org.apache.spark.sql.internal.{HiveSerDe, SQLConf}
3838

3939

@@ -244,7 +244,7 @@ private[hive] trait HiveStrategies {
244244
object HiveScripts extends Strategy {
245245
def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
246246
case ScriptTransformation(input, script, output, child, ioschema) =>
247-
val hiveIoSchema = HiveScriptIOSchema(ioschema)
247+
val hiveIoSchema = ScriptTransformationIOSchema(ioschema)
248248
HiveScriptTransformationExec(input, script, output, planLater(child), hiveIoSchema) :: Nil
249249
case _ => Nil
250250
}

0 commit comments

Comments
 (0)