diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index f8261c293782d..891e9860c73cc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -744,8 +744,30 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging selectClause.hints.asScala.foldRight(withWindow)(withHints) } + // Script Transform's input/output format. + type ScriptIOFormat = + (Seq[(String, String)], Option[String], Seq[(String, String)], Option[String]) + + protected def getRowFormatDelimited(ctx: RowFormatDelimitedContext): ScriptIOFormat = { + // TODO we should use the visitRowFormatDelimited function here. However HiveScriptIOSchema + // expects a seq of pairs in which the old parsers' token names are used as keys. + // Transforming the result of visitRowFormatDelimited would be quite a bit messier than + // retrieving the key value pairs ourselves. + def entry(key: String, value: Token): Seq[(String, String)] = { + Option(value).map(t => key -> t.getText).toSeq + } + + val entries = entry("TOK_TABLEROWFORMATFIELD", ctx.fieldsTerminatedBy) ++ + entry("TOK_TABLEROWFORMATCOLLITEMS", ctx.collectionItemsTerminatedBy) ++ + entry("TOK_TABLEROWFORMATMAPKEYS", ctx.keysTerminatedBy) ++ + entry("TOK_TABLEROWFORMATLINES", ctx.linesSeparatedBy) ++ + entry("TOK_TABLEROWFORMATNULL", ctx.nullDefinedAs) + + (entries, None, Seq.empty, None) + } + /** - * Create a (Hive based) [[ScriptInputOutputSchema]]. + * Create a [[ScriptInputOutputSchema]]. */ protected def withScriptIOSchema( ctx: ParserRuleContext, @@ -754,7 +776,30 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging outRowFormat: RowFormatContext, recordReader: Token, schemaLess: Boolean): ScriptInputOutputSchema = { - throw new ParseException("Script Transform is not supported", ctx) + + def format(fmt: RowFormatContext): ScriptIOFormat = fmt match { + case c: RowFormatDelimitedContext => + getRowFormatDelimited(c) + + case c: RowFormatSerdeContext => + throw new ParseException("TRANSFORM with serde is only supported in hive mode", ctx) + + // SPARK-32106: When there is no definition about format, we return empty result + // to use a built-in default Serde in SparkScriptTransformationExec. + case null => + (Nil, None, Seq.empty, None) + } + + val (inFormat, inSerdeClass, inSerdeProps, reader) = format(inRowFormat) + + val (outFormat, outSerdeClass, outSerdeProps, writer) = format(outRowFormat) + + ScriptInputOutputSchema( + inFormat, outFormat, + inSerdeClass, outSerdeClass, + inSerdeProps, outSerdeProps, + reader, writer, + schemaLess) } /** diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala index 88afcb10d9c20..e4790e2dfa634 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala @@ -23,7 +23,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.types.IntegerType +import org.apache.spark.sql.types.{IntegerType, LongType, StringType} /** * Parser test cases for rules defined in [[CatalystSqlParser]] / [[AstBuilder]]. @@ -1031,4 +1031,115 @@ class PlanParserSuite extends AnalysisTest { assertEqual("select a, b from db.c;;;", table("db", "c").select('a, 'b)) assertEqual("select a, b from db.c; ;; ;", table("db", "c").select('a, 'b)) } + + test("SPARK-32106: TRANSFORM plan") { + // verify schema less + assertEqual( + """ + |SELECT TRANSFORM(a, b, c) + |USING 'cat' + |FROM testData + """.stripMargin, + ScriptTransformation( + Seq('a, 'b, 'c), + "cat", + Seq(AttributeReference("key", StringType)(), + AttributeReference("value", StringType)()), + UnresolvedRelation(TableIdentifier("testData")), + ScriptInputOutputSchema(List.empty, List.empty, None, None, + List.empty, List.empty, None, None, true)) + ) + + // verify without output schema + assertEqual( + """ + |SELECT TRANSFORM(a, b, c) + |USING 'cat' AS (a, b, c) + |FROM testData + """.stripMargin, + ScriptTransformation( + Seq('a, 'b, 'c), + "cat", + Seq(AttributeReference("a", StringType)(), + AttributeReference("b", StringType)(), + AttributeReference("c", StringType)()), + UnresolvedRelation(TableIdentifier("testData")), + ScriptInputOutputSchema(List.empty, List.empty, None, None, + List.empty, List.empty, None, None, false))) + + // verify with output schema + assertEqual( + """ + |SELECT TRANSFORM(a, b, c) + |USING 'cat' AS (a int, b string, c long) + |FROM testData + """.stripMargin, + ScriptTransformation( + Seq('a, 'b, 'c), + "cat", + Seq(AttributeReference("a", IntegerType)(), + AttributeReference("b", StringType)(), + AttributeReference("c", LongType)()), + UnresolvedRelation(TableIdentifier("testData")), + ScriptInputOutputSchema(List.empty, List.empty, None, None, + List.empty, List.empty, None, None, false))) + + // verify with ROW FORMAT DELIMETED + assertEqual( + """ + |SELECT TRANSFORM(a, b, c) + |ROW FORMAT DELIMITED + |FIELDS TERMINATED BY '\t' + |COLLECTION ITEMS TERMINATED BY '\u0002' + |MAP KEYS TERMINATED BY '\u0003' + |LINES TERMINATED BY '\n' + |NULL DEFINED AS 'null' + |USING 'cat' AS (a, b, c) + |ROW FORMAT DELIMITED + |FIELDS TERMINATED BY '\t' + |COLLECTION ITEMS TERMINATED BY '\u0004' + |MAP KEYS TERMINATED BY '\u0005' + |LINES TERMINATED BY '\n' + |NULL DEFINED AS 'NULL' + |FROM testData + """.stripMargin, + ScriptTransformation( + Seq('a, 'b, 'c), + "cat", + Seq(AttributeReference("a", StringType)(), + AttributeReference("b", StringType)(), + AttributeReference("c", StringType)()), + UnresolvedRelation(TableIdentifier("testData")), + ScriptInputOutputSchema( + Seq(("TOK_TABLEROWFORMATFIELD", "'\\t'"), + ("TOK_TABLEROWFORMATCOLLITEMS", "'\u0002'"), + ("TOK_TABLEROWFORMATMAPKEYS", "'\u0003'"), + ("TOK_TABLEROWFORMATLINES", "'\\n'"), + ("TOK_TABLEROWFORMATNULL", "'null'")), + Seq(("TOK_TABLEROWFORMATFIELD", "'\\t'"), + ("TOK_TABLEROWFORMATCOLLITEMS", "'\u0004'"), + ("TOK_TABLEROWFORMATMAPKEYS", "'\u0005'"), + ("TOK_TABLEROWFORMATLINES", "'\\n'"), + ("TOK_TABLEROWFORMATNULL", "'NULL'")), None, None, + List.empty, List.empty, None, None, false))) + + // verify with ROW FORMAT SERDE + intercept( + """ + |SELECT TRANSFORM(a, b, c) + |ROW FORMAT SERDE 'org.apache.hadoop.hive.serde2.OpenCSVSerde' + |WITH SERDEPROPERTIES( + | "separatorChar" = "\t", + | "quoteChar" = "'", + | "escapeChar" = "\\") + |USING 'cat' AS (a, b, c) + |ROW FORMAT SERDE 'org.apache.hadoop.hive.serde2.OpenCSVSerde' + |WITH SERDEPROPERTIES( + | "separatorChar" = "\t", + | "quoteChar" = "'", + | "escapeChar" = "\\") + |FROM testData + """.stripMargin, + "TRANSFORM with serde is only supported in hive mode") + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/BaseScriptTransformationExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/BaseScriptTransformationExec.scala index 22bf6df58b040..7760a3797eb49 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/BaseScriptTransformationExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/BaseScriptTransformationExec.scala @@ -17,10 +17,11 @@ package org.apache.spark.sql.execution -import java.io.OutputStream +import java.io.{BufferedReader, InputStream, InputStreamReader, OutputStream} import java.nio.charset.StandardCharsets import java.util.concurrent.TimeUnit +import scala.collection.JavaConverters._ import scala.util.control.NonFatal import org.apache.hadoop.conf.Configuration @@ -28,14 +29,26 @@ import org.apache.hadoop.conf.Configuration import org.apache.spark.{SparkException, TaskContext} import org.apache.spark.internal.Logging import org.apache.spark.rdd.RDD -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.{AttributeSet, UnsafeProjection} +import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeSet, Cast, Expression, GenericInternalRow, UnsafeProjection} +import org.apache.spark.sql.catalyst.plans.logical.ScriptInputOutputSchema import org.apache.spark.sql.catalyst.plans.physical.Partitioning +import org.apache.spark.sql.catalyst.util.{DateTimeUtils, IntervalUtils} import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.types.DataType -import org.apache.spark.util.{CircularBuffer, SerializableConfiguration, Utils} +import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String +import org.apache.spark.util.{CircularBuffer, RedirectThread, SerializableConfiguration, Utils} trait BaseScriptTransformationExec extends UnaryExecNode { + def input: Seq[Expression] + def script: String + def output: Seq[Attribute] + def child: SparkPlan + def ioschema: ScriptTransformationIOSchema + + protected lazy val inputExpressionsWithoutSerde: Seq[Expression] = { + input.map(Cast(_, StringType).withTimeZone(conf.sessionLocalTimeZone)) + } override def producedAttributes: AttributeSet = outputSet -- inputSet @@ -56,10 +69,91 @@ trait BaseScriptTransformationExec extends UnaryExecNode { } } - def processIterator( + protected def initProc: (OutputStream, Process, InputStream, CircularBuffer) = { + val cmd = List("/bin/bash", "-c", script) + val builder = new ProcessBuilder(cmd.asJava) + + val proc = builder.start() + val inputStream = proc.getInputStream + val outputStream = proc.getOutputStream + val errorStream = proc.getErrorStream + + // In order to avoid deadlocks, we need to consume the error output of the child process. + // To avoid issues caused by large error output, we use a circular buffer to limit the amount + // of error output that we retain. See SPARK-7862 for more discussion of the deadlock / hang + // that motivates this. + val stderrBuffer = new CircularBuffer(2048) + new RedirectThread( + errorStream, + stderrBuffer, + s"Thread-${this.getClass.getSimpleName}-STDERR-Consumer").start() + (outputStream, proc, inputStream, stderrBuffer) + } + + protected def processIterator( inputIterator: Iterator[InternalRow], hadoopConf: Configuration): Iterator[InternalRow] + protected def createOutputIteratorWithoutSerde( + writerThread: BaseScriptTransformationWriterThread, + inputStream: InputStream, + proc: Process, + stderrBuffer: CircularBuffer): Iterator[InternalRow] = { + new Iterator[InternalRow] { + var curLine: String = null + val reader = new BufferedReader(new InputStreamReader(inputStream, StandardCharsets.UTF_8)) + + val outputRowFormat = ioschema.outputRowFormatMap("TOK_TABLEROWFORMATFIELD") + val processRowWithoutSerde = if (!ioschema.schemaLess) { + prevLine: String => + new GenericInternalRow( + prevLine.split(outputRowFormat) + .zip(outputFieldWriters) + .map { case (data, writer) => writer(data) }) + } else { + // In schema less mode, hive default serde will choose first two output column as output + // if output column size less then 2, it will throw ArrayIndexOutOfBoundsException. + // Here we change spark's behavior same as hive's default serde. + // But in hive, TRANSFORM with schema less behavior like origin spark, we will fix this + // to keep spark and hive behavior same in SPARK-32388 + val kvWriter = CatalystTypeConverters.createToCatalystConverter(StringType) + prevLine: String => + new GenericInternalRow( + prevLine.split(outputRowFormat).slice(0, 2) + .map(kvWriter)) + } + + override def hasNext: Boolean = { + try { + if (curLine == null) { + curLine = reader.readLine() + if (curLine == null) { + checkFailureAndPropagate(writerThread, null, proc, stderrBuffer) + return false + } + } + true + } catch { + case NonFatal(e) => + // If this exception is due to abrupt / unclean termination of `proc`, + // then detect it and propagate a better exception message for end users + checkFailureAndPropagate(writerThread, e, proc, stderrBuffer) + + throw e + } + } + + override def next(): InternalRow = { + if (!hasNext) { + throw new NoSuchElementException + } + val prevLine = curLine + curLine = reader.readLine() + processRowWithoutSerde(prevLine) + } + } + } + protected def checkFailureAndPropagate( writerThread: BaseScriptTransformationWriterThread, cause: Throwable = null, @@ -87,17 +181,72 @@ trait BaseScriptTransformationExec extends UnaryExecNode { } } } + + private lazy val outputFieldWriters: Seq[String => Any] = output.map { attr => + val converter = CatalystTypeConverters.createToCatalystConverter(attr.dataType) + attr.dataType match { + case StringType => wrapperConvertException(data => data, converter) + case BooleanType => wrapperConvertException(data => data.toBoolean, converter) + case ByteType => wrapperConvertException(data => data.toByte, converter) + case BinaryType => + wrapperConvertException(data => UTF8String.fromString(data).getBytes, converter) + case IntegerType => wrapperConvertException(data => data.toInt, converter) + case ShortType => wrapperConvertException(data => data.toShort, converter) + case LongType => wrapperConvertException(data => data.toLong, converter) + case FloatType => wrapperConvertException(data => data.toFloat, converter) + case DoubleType => wrapperConvertException(data => data.toDouble, converter) + case _: DecimalType => wrapperConvertException(data => BigDecimal(data), converter) + case DateType if conf.datetimeJava8ApiEnabled => + wrapperConvertException(data => DateTimeUtils.stringToDate( + UTF8String.fromString(data), + DateTimeUtils.getZoneId(conf.sessionLocalTimeZone)) + .map(DateTimeUtils.daysToLocalDate).orNull, converter) + case DateType => wrapperConvertException(data => DateTimeUtils.stringToDate( + UTF8String.fromString(data), + DateTimeUtils.getZoneId(conf.sessionLocalTimeZone)) + .map(DateTimeUtils.toJavaDate).orNull, converter) + case TimestampType if conf.datetimeJava8ApiEnabled => + wrapperConvertException(data => DateTimeUtils.stringToTimestamp( + UTF8String.fromString(data), + DateTimeUtils.getZoneId(conf.sessionLocalTimeZone)) + .map(DateTimeUtils.microsToInstant).orNull, converter) + case TimestampType => wrapperConvertException(data => DateTimeUtils.stringToTimestamp( + UTF8String.fromString(data), + DateTimeUtils.getZoneId(conf.sessionLocalTimeZone)) + .map(DateTimeUtils.toJavaTimestamp).orNull, converter) + case CalendarIntervalType => wrapperConvertException( + data => IntervalUtils.stringToInterval(UTF8String.fromString(data)), + converter) + case udt: UserDefinedType[_] => + wrapperConvertException(data => udt.deserialize(data), converter) + case dt => + throw new SparkException(s"${nodeName} without serde does not support " + + s"${dt.getClass.getSimpleName} as output data type") + } + } + + // Keep consistent with Hive `LazySimpleSerde`, when there is a type case error, return null + private val wrapperConvertException: (String => Any, Any => Any) => String => Any = + (f: String => Any, converter: Any => Any) => + (data: String) => converter { + try { + f(data) + } catch { + case NonFatal(_) => null + } + } } -abstract class BaseScriptTransformationWriterThread( - iter: Iterator[InternalRow], - inputSchema: Seq[DataType], - ioSchema: BaseScriptTransformIOSchema, - outputStream: OutputStream, - proc: Process, - stderrBuffer: CircularBuffer, - taskContext: TaskContext, - conf: Configuration) extends Thread with Logging { +abstract class BaseScriptTransformationWriterThread extends Thread with Logging { + + def iter: Iterator[InternalRow] + def inputSchema: Seq[DataType] + def ioSchema: ScriptTransformationIOSchema + def outputStream: OutputStream + def proc: Process + def stderrBuffer: CircularBuffer + def taskContext: TaskContext + def conf: Configuration setDaemon(true) @@ -169,34 +318,50 @@ abstract class BaseScriptTransformationWriterThread( /** * The wrapper class of input and output schema properties */ -abstract class BaseScriptTransformIOSchema extends Serializable { - import ScriptIOSchema._ - - def inputRowFormat: Seq[(String, String)] - - def outputRowFormat: Seq[(String, String)] - - def inputSerdeClass: Option[String] - - def outputSerdeClass: Option[String] - - def inputSerdeProps: Seq[(String, String)] - - def outputSerdeProps: Seq[(String, String)] - - def recordReaderClass: Option[String] - - def recordWriterClass: Option[String] - - def schemaLess: Boolean +case class ScriptTransformationIOSchema( + inputRowFormat: Seq[(String, String)], + outputRowFormat: Seq[(String, String)], + inputSerdeClass: Option[String], + outputSerdeClass: Option[String], + inputSerdeProps: Seq[(String, String)], + outputSerdeProps: Seq[(String, String)], + recordReaderClass: Option[String], + recordWriterClass: Option[String], + schemaLess: Boolean) extends Serializable { + import ScriptTransformationIOSchema._ val inputRowFormatMap = inputRowFormat.toMap.withDefault((k) => defaultFormat(k)) val outputRowFormatMap = outputRowFormat.toMap.withDefault((k) => defaultFormat(k)) } -object ScriptIOSchema { +object ScriptTransformationIOSchema { val defaultFormat = Map( ("TOK_TABLEROWFORMATFIELD", "\t"), ("TOK_TABLEROWFORMATLINES", "\n") ) + + val defaultIOSchema = ScriptTransformationIOSchema( + inputRowFormat = Seq.empty, + outputRowFormat = Seq.empty, + inputSerdeClass = None, + outputSerdeClass = None, + inputSerdeProps = Seq.empty, + outputSerdeProps = Seq.empty, + recordReaderClass = None, + recordWriterClass = None, + schemaLess = false + ) + + def apply(input: ScriptInputOutputSchema): ScriptTransformationIOSchema = { + ScriptTransformationIOSchema( + input.inputRowFormat, + input.outputRowFormat, + input.inputSerdeClass, + input.outputSerdeClass, + input.inputSerdeProps, + input.outputSerdeProps, + input.recordReaderClass, + input.recordWriterClass, + input.schemaLess) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala index 895eeedd86b8b..b96a861196897 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala @@ -46,6 +46,7 @@ class SparkPlanner( Window :: JoinSelection :: InMemoryScans :: + SparkScripts:: BasicOperators :: Nil) /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkScriptTransformationExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkScriptTransformationExec.scala new file mode 100644 index 0000000000000..b87c20e6a5656 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkScriptTransformationExec.scala @@ -0,0 +1,89 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +import java.io._ + +import org.apache.hadoop.conf.Configuration + +import org.apache.spark.TaskContext +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.types._ +import org.apache.spark.util.CircularBuffer + +/** + * Transforms the input by forking and running the specified script. + * + * @param input the set of expression that should be passed to the script. + * @param script the command that should be executed. + * @param output the attributes that are produced by the script. + */ +case class SparkScriptTransformationExec( + input: Seq[Expression], + script: String, + output: Seq[Attribute], + child: SparkPlan, + ioschema: ScriptTransformationIOSchema) + extends BaseScriptTransformationExec { + + override def processIterator( + inputIterator: Iterator[InternalRow], + hadoopConf: Configuration): Iterator[InternalRow] = { + + val (outputStream, proc, inputStream, stderrBuffer) = initProc + + val outputProjection = new InterpretedProjection(inputExpressionsWithoutSerde, child.output) + + // This new thread will consume the ScriptTransformation's input rows and write them to the + // external process. That process's output will be read by this current thread. + val writerThread = SparkScriptTransformationWriterThread( + inputIterator.map(outputProjection), + inputExpressionsWithoutSerde.map(_.dataType), + ioschema, + outputStream, + proc, + stderrBuffer, + TaskContext.get(), + hadoopConf + ) + + val outputIterator = + createOutputIteratorWithoutSerde(writerThread, inputStream, proc, stderrBuffer) + + writerThread.start() + + outputIterator + } +} + +case class SparkScriptTransformationWriterThread( + iter: Iterator[InternalRow], + inputSchema: Seq[DataType], + ioSchema: ScriptTransformationIOSchema, + outputStream: OutputStream, + proc: Process, + stderrBuffer: CircularBuffer, + taskContext: TaskContext, + conf: Configuration) + extends BaseScriptTransformationWriterThread { + + override def processRows(): Unit = { + processRowsWithoutSerde() + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala index 3a2c673229c20..7ef46c949db6f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala @@ -35,6 +35,7 @@ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.execution.command._ import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.internal.{HiveSerDe, SQLConf, VariableSubstitution} +import org.apache.spark.sql.internal.StaticSQLConf.CATALOG_IMPLEMENTATION import org.apache.spark.sql.types.StructType /** @@ -664,7 +665,7 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder(conf) { } /** - * Create a [[ScriptInputOutputSchema]]. + * Create a hive serde [[ScriptInputOutputSchema]]. */ override protected def withScriptIOSchema( ctx: ParserRuleContext, @@ -679,64 +680,60 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder(conf) { "Unsupported operation: Used defined record reader/writer classes.", ctx) } - // Decode and input/output format. - type Format = (Seq[(String, String)], Option[String], Seq[(String, String)], Option[String]) - def format( - fmt: RowFormatContext, - configKey: String, - defaultConfigValue: String): Format = fmt match { - case c: RowFormatDelimitedContext => - // TODO we should use the visitRowFormatDelimited function here. However HiveScriptIOSchema - // expects a seq of pairs in which the old parsers' token names are used as keys. - // Transforming the result of visitRowFormatDelimited would be quite a bit messier than - // retrieving the key value pairs ourselves. - def entry(key: String, value: Token): Seq[(String, String)] = { - Option(value).map(t => key -> t.getText).toSeq - } - val entries = entry("TOK_TABLEROWFORMATFIELD", c.fieldsTerminatedBy) ++ - entry("TOK_TABLEROWFORMATCOLLITEMS", c.collectionItemsTerminatedBy) ++ - entry("TOK_TABLEROWFORMATMAPKEYS", c.keysTerminatedBy) ++ - entry("TOK_TABLEROWFORMATLINES", c.linesSeparatedBy) ++ - entry("TOK_TABLEROWFORMATNULL", c.nullDefinedAs) - - (entries, None, Seq.empty, None) - - case c: RowFormatSerdeContext => - // Use a serde format. - val CatalogStorageFormat(None, None, None, Some(name), _, props) = visitRowFormatSerde(c) + if (!conf.getConf(CATALOG_IMPLEMENTATION).equals("hive")) { + super.withScriptIOSchema( + ctx, + inRowFormat, + recordWriter, + outRowFormat, + recordReader, + schemaLess) + } else { + def format( + fmt: RowFormatContext, + configKey: String, + defaultConfigValue: String): ScriptIOFormat = fmt match { + case c: RowFormatDelimitedContext => + getRowFormatDelimited(c) + + case c: RowFormatSerdeContext => + // Use a serde format. + val CatalogStorageFormat(None, None, None, Some(name), _, props) = visitRowFormatSerde(c) + + // SPARK-10310: Special cases LazySimpleSerDe + val recordHandler = if (name == "org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe") { + Option(conf.getConfString(configKey, defaultConfigValue)) + } else { + None + } + (Seq.empty, Option(name), props.toSeq, recordHandler) + + case null => + // Use default (serde) format. + val name = conf.getConfString("hive.script.serde", + "org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe") + val props = Seq("field.delim" -> "\t") + val recordHandler = Option(conf.getConfString(configKey, defaultConfigValue)) + (Nil, Option(name), props, recordHandler) + } - // SPARK-10310: Special cases LazySimpleSerDe - val recordHandler = if (name == "org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe") { - Option(conf.getConfString(configKey, defaultConfigValue)) - } else { - None - } - (Seq.empty, Option(name), props.toSeq, recordHandler) - - case null => - // Use default (serde) format. - val name = conf.getConfString("hive.script.serde", - "org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe") - val props = Seq("field.delim" -> "\t") - val recordHandler = Option(conf.getConfString(configKey, defaultConfigValue)) - (Nil, Option(name), props, recordHandler) + val (inFormat, inSerdeClass, inSerdeProps, reader) = + format( + inRowFormat, "hive.script.recordreader", + "org.apache.hadoop.hive.ql.exec.TextRecordReader") + + val (outFormat, outSerdeClass, outSerdeProps, writer) = + format( + outRowFormat, "hive.script.recordwriter", + "org.apache.hadoop.hive.ql.exec.TextRecordWriter") + + ScriptInputOutputSchema( + inFormat, outFormat, + inSerdeClass, outSerdeClass, + inSerdeProps, outSerdeProps, + reader, writer, + schemaLess) } - - val (inFormat, inSerdeClass, inSerdeProps, reader) = - format( - inRowFormat, "hive.script.recordreader", "org.apache.hadoop.hive.ql.exec.TextRecordReader") - - val (outFormat, outSerdeClass, outSerdeProps, writer) = - format( - outRowFormat, "hive.script.recordwriter", - "org.apache.hadoop.hive.ql.exec.TextRecordWriter") - - ScriptInputOutputSchema( - inFormat, outFormat, - inSerdeClass, outSerdeClass, - inSerdeProps, outSerdeProps, - reader, writer, - schemaLess) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 583e5a2c5c57e..1e0d0c346731a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -532,6 +532,20 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { } } + object SparkScripts extends Strategy { + def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { + case logical.ScriptTransformation(input, script, output, child, ioschema) => + SparkScriptTransformationExec( + input, + script, + output, + planLater(child), + ScriptTransformationIOSchema(ioschema) + ) :: Nil + case _ => Nil + } + } + /** * This strategy is just for explaining `Dataset/DataFrame` created by `spark.readStream`. * It won't affect the execution, because `StreamingRelation` will be replaced with diff --git a/sql/core/src/test/resources/sql-tests/inputs/transform.sql b/sql/core/src/test/resources/sql-tests/inputs/transform.sql new file mode 100644 index 0000000000000..8610e384d6fab --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/inputs/transform.sql @@ -0,0 +1,114 @@ +-- Test data. +CREATE OR REPLACE TEMPORARY VIEW t AS SELECT * FROM VALUES +('1', true, unhex('537061726B2053514C'), tinyint(1), 1, smallint(100), bigint(1), float(1.0), 1.0, Decimal(1.0), timestamp('1997-01-02'), date('2000-04-01')), +('2', false, unhex('537061726B2053514C'), tinyint(2), 2, smallint(200), bigint(2), float(2.0), 2.0, Decimal(2.0), timestamp('1997-01-02 03:04:05'), date('2000-04-02')), +('3', true, unhex('537061726B2053514C'), tinyint(3), 3, smallint(300), bigint(3), float(3.0), 3.0, Decimal(3.0), timestamp('1997-02-10 17:32:01-08'), date('2000-04-03')) +AS t(a, b, c, d, e, f, g, h, i, j, k, l); + +SELECT TRANSFORM(a) +USING 'cat' AS (a) +FROM t; + +-- with non-exist command +SELECT TRANSFORM(a) +USING 'some_non_existent_command' AS (a) +FROM t; + +-- with non-exist file +SELECT TRANSFORM(a) +USING 'python some_non_existent_file' AS (a) +FROM t; + +-- common supported data types between no serde and serde transform +SELECT a, b, decode(c, 'UTF-8'), d, e, f, g, h, i, j, k, l FROM ( + SELECT TRANSFORM(a, b, c, d, e, f, g, h, i, j, k, l) + USING 'cat' AS ( + a string, + b boolean, + c binary, + d tinyint, + e int, + f smallint, + g long, + h float, + i double, + j decimal(38, 18), + k timestamp, + l date) + FROM t +) tmp; + +-- common supported data types between no serde and serde transform +SELECT a, b, decode(c, 'UTF-8'), d, e, f, g, h, i, j, k, l FROM ( + SELECT TRANSFORM(a, b, c, d, e, f, g, h, i, j, k, l) + USING 'cat' AS ( + a string, + b string, + c string, + d string, + e string, + f string, + g string, + h string, + i string, + j string, + k string, + l string) + FROM t +) tmp; + +-- SPARK-32388 handle schema less +SELECT TRANSFORM(a) +USING 'cat' +FROM t; + +SELECT TRANSFORM(a, b) +USING 'cat' +FROM t; + +SELECT TRANSFORM(a, b, c) +USING 'cat' +FROM t; + +-- return null when return string incompatible (no serde) +SELECT TRANSFORM(a, b, c, d, e, f, g, h, i) +USING 'cat' AS (a int, b short, c long, d byte, e float, f double, g decimal(38, 18), h date, i timestamp) +FROM VALUES +('a','','1231a','a','213.21a','213.21a','0a.21d','2000-04-01123','1997-0102 00:00:') tmp(a, b, c, d, e, f, g, h, i); + +-- SPARK-28227: transform can't run with aggregation +SELECT TRANSFORM(b, max(a), sum(f)) +USING 'cat' AS (a, b) +FROM t +GROUP BY b; + +-- transform use MAP +MAP a, b USING 'cat' AS (a, b) FROM t; + +-- transform use REDUCE +REDUCE a, b USING 'cat' AS (a, b) FROM t; + +-- transform with defined row format delimit +SELECT TRANSFORM(a, b, c, null) +ROW FORMAT DELIMITED +FIELDS TERMINATED BY '|' +LINES TERMINATED BY '\n' +NULL DEFINED AS 'NULL' +USING 'cat' AS (a, b, c, d) +ROW FORMAT DELIMITED +FIELDS TERMINATED BY '|' +LINES TERMINATED BY '\n' +NULL DEFINED AS 'NULL' +FROM t; + +SELECT TRANSFORM(a, b, c, null) +ROW FORMAT DELIMITED +FIELDS TERMINATED BY '|' +LINES TERMINATED BY '\n' +NULL DEFINED AS 'NULL' +USING 'cat' AS (d) +ROW FORMAT DELIMITED +FIELDS TERMINATED BY '||' +LINES TERMINATED BY '\n' +NULL DEFINED AS 'NULL' +FROM t; diff --git a/sql/core/src/test/resources/sql-tests/results/transform.sql.out b/sql/core/src/test/resources/sql-tests/results/transform.sql.out new file mode 100644 index 0000000000000..744d6384f9c45 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/results/transform.sql.out @@ -0,0 +1,224 @@ +-- Automatically generated by SQLQueryTestSuite +-- Number of queries: 15 + + +-- !query +CREATE OR REPLACE TEMPORARY VIEW t AS SELECT * FROM VALUES +('1', true, unhex('537061726B2053514C'), tinyint(1), 1, smallint(100), bigint(1), float(1.0), 1.0, Decimal(1.0), timestamp('1997-01-02'), date('2000-04-01')), +('2', false, unhex('537061726B2053514C'), tinyint(2), 2, smallint(200), bigint(2), float(2.0), 2.0, Decimal(2.0), timestamp('1997-01-02 03:04:05'), date('2000-04-02')), +('3', true, unhex('537061726B2053514C'), tinyint(3), 3, smallint(300), bigint(3), float(3.0), 3.0, Decimal(3.0), timestamp('1997-02-10 17:32:01-08'), date('2000-04-03')) +AS t(a, b, c, d, e, f, g, h, i, j, k, l) +-- !query schema +struct<> +-- !query output + + + +-- !query +SELECT TRANSFORM(a) +USING 'cat' AS (a) +FROM t +-- !query schema +struct +-- !query output +1 +2 +3 + + +-- !query +SELECT TRANSFORM(a) +USING 'some_non_existent_command' AS (a) +FROM t +-- !query schema +struct<> +-- !query output +org.apache.spark.SparkException +Subprocess exited with status 127. Error: /bin/bash: some_non_existent_command: command not found + + +-- !query +SELECT TRANSFORM(a) +USING 'python some_non_existent_file' AS (a) +FROM t +-- !query schema +struct<> +-- !query output +org.apache.spark.SparkException +Subprocess exited with status 2. Error: python: can't open file 'some_non_existent_file': [Errno 2] No such file or directory + + +-- !query +SELECT a, b, decode(c, 'UTF-8'), d, e, f, g, h, i, j, k, l FROM ( + SELECT TRANSFORM(a, b, c, d, e, f, g, h, i, j, k, l) + USING 'cat' AS ( + a string, + b boolean, + c binary, + d tinyint, + e int, + f smallint, + g long, + h float, + i double, + j decimal(38, 18), + k timestamp, + l date) + FROM t +) tmp +-- !query schema +struct +-- !query output +1 true Spark SQL 1 1 100 1 1.0 1.0 1.000000000000000000 1997-01-02 00:00:00 2000-04-01 +2 false Spark SQL 2 2 200 2 2.0 2.0 2.000000000000000000 1997-01-02 03:04:05 2000-04-02 +3 true Spark SQL 3 3 300 3 3.0 3.0 3.000000000000000000 1997-02-10 17:32:01 2000-04-03 + + +-- !query +SELECT a, b, decode(c, 'UTF-8'), d, e, f, g, h, i, j, k, l FROM ( + SELECT TRANSFORM(a, b, c, d, e, f, g, h, i, j, k, l) + USING 'cat' AS ( + a string, + b string, + c string, + d string, + e string, + f string, + g string, + h string, + i string, + j string, + k string, + l string) + FROM t +) tmp +-- !query schema +struct +-- !query output +1 true Spark SQL 1 1 100 1 1.0 1.0 1 1997-01-02 00:00:00 2000-04-01 +2 false Spark SQL 2 2 200 2 2.0 2.0 2 1997-01-02 03:04:05 2000-04-02 +3 true Spark SQL 3 3 300 3 3.0 3.0 3 1997-02-10 17:32:01 2000-04-03 + + +-- !query +SELECT TRANSFORM(a) +USING 'cat' +FROM t +-- !query schema +struct<> +-- !query output +java.lang.ArrayIndexOutOfBoundsException +1 + + +-- !query +SELECT TRANSFORM(a, b) +USING 'cat' +FROM t +-- !query schema +struct +-- !query output +1 true +2 false +3 true + + +-- !query +SELECT TRANSFORM(a, b, c) +USING 'cat' +FROM t +-- !query schema +struct +-- !query output +1 true +2 false +3 true + + +-- !query +SELECT TRANSFORM(a, b, c, d, e, f, g, h, i) +USING 'cat' AS (a int, b short, c long, d byte, e float, f double, g decimal(38, 18), h date, i timestamp) +FROM VALUES +('a','','1231a','a','213.21a','213.21a','0a.21d','2000-04-01123','1997-0102 00:00:') tmp(a, b, c, d, e, f, g, h, i) +-- !query schema +struct +-- !query output +NULL NULL NULL NULL NULL NULL NULL NULL NULL + + +-- !query +SELECT TRANSFORM(b, max(a), sum(f)) +USING 'cat' AS (a, b) +FROM t +GROUP BY b +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.catalyst.parser.ParseException + +mismatched input 'GROUP' expecting {, ';'}(line 4, pos 0) + +== SQL == +SELECT TRANSFORM(b, max(a), sum(f)) +USING 'cat' AS (a, b) +FROM t +GROUP BY b +^^^ + + +-- !query +MAP a, b USING 'cat' AS (a, b) FROM t +-- !query schema +struct +-- !query output +1 true +2 false +3 true + + +-- !query +REDUCE a, b USING 'cat' AS (a, b) FROM t +-- !query schema +struct +-- !query output +1 true +2 false +3 true + + +-- !query +SELECT TRANSFORM(a, b, c, null) +ROW FORMAT DELIMITED +FIELDS TERMINATED BY '|' +LINES TERMINATED BY '\n' +NULL DEFINED AS 'NULL' +USING 'cat' AS (a, b, c, d) +ROW FORMAT DELIMITED +FIELDS TERMINATED BY '|' +LINES TERMINATED BY '\n' +NULL DEFINED AS 'NULL' +FROM t +-- !query schema +struct +-- !query output +1 | true | +2 | false | + + +-- !query +SELECT TRANSFORM(a, b, c, null) +ROW FORMAT DELIMITED +FIELDS TERMINATED BY '|' +LINES TERMINATED BY '\n' +NULL DEFINED AS 'NULL' +USING 'cat' AS (d) +ROW FORMAT DELIMITED +FIELDS TERMINATED BY '||' +LINES TERMINATED BY '\n' +NULL DEFINED AS 'NULL' +FROM t +-- !query schema +struct +-- !query output +1 +2 diff --git a/sql/hive/src/test/resources/test_script.py b/sql/core/src/test/resources/test_script.py similarity index 100% rename from sql/hive/src/test/resources/test_script.py rename to sql/core/src/test/resources/test_script.py diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala index f0522dfeafaac..8f18468e36bcb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala @@ -23,7 +23,7 @@ import java.util.Locale import scala.collection.mutable.ArrayBuffer import scala.util.control.NonFatal -import org.apache.spark.{SparkConf, SparkException} +import org.apache.spark.{SparkConf, SparkException, TestUtils} import org.apache.spark.sql.catalyst.expressions.codegen.CodeGenerator import org.apache.spark.sql.catalyst.planning.PhysicalOperation import org.apache.spark.sql.catalyst.plans.SQLHelper @@ -258,6 +258,9 @@ class SQLQueryTestSuite extends QueryTest with SharedSparkSession with SQLHelper newLine.startsWith("--") && !newLine.startsWith("--QUERY-DELIMITER") } + // SPARK-32106 Since we add SQL test 'transform.sql' will use `cat` command, + // here we need to check command available + assume(TestUtils.testCommandAvailable("/bin/bash")) val input = fileToString(new File(testCase.inputFile)) val (comments, code) = splitCommentsAndCodes(input) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/BaseScriptTransformationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/BaseScriptTransformationSuite.scala new file mode 100644 index 0000000000000..101c9a5c899db --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/BaseScriptTransformationSuite.scala @@ -0,0 +1,383 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +import java.sql.{Date, Timestamp} + +import org.json4s.DefaultFormats +import org.json4s.JsonDSL._ +import org.json4s.jackson.JsonMethods._ +import org.scalatest.Assertions._ +import org.scalatest.BeforeAndAfterEach +import org.scalatest.exceptions.TestFailedException + +import org.apache.spark.{SparkException, TaskContext, TestUtils} +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.Column +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, Expression, GenericInternalRow} +import org.apache.spark.sql.catalyst.plans.physical.Partitioning +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.test.SQLTestUtils +import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.CalendarInterval + +abstract class BaseScriptTransformationSuite extends SparkPlanTest with SQLTestUtils + with BeforeAndAfterEach { + import testImplicits._ + import ScriptTransformationIOSchema._ + + protected val uncaughtExceptionHandler = new TestUncaughtExceptionHandler + + private var defaultUncaughtExceptionHandler: Thread.UncaughtExceptionHandler = _ + + protected override def beforeAll(): Unit = { + super.beforeAll() + defaultUncaughtExceptionHandler = Thread.getDefaultUncaughtExceptionHandler + Thread.setDefaultUncaughtExceptionHandler(uncaughtExceptionHandler) + } + + protected override def afterAll(): Unit = { + super.afterAll() + Thread.setDefaultUncaughtExceptionHandler(defaultUncaughtExceptionHandler) + } + + override protected def afterEach(): Unit = { + super.afterEach() + uncaughtExceptionHandler.cleanStatus() + } + + def isHive23OrSpark: Boolean + + def createScriptTransformationExec( + input: Seq[Expression], + script: String, + output: Seq[Attribute], + child: SparkPlan, + ioschema: ScriptTransformationIOSchema): BaseScriptTransformationExec + + test("cat without SerDe") { + assume(TestUtils.testCommandAvailable("/bin/bash")) + + val rowsDf = Seq("a", "b", "c").map(Tuple1.apply).toDF("a") + checkAnswer( + rowsDf, + (child: SparkPlan) => createScriptTransformationExec( + input = Seq(rowsDf.col("a").expr), + script = "cat", + output = Seq(AttributeReference("a", StringType)()), + child = child, + ioschema = defaultIOSchema + ), + rowsDf.collect()) + assert(uncaughtExceptionHandler.exception.isEmpty) + } + + test("script transformation should not swallow errors from upstream operators (no serde)") { + assume(TestUtils.testCommandAvailable("/bin/bash")) + + val rowsDf = Seq("a", "b", "c").map(Tuple1.apply).toDF("a") + val e = intercept[TestFailedException] { + checkAnswer( + rowsDf, + (child: SparkPlan) => createScriptTransformationExec( + input = Seq(rowsDf.col("a").expr), + script = "cat", + output = Seq(AttributeReference("a", StringType)()), + child = ExceptionInjectingOperator(child), + ioschema = defaultIOSchema + ), + rowsDf.collect()) + } + assert(e.getMessage().contains("intentional exception")) + // Before SPARK-25158, uncaughtExceptionHandler will catch IllegalArgumentException + assert(uncaughtExceptionHandler.exception.isEmpty) + } + + test("SPARK-25990: TRANSFORM should handle different data types correctly") { + assume(TestUtils.testCommandAvailable("python")) + val scriptFilePath = getTestResourcePath("test_script.py") + + withTempView("v") { + val df = Seq( + (1, "1", 1.0, BigDecimal(1.0), new Timestamp(1)), + (2, "2", 2.0, BigDecimal(2.0), new Timestamp(2)), + (3, "3", 3.0, BigDecimal(3.0), new Timestamp(3)) + ).toDF("a", "b", "c", "d", "e") // Note column d's data type is Decimal(38, 18) + df.createTempView("v") + + val query = sql( + s""" + |SELECT + |TRANSFORM(a, b, c, d, e) + |USING 'python $scriptFilePath' AS (a, b, c, d, e) + |FROM v + """.stripMargin) + + // In Hive 1.2, the string representation of a decimal omits trailing zeroes. + // But in Hive 2.3, it is always padded to 18 digits with trailing zeroes if necessary. + val decimalToString: Column => Column = if (isHive23OrSpark) { + c => c.cast("string") + } else { + c => c.cast("decimal(1, 0)").cast("string") + } + checkAnswer(query, identity, df.select( + 'a.cast("string"), + 'b.cast("string"), + 'c.cast("string"), + decimalToString('d), + 'e.cast("string")).collect()) + } + } + + test("SPARK-25990: TRANSFORM should handle schema less correctly (no serde)") { + assume(TestUtils.testCommandAvailable("python")) + val scriptFilePath = getTestResourcePath("test_script.py") + + withTempView("v") { + val df = Seq( + (1, "1", 1.0, BigDecimal(1.0), new Timestamp(1)), + (2, "2", 2.0, BigDecimal(2.0), new Timestamp(2)), + (3, "3", 3.0, BigDecimal(3.0), new Timestamp(3)) + ).toDF("a", "b", "c", "d", "e") // Note column d's data type is Decimal(38, 18) + + checkAnswer( + df, + (child: SparkPlan) => createScriptTransformationExec( + input = Seq( + df.col("a").expr, + df.col("b").expr, + df.col("c").expr, + df.col("d").expr, + df.col("e").expr), + script = s"python $scriptFilePath", + output = Seq( + AttributeReference("key", StringType)(), + AttributeReference("value", StringType)()), + child = child, + ioschema = defaultIOSchema.copy(schemaLess = true) + ), + df.select( + 'a.cast("string").as("key"), + 'b.cast("string").as("value")).collect()) + } + } + + test("SPARK-30973: TRANSFORM should wait for the termination of the script (no serde)") { + assume(TestUtils.testCommandAvailable("/bin/bash")) + + val rowsDf = Seq("a", "b", "c").map(Tuple1.apply).toDF("a") + val e = intercept[SparkException] { + val plan = + createScriptTransformationExec( + input = Seq(rowsDf.col("a").expr), + script = "some_non_existent_command", + output = Seq(AttributeReference("a", StringType)()), + child = rowsDf.queryExecution.sparkPlan, + ioschema = defaultIOSchema) + SparkPlanTest.executePlan(plan, spark.sqlContext) + } + assert(e.getMessage.contains("Subprocess exited with status")) + assert(uncaughtExceptionHandler.exception.isEmpty) + } + + def testBasicInputDataTypesWith(serde: ScriptTransformationIOSchema, testName: String): Unit = { + test(s"SPARK-32106: TRANSFORM should support basic data types as input ($testName)") { + assume(TestUtils.testCommandAvailable("python")) + withTempView("v") { + val df = Seq( + (1, "1", 1.0f, 1.0, 11.toByte, BigDecimal(1.0), new Timestamp(1), + new Date(2020, 7, 1), true), + (2, "2", 2.0f, 2.0, 22.toByte, BigDecimal(2.0), new Timestamp(2), + new Date(2020, 7, 2), true), + (3, "3", 3.0f, 3.0, 33.toByte, BigDecimal(3.0), new Timestamp(3), + new Date(2020, 7, 3), false) + ).toDF("a", "b", "c", "d", "e", "f", "g", "h", "i") + .withColumn("j", lit("abc").cast("binary")) + + checkAnswer( + df, + (child: SparkPlan) => createScriptTransformationExec( + input = Seq( + df.col("a").expr, + df.col("b").expr, + df.col("c").expr, + df.col("d").expr, + df.col("e").expr, + df.col("f").expr, + df.col("g").expr, + df.col("h").expr, + df.col("i").expr, + df.col("j").expr), + script = "cat", + output = Seq( + AttributeReference("a", IntegerType)(), + AttributeReference("b", StringType)(), + AttributeReference("c", FloatType)(), + AttributeReference("d", DoubleType)(), + AttributeReference("e", ByteType)(), + AttributeReference("f", DecimalType(38, 18))(), + AttributeReference("g", TimestampType)(), + AttributeReference("h", DateType)(), + AttributeReference("i", BooleanType)(), + AttributeReference("j", BinaryType)()), + child = child, + ioschema = serde + ), + df.select('a, 'b, 'c, 'd, 'e, 'f, 'g, 'h, 'i, 'j).collect()) + } + } + } + + testBasicInputDataTypesWith(defaultIOSchema, "no serde") + + test("SPARK-32106: TRANSFORM should support more data types (interval, array, map, struct " + + "and udt) as input (no serde)") { + assume(TestUtils.testCommandAvailable("python")) + withTempView("v") { + val df = Seq( + (new CalendarInterval(7, 1, 1000), Array(0, 1, 2), Map("a" -> 1), (1, 2), + new SimpleTuple(1, 1L)), + (new CalendarInterval(7, 2, 2000), Array(3, 4, 5), Map("b" -> 2), (3, 4), + new SimpleTuple(1, 1L)), + (new CalendarInterval(7, 3, 3000), Array(6, 7, 8), Map("c" -> 3), (5, 6), + new SimpleTuple(1, 1L)) + ).toDF("a", "b", "c", "d", "e") + + // Can't support convert script output data to ArrayType/MapType/StructType now, + // return these column still as string. + // For UserDefinedType, if user defined deserialize method to support convert string + // to UserType like [[SimpleTupleUDT]], we can support convert to this UDT, else we + // will return null value as column. + checkAnswer( + df, + (child: SparkPlan) => createScriptTransformationExec( + input = Seq( + df.col("a").expr, + df.col("b").expr, + df.col("c").expr, + df.col("d").expr, + df.col("e").expr), + script = "cat", + output = Seq( + AttributeReference("a", CalendarIntervalType)(), + AttributeReference("b", StringType)(), + AttributeReference("c", StringType)(), + AttributeReference("d", StringType)(), + AttributeReference("e", new SimpleTupleUDT)()), + child = child, + ioschema = defaultIOSchema + ), + df.select('a, 'b.cast("string"), 'c.cast("string"), 'd.cast("string"), 'e).collect()) + } + } + + test("SPARK-32106: TRANSFORM should respect DATETIME_JAVA8API_ENABLED (no serde)") { + assume(TestUtils.testCommandAvailable("python")) + Array(false, true).foreach { java8AapiEnable => + withSQLConf(SQLConf.DATETIME_JAVA8API_ENABLED.key -> java8AapiEnable.toString) { + withTempView("v") { + val df = Seq( + (new Timestamp(1), new Date(2020, 7, 1)), + (new Timestamp(2), new Date(2020, 7, 2)), + (new Timestamp(3), new Date(2020, 7, 3)) + ).toDF("a", "b") + df.createTempView("v") + + val query = sql( + """ + |SELECT TRANSFORM (a, b) + |USING 'cat' AS (a timestamp, b date) + |FROM v + """.stripMargin) + checkAnswer(query, identity, df.select('a, 'b).collect()) + } + } + } + } +} + +case class ExceptionInjectingOperator(child: SparkPlan) extends UnaryExecNode { + override protected def doExecute(): RDD[InternalRow] = { + child.execute().map { x => + assert(TaskContext.get() != null) // Make sure that TaskContext is defined. + Thread.sleep(1000) // This sleep gives the external process time to start. + throw new IllegalArgumentException("intentional exception") + } + } + + override def output: Seq[Attribute] = child.output + + override def outputPartitioning: Partitioning = child.outputPartitioning +} + +@SQLUserDefinedType(udt = classOf[SimpleTupleUDT]) +private class SimpleTuple(val id: Int, val size: Long) extends Serializable { + + override def hashCode(): Int = getClass.hashCode() + + override def equals(other: Any): Boolean = other match { + case v: SimpleTuple => this.id == v.id && this.size == v.size + case _ => false + } + + override def toString: String = + compact(render( + ("id" -> id) ~ + ("size" -> size) + )) +} + +private class SimpleTupleUDT extends UserDefinedType[SimpleTuple] { + + override def sqlType: DataType = StructType( + StructField("id", IntegerType, false) :: + StructField("size", LongType, false) :: + Nil) + + override def serialize(sql: SimpleTuple): Any = { + val row = new GenericInternalRow(2) + row.setInt(0, sql.id) + row.setLong(1, sql.size) + row + } + + override def deserialize(datum: Any): SimpleTuple = { + datum match { + case str: String => + implicit val format = DefaultFormats + val json = parse(str) + new SimpleTuple((json \ "id").extract[Int], (json \ "size").extract[Long]) + case data: InternalRow if data.numFields == 2 => + new SimpleTuple(data.getInt(0), data.getLong(1)) + case _ => null + } + } + + override def userClass: Class[SimpleTuple] = classOf[SimpleTuple] + + override def asNullable: SimpleTupleUDT = this + + override def hashCode(): Int = getClass.hashCode() + + override def equals(other: Any): Boolean = { + other.isInstanceOf[SimpleTupleUDT] + } +} + diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkScriptTransformationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkScriptTransformationSuite.scala new file mode 100644 index 0000000000000..6b20f4cf88645 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkScriptTransformationSuite.scala @@ -0,0 +1,104 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +import org.apache.spark.{SparkException, TestUtils} +import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression} +import org.apache.spark.sql.catalyst.parser.ParseException +import org.apache.spark.sql.test.SharedSparkSession + +class SparkScriptTransformationSuite extends BaseScriptTransformationSuite with SharedSparkSession { + import testImplicits._ + + override def isHive23OrSpark: Boolean = true + + override def createScriptTransformationExec( + input: Seq[Expression], + script: String, + output: Seq[Attribute], + child: SparkPlan, + ioschema: ScriptTransformationIOSchema): BaseScriptTransformationExec = { + SparkScriptTransformationExec( + input = input, + script = script, + output = output, + child = child, + ioschema = ioschema + ) + } + + test("SPARK-32106: TRANSFORM with serde without hive should throw exception") { + assume(TestUtils.testCommandAvailable("/bin/bash")) + withTempView("v") { + val df = Seq("a", "b", "c").map(Tuple1.apply).toDF("a") + df.createTempView("v") + + val e = intercept[ParseException] { + sql( + """ + |SELECT TRANSFORM (a) + |ROW FORMAT SERDE 'org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe' + |USING 'cat' AS (a) + |ROW FORMAT SERDE 'org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe' + |FROM v + """.stripMargin) + }.getMessage + assert(e.contains("TRANSFORM with serde is only supported in hive mode")) + } + } + + test("SPARK-32106: TRANSFORM doesn't support ArrayType/MapType/StructType " + + "as output data type (no serde)") { + assume(TestUtils.testCommandAvailable("/bin/bash")) + // check for ArrayType + val e1 = intercept[SparkException] { + sql( + """ + |SELECT TRANSFORM(a) + |USING 'cat' AS (a array) + |FROM VALUES (array(1, 1), map('1', 1), struct(1, 'a')) t(a, b, c) + """.stripMargin).collect() + }.getMessage + assert(e1.contains("SparkScriptTransformation without serde does not support" + + " ArrayType as output data type")) + + // check for MapType + val e2 = intercept[SparkException] { + sql( + """ + |SELECT TRANSFORM(b) + |USING 'cat' AS (b map) + |FROM VALUES (array(1, 1), map('1', 1), struct(1, 'a')) t(a, b, c) + """.stripMargin).collect() + }.getMessage + assert(e2.contains("SparkScriptTransformation without serde does not support" + + " MapType as output data type")) + + // check for StructType + val e3 = intercept[SparkException] { + sql( + """ + |SELECT TRANSFORM(c) + |USING 'cat' AS (c struct) + |FROM VALUES (array(1, 1), map('1', 1), struct(1, 'a')) t(a, b, c) + """.stripMargin).collect() + }.getMessage + assert(e3.contains("SparkScriptTransformation without serde does not support" + + " StructType as output data type")) + } +} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/TestUncaughtExceptionHandler.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/TestUncaughtExceptionHandler.scala similarity index 96% rename from sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/TestUncaughtExceptionHandler.scala rename to sql/core/src/test/scala/org/apache/spark/sql/execution/TestUncaughtExceptionHandler.scala index 681eb4e255dbc..360f4658345e9 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/TestUncaughtExceptionHandler.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/TestUncaughtExceptionHandler.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.hive.execution +package org.apache.spark.sql.execution class TestUncaughtExceptionHandler extends Thread.UncaughtExceptionHandler { diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala index 16e9014340244..d075b69d976cd 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala @@ -1063,6 +1063,9 @@ private[hive] trait HiveInspectors { case DateType => dateTypeInfo case TimestampType => timestampTypeInfo case NullType => voidTypeInfo + case dt => + throw new AnalysisException( + s"${dt.getClass.getSimpleName.replace("$", "")} cannot be converted to Hive TypeInfo") } } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala index dae68df08f32e..97e1dee5913a4 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala @@ -33,7 +33,7 @@ import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.command.{CreateTableCommand, DDLUtils} import org.apache.spark.sql.execution.datasources.CreateTable import org.apache.spark.sql.hive.execution._ -import org.apache.spark.sql.hive.execution.{HiveScriptIOSchema, HiveScriptTransformationExec} +import org.apache.spark.sql.hive.execution.HiveScriptTransformationExec import org.apache.spark.sql.internal.{HiveSerDe, SQLConf} @@ -244,7 +244,7 @@ private[hive] trait HiveStrategies { object HiveScripts extends Strategy { def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { case ScriptTransformation(input, script, output, child, ioschema) => - val hiveIoSchema = HiveScriptIOSchema(ioschema) + val hiveIoSchema = ScriptTransformationIOSchema(ioschema) HiveScriptTransformationExec(input, script, output, planLater(child), hiveIoSchema) :: Nil case _ => Nil } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformationExec.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveScriptTransformationExec.scala similarity index 61% rename from sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformationExec.scala rename to sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveScriptTransformationExec.scala index 96fe646d39fde..535eae5e47adb 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformationExec.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveScriptTransformationExec.scala @@ -18,7 +18,6 @@ package org.apache.spark.sql.hive.execution import java.io._ -import java.nio.charset.StandardCharsets import java.util.Properties import javax.annotation.Nullable @@ -33,14 +32,13 @@ import org.apache.hadoop.hive.serde2.objectinspector._ import org.apache.hadoop.io.Writable import org.apache.spark.TaskContext -import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.logical.ScriptInputOutputSchema import org.apache.spark.sql.execution._ import org.apache.spark.sql.hive.HiveInspectors import org.apache.spark.sql.hive.HiveShim._ import org.apache.spark.sql.types.DataType -import org.apache.spark.util.{CircularBuffer, RedirectThread, Utils} +import org.apache.spark.util.{CircularBuffer, Utils} /** * Transforms the input by forking and running the specified script. @@ -54,71 +52,27 @@ case class HiveScriptTransformationExec( script: String, output: Seq[Attribute], child: SparkPlan, - ioschema: HiveScriptIOSchema) + ioschema: ScriptTransformationIOSchema) extends BaseScriptTransformationExec { + import HiveScriptIOSchema._ - override def processIterator( - inputIterator: Iterator[InternalRow], + private def createOutputIteratorWithSerde( + writerThread: BaseScriptTransformationWriterThread, + inputStream: InputStream, + proc: Process, + stderrBuffer: CircularBuffer, + outputSerde: AbstractSerDe, + outputSoi: StructObjectInspector, hadoopConf: Configuration): Iterator[InternalRow] = { - val cmd = List("/bin/bash", "-c", script) - val builder = new ProcessBuilder(cmd.asJava) - - val proc = builder.start() - val inputStream = proc.getInputStream - val outputStream = proc.getOutputStream - val errorStream = proc.getErrorStream - - // In order to avoid deadlocks, we need to consume the error output of the child process. - // To avoid issues caused by large error output, we use a circular buffer to limit the amount - // of error output that we retain. See SPARK-7862 for more discussion of the deadlock / hang - // that motivates this. - val stderrBuffer = new CircularBuffer(2048) - new RedirectThread( - errorStream, - stderrBuffer, - "Thread-ScriptTransformation-STDERR-Consumer").start() - - val outputProjection = new InterpretedProjection(input, child.output) - - // This nullability is a performance optimization in order to avoid an Option.foreach() call - // inside of a loop - @Nullable val (inputSerde, inputSoi) = ioschema.initInputSerDe(input).getOrElse((null, null)) - - // This new thread will consume the ScriptTransformation's input rows and write them to the - // external process. That process's output will be read by this current thread. - val writerThread = new HiveScriptTransformationWriterThread( - inputIterator.map(outputProjection), - input.map(_.dataType), - inputSerde, - inputSoi, - ioschema, - outputStream, - proc, - stderrBuffer, - TaskContext.get(), - hadoopConf - ) - - // This nullability is a performance optimization in order to avoid an Option.foreach() call - // inside of a loop - @Nullable val (outputSerde, outputSoi) = { - ioschema.initOutputSerDe(output).getOrElse((null, null)) - } - - val reader = new BufferedReader(new InputStreamReader(inputStream, StandardCharsets.UTF_8)) - val outputIterator: Iterator[InternalRow] = new Iterator[InternalRow] with HiveInspectors { + new Iterator[InternalRow] with HiveInspectors { var curLine: String = null val scriptOutputStream = new DataInputStream(inputStream) @Nullable val scriptOutputReader = - ioschema.recordReader(scriptOutputStream, hadoopConf).orNull + recordReader(ioschema, scriptOutputStream, hadoopConf).orNull var scriptOutputWritable: Writable = null - val reusedWritableObject: Writable = if (null != outputSerde) { - outputSerde.getSerializedClass().getConstructor().newInstance() - } else { - null - } + val reusedWritableObject = outputSerde.getSerializedClass.getConstructor().newInstance() val mutableRow = new SpecificInternalRow(output.map(_.dataType)) @transient @@ -126,15 +80,7 @@ case class HiveScriptTransformationExec( override def hasNext: Boolean = { try { - if (outputSerde == null) { - if (curLine == null) { - curLine = reader.readLine() - if (curLine == null) { - checkFailureAndPropagate(writerThread, null, proc, stderrBuffer) - return false - } - } - } else if (scriptOutputWritable == null) { + if (scriptOutputWritable == null) { scriptOutputWritable = reusedWritableObject if (scriptOutputReader != null) { @@ -172,35 +118,70 @@ case class HiveScriptTransformationExec( if (!hasNext) { throw new NoSuchElementException } - if (outputSerde == null) { - val prevLine = curLine - curLine = reader.readLine() - if (!ioschema.schemaLess) { - new GenericInternalRow( - prevLine.split(ioschema.outputRowFormatMap("TOK_TABLEROWFORMATFIELD")) - .map(CatalystTypeConverters.convertToCatalyst)) + val raw = outputSerde.deserialize(scriptOutputWritable) + scriptOutputWritable = null + val dataList = outputSoi.getStructFieldsDataAsList(raw) + var i = 0 + while (i < dataList.size()) { + if (dataList.get(i) == null) { + mutableRow.setNullAt(i) } else { - new GenericInternalRow( - prevLine.split(ioschema.outputRowFormatMap("TOK_TABLEROWFORMATFIELD"), 2) - .map(CatalystTypeConverters.convertToCatalyst)) + unwrappers(i)(dataList.get(i), mutableRow, i) } - } else { - val raw = outputSerde.deserialize(scriptOutputWritable) - scriptOutputWritable = null - val dataList = outputSoi.getStructFieldsDataAsList(raw) - var i = 0 - while (i < dataList.size()) { - if (dataList.get(i) == null) { - mutableRow.setNullAt(i) - } else { - unwrappers(i)(dataList.get(i), mutableRow, i) - } - i += 1 - } - mutableRow + i += 1 } + mutableRow } } + } + + override def processIterator( + inputIterator: Iterator[InternalRow], + hadoopConf: Configuration): Iterator[InternalRow] = { + + val (outputStream, proc, inputStream, stderrBuffer) = initProc + + // This nullability is a performance optimization in order to avoid an Option.foreach() call + // inside of a loop + @Nullable val (inputSerde, inputSoi) = initInputSerDe(ioschema, input).getOrElse((null, null)) + + // For HiveScriptTransformationExec, if inputSerde == null, but outputSerde != null + // We will use StringBuffer to pass data, in this case, we should cast data as string too. + val finalInput = if (inputSerde == null) { + inputExpressionsWithoutSerde + } else { + input + } + + val outputProjection = new InterpretedProjection(finalInput, child.output) + + // This new thread will consume the ScriptTransformation's input rows and write them to the + // external process. That process's output will be read by this current thread. + val writerThread = HiveScriptTransformationWriterThread( + inputIterator.map(outputProjection), + finalInput.map(_.dataType), + inputSerde, + inputSoi, + ioschema, + outputStream, + proc, + stderrBuffer, + TaskContext.get(), + hadoopConf + ) + + // This nullability is a performance optimization in order to avoid an Option.foreach() call + // inside of a loop + @Nullable val (outputSerde, outputSoi) = { + initOutputSerDe(ioschema, output).getOrElse((null, null)) + } + + val outputIterator = if (outputSerde == null) { + createOutputIteratorWithoutSerde(writerThread, inputStream, proc, stderrBuffer) + } else { + createOutputIteratorWithSerde( + writerThread, inputStream, proc, stderrBuffer, outputSerde, outputSoi, hadoopConf) + } writerThread.start() @@ -208,30 +189,23 @@ case class HiveScriptTransformationExec( } } -private class HiveScriptTransformationWriterThread( +case class HiveScriptTransformationWriterThread( iter: Iterator[InternalRow], inputSchema: Seq[DataType], @Nullable inputSerde: AbstractSerDe, @Nullable inputSoi: StructObjectInspector, - ioSchema: HiveScriptIOSchema, + ioSchema: ScriptTransformationIOSchema, outputStream: OutputStream, proc: Process, stderrBuffer: CircularBuffer, taskContext: TaskContext, conf: Configuration) - extends BaseScriptTransformationWriterThread( - iter, - inputSchema, - ioSchema, - outputStream, - proc, - stderrBuffer, - taskContext, - conf) with HiveInspectors { + extends BaseScriptTransformationWriterThread with HiveInspectors { + import HiveScriptIOSchema._ override def processRows(): Unit = { val dataOutputStream = new DataOutputStream(outputStream) - @Nullable val scriptInputWriter = ioSchema.recordWriter(dataOutputStream, conf).orNull + @Nullable val scriptInputWriter = recordWriter(ioSchema, dataOutputStream, conf).orNull if (inputSerde == null) { processRowsWithoutSerde() @@ -259,40 +233,14 @@ private class HiveScriptTransformationWriterThread( } } -object HiveScriptIOSchema { - def apply(input: ScriptInputOutputSchema): HiveScriptIOSchema = { - HiveScriptIOSchema( - input.inputRowFormat, - input.outputRowFormat, - input.inputSerdeClass, - input.outputSerdeClass, - input.inputSerdeProps, - input.outputSerdeProps, - input.recordReaderClass, - input.recordWriterClass, - input.schemaLess) - } -} +object HiveScriptIOSchema extends HiveInspectors { -/** - * The wrapper class of Hive input and output schema properties - */ -case class HiveScriptIOSchema ( - inputRowFormat: Seq[(String, String)], - outputRowFormat: Seq[(String, String)], - inputSerdeClass: Option[String], - outputSerdeClass: Option[String], - inputSerdeProps: Seq[(String, String)], - outputSerdeProps: Seq[(String, String)], - recordReaderClass: Option[String], - recordWriterClass: Option[String], - schemaLess: Boolean) - extends BaseScriptTransformIOSchema with HiveInspectors { - - def initInputSerDe(input: Seq[Expression]): Option[(AbstractSerDe, StructObjectInspector)] = { - inputSerdeClass.map { serdeClass => + def initInputSerDe( + ioschema: ScriptTransformationIOSchema, + input: Seq[Expression]): Option[(AbstractSerDe, StructObjectInspector)] = { + ioschema.inputSerdeClass.map { serdeClass => val (columns, columnTypes) = parseAttrs(input) - val serde = initSerDe(serdeClass, columns, columnTypes, inputSerdeProps) + val serde = initSerDe(serdeClass, columns, columnTypes, ioschema.inputSerdeProps) val fieldObjectInspectors = columnTypes.map(toInspector) val objectInspector = ObjectInspectorFactory .getStandardStructObjectInspector(columns.asJava, fieldObjectInspectors.asJava) @@ -300,10 +248,12 @@ case class HiveScriptIOSchema ( } } - def initOutputSerDe(output: Seq[Attribute]): Option[(AbstractSerDe, StructObjectInspector)] = { - outputSerdeClass.map { serdeClass => + def initOutputSerDe( + ioschema: ScriptTransformationIOSchema, + output: Seq[Attribute]): Option[(AbstractSerDe, StructObjectInspector)] = { + ioschema.outputSerdeClass.map { serdeClass => val (columns, columnTypes) = parseAttrs(output) - val serde = initSerDe(serdeClass, columns, columnTypes, outputSerdeProps) + val serde = initSerDe(serdeClass, columns, columnTypes, ioschema.outputSerdeProps) val structObjectInspector = serde.getObjectInspector().asInstanceOf[StructObjectInspector] (serde, structObjectInspector) } @@ -315,7 +265,7 @@ case class HiveScriptIOSchema ( (columns, columnTypes) } - private def initSerDe( + def initSerDe( serdeClassName: String, columns: Seq[String], columnTypes: Seq[DataType], @@ -339,22 +289,26 @@ case class HiveScriptIOSchema ( } def recordReader( + ioschema: ScriptTransformationIOSchema, inputStream: InputStream, conf: Configuration): Option[RecordReader] = { - recordReaderClass.map { klass => + ioschema.recordReaderClass.map { klass => val instance = Utils.classForName[RecordReader](klass).getConstructor(). newInstance() val props = new Properties() // Can not use props.putAll(outputSerdeProps.toMap.asJava) in scala-2.12 // See https://github.com/scala/bug/issues/10418 - outputSerdeProps.toMap.foreach { case (k, v) => props.put(k, v) } + ioschema.outputSerdeProps.toMap.foreach { case (k, v) => props.put(k, v) } instance.initialize(inputStream, conf, props) instance } } - def recordWriter(outputStream: OutputStream, conf: Configuration): Option[RecordWriter] = { - recordWriterClass.map { klass => + def recordWriter( + ioschema: ScriptTransformationIOSchema, + outputStream: OutputStream, + conf: Configuration): Option[RecordWriter] = { + ioschema.recordWriterClass.map { klass => val instance = Utils.classForName[RecordWriter](klass).getConstructor(). newInstance() instance.initialize(outputStream, conf) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveScriptTransformationSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveScriptTransformationSuite.scala index 35252fc47f49f..e89e20c2c723e 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveScriptTransformationSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveScriptTransformationSuite.scala @@ -20,78 +20,44 @@ package org.apache.spark.sql.hive.execution import java.sql.Timestamp import org.apache.hadoop.hive.serde2.`lazy`.LazySimpleSerDe -import org.scalatest.Assertions._ -import org.scalatest.BeforeAndAfterEach import org.scalatest.exceptions.TestFailedException -import org.apache.spark.{SparkException, TaskContext, TestUtils} -import org.apache.spark.rdd.RDD -import org.apache.spark.sql.Column -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference} -import org.apache.spark.sql.catalyst.plans.physical.Partitioning -import org.apache.spark.sql.execution.{SparkPlan, SparkPlanTest, UnaryExecNode} +import org.apache.spark.{SparkException, TestUtils} +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, Expression} +import org.apache.spark.sql.execution._ +import org.apache.spark.sql.functions._ import org.apache.spark.sql.hive.HiveUtils import org.apache.spark.sql.hive.test.TestHiveSingleton -import org.apache.spark.sql.test.SQLTestUtils -import org.apache.spark.sql.types.StringType - -class HiveScriptTransformationSuite extends SparkPlanTest with SQLTestUtils with TestHiveSingleton - with BeforeAndAfterEach { - import spark.implicits._ - - private val noSerdeIOSchema = HiveScriptIOSchema( - inputRowFormat = Seq.empty, - outputRowFormat = Seq.empty, - inputSerdeClass = None, - outputSerdeClass = None, - inputSerdeProps = Seq.empty, - outputSerdeProps = Seq.empty, - recordReaderClass = None, - recordWriterClass = None, - schemaLess = false - ) - - private val serdeIOSchema = noSerdeIOSchema.copy( - inputSerdeClass = Some(classOf[LazySimpleSerDe].getCanonicalName), - outputSerdeClass = Some(classOf[LazySimpleSerDe].getCanonicalName) - ) - - private var defaultUncaughtExceptionHandler: Thread.UncaughtExceptionHandler = _ - - private val uncaughtExceptionHandler = new TestUncaughtExceptionHandler - - protected override def beforeAll(): Unit = { - super.beforeAll() - defaultUncaughtExceptionHandler = Thread.getDefaultUncaughtExceptionHandler - Thread.setDefaultUncaughtExceptionHandler(uncaughtExceptionHandler) +import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.CalendarInterval + +class HiveScriptTransformationSuite extends BaseScriptTransformationSuite with TestHiveSingleton { + import testImplicits._ + + import ScriptTransformationIOSchema._ + + override def isHive23OrSpark: Boolean = HiveUtils.isHive23 + + override def createScriptTransformationExec( + input: Seq[Expression], + script: String, + output: Seq[Attribute], + child: SparkPlan, + ioschema: ScriptTransformationIOSchema): BaseScriptTransformationExec = { + HiveScriptTransformationExec( + input = input, + script = script, + output = output, + child = child, + ioschema = ioschema + ) } - protected override def afterAll(): Unit = { - super.afterAll() - Thread.setDefaultUncaughtExceptionHandler(defaultUncaughtExceptionHandler) - } - - override protected def afterEach(): Unit = { - super.afterEach() - uncaughtExceptionHandler.cleanStatus() - } - - test("cat without SerDe") { - assume(TestUtils.testCommandAvailable("/bin/bash")) - - val rowsDf = Seq("a", "b", "c").map(Tuple1.apply).toDF("a") - checkAnswer( - rowsDf, - (child: SparkPlan) => new HiveScriptTransformationExec( - input = Seq(rowsDf.col("a").expr), - script = "cat", - output = Seq(AttributeReference("a", StringType)()), - child = child, - ioschema = noSerdeIOSchema - ), - rowsDf.collect()) - assert(uncaughtExceptionHandler.exception.isEmpty) + private val hiveIOSchema: ScriptTransformationIOSchema = { + defaultIOSchema.copy( + inputSerdeClass = Some(classOf[LazySimpleSerDe].getCanonicalName), + outputSerdeClass = Some(classOf[LazySimpleSerDe].getCanonicalName) + ) } test("cat with LazySimpleSerDe") { @@ -100,51 +66,30 @@ class HiveScriptTransformationSuite extends SparkPlanTest with SQLTestUtils with val rowsDf = Seq("a", "b", "c").map(Tuple1.apply).toDF("a") checkAnswer( rowsDf, - (child: SparkPlan) => new HiveScriptTransformationExec( + (child: SparkPlan) => createScriptTransformationExec( input = Seq(rowsDf.col("a").expr), script = "cat", output = Seq(AttributeReference("a", StringType)()), child = child, - ioschema = serdeIOSchema + ioschema = hiveIOSchema ), rowsDf.collect()) assert(uncaughtExceptionHandler.exception.isEmpty) } - test("script transformation should not swallow errors from upstream operators (no serde)") { - assume(TestUtils.testCommandAvailable("/bin/bash")) - - val rowsDf = Seq("a", "b", "c").map(Tuple1.apply).toDF("a") - val e = intercept[TestFailedException] { - checkAnswer( - rowsDf, - (child: SparkPlan) => new HiveScriptTransformationExec( - input = Seq(rowsDf.col("a").expr), - script = "cat", - output = Seq(AttributeReference("a", StringType)()), - child = ExceptionInjectingOperator(child), - ioschema = noSerdeIOSchema - ), - rowsDf.collect()) - } - assert(e.getMessage().contains("intentional exception")) - // Before SPARK-25158, uncaughtExceptionHandler will catch IllegalArgumentException - assert(uncaughtExceptionHandler.exception.isEmpty) - } - - test("script transformation should not swallow errors from upstream operators (with serde)") { + test("script transformation should not swallow errors from upstream operators (hive serde)") { assume(TestUtils.testCommandAvailable("/bin/bash")) val rowsDf = Seq("a", "b", "c").map(Tuple1.apply).toDF("a") val e = intercept[TestFailedException] { checkAnswer( rowsDf, - (child: SparkPlan) => new HiveScriptTransformationExec( + (child: SparkPlan) => createScriptTransformationExec( input = Seq(rowsDf.col("a").expr), script = "cat", output = Seq(AttributeReference("a", StringType)()), child = ExceptionInjectingOperator(child), - ioschema = serdeIOSchema + ioschema = hiveIOSchema ), rowsDf.collect()) } @@ -153,26 +98,26 @@ class HiveScriptTransformationSuite extends SparkPlanTest with SQLTestUtils with assert(uncaughtExceptionHandler.exception.isEmpty) } - test("SPARK-14400 script transformation should fail for bad script command") { + test("SPARK-14400 script transformation should fail for bad script command (hive serde)") { assume(TestUtils.testCommandAvailable("/bin/bash")) val rowsDf = Seq("a", "b", "c").map(Tuple1.apply).toDF("a") val e = intercept[SparkException] { val plan = - new HiveScriptTransformationExec( + createScriptTransformationExec( input = Seq(rowsDf.col("a").expr), script = "some_non_existent_command", output = Seq(AttributeReference("a", StringType)()), child = rowsDf.queryExecution.sparkPlan, - ioschema = serdeIOSchema) + ioschema = hiveIOSchema) SparkPlanTest.executePlan(plan, hiveContext) } assert(e.getMessage.contains("Subprocess exited with status")) assert(uncaughtExceptionHandler.exception.isEmpty) } - test("SPARK-24339 verify the result after pruning the unused columns") { + test("SPARK-24339 verify the result after pruning the unused columns (hive serde)") { val rowsDf = Seq( ("Bob", 16, 176), ("Alice", 32, 164), @@ -181,18 +126,36 @@ class HiveScriptTransformationSuite extends SparkPlanTest with SQLTestUtils with checkAnswer( rowsDf, - (child: SparkPlan) => new HiveScriptTransformationExec( + (child: SparkPlan) => createScriptTransformationExec( input = Seq(rowsDf.col("name").expr), script = "cat", output = Seq(AttributeReference("name", StringType)()), child = child, - ioschema = serdeIOSchema + ioschema = hiveIOSchema ), rowsDf.select("name").collect()) assert(uncaughtExceptionHandler.exception.isEmpty) } - test("SPARK-25990: TRANSFORM should handle different data types correctly") { + test("SPARK-30973: TRANSFORM should wait for the termination of the script (hive serde)") { + assume(TestUtils.testCommandAvailable("/bin/bash")) + + val rowsDf = Seq("a", "b", "c").map(Tuple1.apply).toDF("a") + val e = intercept[SparkException] { + val plan = + createScriptTransformationExec( + input = Seq(rowsDf.col("a").expr), + script = "some_non_existent_command", + output = Seq(AttributeReference("a", StringType)()), + child = rowsDf.queryExecution.sparkPlan, + ioschema = hiveIOSchema) + SparkPlanTest.executePlan(plan, hiveContext) + } + assert(e.getMessage.contains("Subprocess exited with status")) + assert(uncaughtExceptionHandler.exception.isEmpty) + } + + test("SPARK-25990: TRANSFORM should handle schema less correctly (hive serde)") { assume(TestUtils.testCommandAvailable("python")) val scriptFilePath = getTestResourcePath("test_script.py") @@ -206,75 +169,147 @@ class HiveScriptTransformationSuite extends SparkPlanTest with SQLTestUtils with val query = sql( s""" - |SELECT - |TRANSFORM(a, b, c, d, e) - |USING 'python $scriptFilePath' AS (a, b, c, d, e) - |FROM v + |SELECT TRANSFORM(a, b, c, d, e) + |USING 'python ${scriptFilePath}' + |FROM v """.stripMargin) - // In Hive 1.2, the string representation of a decimal omits trailing zeroes. - // But in Hive 2.3, it is always padded to 18 digits with trailing zeroes if necessary. - val decimalToString: Column => Column = if (HiveUtils.isHive23) { - c => c.cast("string") - } else { - c => c.cast("decimal(1, 0)").cast("string") - } - checkAnswer(query, identity, df.select( - 'a.cast("string"), - 'b.cast("string"), - 'c.cast("string"), - decimalToString('d), - 'e.cast("string")).collect()) + // In hive default serde mode, if we don't define output schema, it will choose first + // two column as output schema (key: String, value: String) + checkAnswer( + query, + identity, + df.select( + 'a.cast("string").as("key"), + 'b.cast("string").as("value")).collect()) } } - test("SPARK-30973: TRANSFORM should wait for the termination of the script (no serde)") { + testBasicInputDataTypesWith(hiveIOSchema, "hive serde") + + test("SPARK-32106: TRANSFORM supports complex data types type (hive serde)") { assume(TestUtils.testCommandAvailable("/bin/bash")) + withTempView("v") { + val df = Seq( + (1, "1", Array(0, 1, 2), Map("a" -> 1)), + (2, "2", Array(3, 4, 5), Map("b" -> 2))) + .toDF("a", "b", "c", "d") + .select('a, 'b, 'c, 'd, struct('a, 'b).as("e")) + df.createTempView("v") - val rowsDf = Seq("a", "b", "c").map(Tuple1.apply).toDF("a") - val e = intercept[SparkException] { - val plan = - new HiveScriptTransformationExec( - input = Seq(rowsDf.col("a").expr), - script = "some_non_existent_command", - output = Seq(AttributeReference("a", StringType)()), - child = rowsDf.queryExecution.sparkPlan, - ioschema = noSerdeIOSchema) - SparkPlanTest.executePlan(plan, hiveContext) + // Hive serde support ArrayType/MapType/StructType as input and output data type + checkAnswer( + df, + (child: SparkPlan) => createScriptTransformationExec( + input = Seq( + df.col("c").expr, + df.col("d").expr, + df.col("e").expr), + script = "cat", + output = Seq( + AttributeReference("c", ArrayType(IntegerType))(), + AttributeReference("d", MapType(StringType, IntegerType))(), + AttributeReference("e", StructType( + Seq( + StructField("col1", IntegerType, false), + StructField("col2", StringType, true))))()), + child = child, + ioschema = hiveIOSchema + ), + df.select('c, 'd, 'e).collect()) } - assert(e.getMessage.contains("Subprocess exited with status")) - assert(uncaughtExceptionHandler.exception.isEmpty) } - test("SPARK-30973: TRANSFORM should wait for the termination of the script (with serde)") { + test("SPARK-32106: TRANSFORM supports complex data types end to end (hive serde)") { assume(TestUtils.testCommandAvailable("/bin/bash")) + withTempView("v") { + val df = Seq( + (1, "1", Array(0, 1, 2), Map("a" -> 1)), + (2, "2", Array(3, 4, 5), Map("b" -> 2))) + .toDF("a", "b", "c", "d") + .select('a, 'b, 'c, 'd, struct('a, 'b).as("e")) + df.createTempView("v") - val rowsDf = Seq("a", "b", "c").map(Tuple1.apply).toDF("a") - val e = intercept[SparkException] { - val plan = - new HiveScriptTransformationExec( - input = Seq(rowsDf.col("a").expr), - script = "some_non_existent_command", - output = Seq(AttributeReference("a", StringType)()), - child = rowsDf.queryExecution.sparkPlan, - ioschema = serdeIOSchema) - SparkPlanTest.executePlan(plan, hiveContext) + // Hive serde support ArrayType/MapType/StructType as input and output data type + val query = sql( + """ + |SELECT TRANSFORM (c, d, e) + |USING 'cat' AS (c array, d map, e struct) + |FROM v + """.stripMargin) + checkAnswer(query, identity, df.select('c, 'd, 'e).collect()) } - assert(e.getMessage.contains("Subprocess exited with status")) - assert(uncaughtExceptionHandler.exception.isEmpty) } -} -private case class ExceptionInjectingOperator(child: SparkPlan) extends UnaryExecNode { - override protected def doExecute(): RDD[InternalRow] = { - child.execute().map { x => - assert(TaskContext.get() != null) // Make sure that TaskContext is defined. - Thread.sleep(1000) // This sleep gives the external process time to start. - throw new IllegalArgumentException("intentional exception") + test("SPARK-32106: TRANSFORM doesn't support CalenderIntervalType/UserDefinedType (hive serde)") { + assume(TestUtils.testCommandAvailable("/bin/bash")) + withTempView("v") { + val df = Seq( + (1, new CalendarInterval(7, 1, 1000), new TestUDT.MyDenseVector(Array(1, 2, 3))), + (1, new CalendarInterval(7, 1, 1000), new TestUDT.MyDenseVector(Array(1, 2, 3)))) + .toDF("a", "b", "c") + df.createTempView("v") + + val e1 = intercept[SparkException] { + val plan = createScriptTransformationExec( + input = Seq(df.col("a").expr, df.col("b").expr), + script = "cat", + output = Seq( + AttributeReference("a", IntegerType)(), + AttributeReference("b", CalendarIntervalType)()), + child = df.queryExecution.sparkPlan, + ioschema = hiveIOSchema) + SparkPlanTest.executePlan(plan, hiveContext) + }.getMessage + assert(e1.contains( + "CalendarIntervalType cannot be converted to Hive TypeInfo")) + + val e2 = intercept[SparkException] { + val plan = createScriptTransformationExec( + input = Seq(df.col("a").expr, df.col("c").expr), + script = "cat", + output = Seq( + AttributeReference("a", IntegerType)(), + AttributeReference("c", new TestUDT.MyDenseVectorUDT)()), + child = df.queryExecution.sparkPlan, + ioschema = hiveIOSchema) + SparkPlanTest.executePlan(plan, hiveContext) + }.getMessage + assert(e2.contains( + "MyDenseVectorUDT cannot be converted to Hive TypeInfo")) } } - override def output: Seq[Attribute] = child.output + test("SPARK-32106: TRANSFORM doesn't support" + + " CalenderIntervalType/UserDefinedType end to end (hive serde)") { + assume(TestUtils.testCommandAvailable("/bin/bash")) + withTempView("v") { + val df = Seq( + (1, new CalendarInterval(7, 1, 1000), new TestUDT.MyDenseVector(Array(1, 2, 3))), + (1, new CalendarInterval(7, 1, 1000), new TestUDT.MyDenseVector(Array(1, 2, 3)))) + .toDF("a", "b", "c") + df.createTempView("v") - override def outputPartitioning: Partitioning = child.outputPartitioning + val e1 = intercept[SparkException] { + sql( + """ + |SELECT TRANSFORM(a, b) USING 'cat' AS (a, b) + |FROM v + """.stripMargin).collect() + }.getMessage + assert(e1.contains( + "CalendarIntervalType cannot be converted to Hive TypeInfo")) + + val e2 = intercept[SparkException] { + sql( + """ + |SELECT TRANSFORM(a, c) USING 'cat' AS (a, c) + |FROM v + """.stripMargin).collect() + }.getMessage + assert(e2.contains( + "MyDenseVectorUDT cannot be converted to Hive TypeInfo")) + } + } } + diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index 920f6385f8e19..24b9d25ed94f2 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -33,6 +33,7 @@ import org.apache.spark.sql.catalyst.analysis.{EliminateSubqueryAliases, Functio import org.apache.spark.sql.catalyst.catalog.{CatalogTableType, CatalogUtils, HiveTableRelation} import org.apache.spark.sql.catalyst.parser.ParseException import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, SubqueryAlias} +import org.apache.spark.sql.execution.TestUncaughtExceptionHandler import org.apache.spark.sql.execution.adaptive.{DisableAdaptiveExecutionSuite, EnableAdaptiveExecutionSuite} import org.apache.spark.sql.execution.command.{FunctionsCommand, LoadDataCommand} import org.apache.spark.sql.execution.datasources.{HadoopFsRelation, LogicalRelation}