diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParserUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParserUtils.scala index a3779698a5ac3..f2dab941cb8b2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParserUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParserUtils.scala @@ -83,6 +83,11 @@ object ParserUtils { node.getText.slice(1, node.getText.size - 1) } + /** Collect the entries if any. */ + def entry(key: String, value: Token): Seq[(String, String)] = { + Option(value).toSeq.map(x => key -> string(x)) + } + /** Get the origin (line and position) of the token. */ def position(token: Token): Origin = { val opt = Option(token) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala index 129312160b1b7..e79ebd12d19c8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala @@ -654,10 +654,6 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder(conf) { */ override def visitRowFormatDelimited( ctx: RowFormatDelimitedContext): CatalogStorageFormat = withOrigin(ctx) { - // Collect the entries if any. - def entry(key: String, value: Token): Seq[(String, String)] = { - Option(value).toSeq.map(x => key -> string(x)) - } // TODO we need proper support for the NULL format. val entries = entry("field.delim", ctx.fieldsTerminatedBy) ++ @@ -756,9 +752,6 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder(conf) { // expects a seq of pairs in which the old parsers' token names are used as keys. // Transforming the result of visitRowFormatDelimited would be quite a bit messier than // retrieving the key value pairs ourselves. - def entry(key: String, value: Token): Seq[(String, String)] = { - Option(value).map(t => key -> t.getText).toSeq - } val entries = entry("TOK_TABLEROWFORMATFIELD", c.fieldsTerminatedBy) ++ entry("TOK_TABLEROWFORMATCOLLITEMS", c.collectionItemsTerminatedBy) ++ entry("TOK_TABLEROWFORMATMAPKEYS", c.keysTerminatedBy) ++ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/BaseScriptTransformationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/BaseScriptTransformationSuite.scala index aa800000e0515..3a190840c01e5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/BaseScriptTransformationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/BaseScriptTransformationSuite.scala @@ -65,6 +65,14 @@ abstract class BaseScriptTransformationSuite extends SparkPlanTest with SQLTestU def isHive23OrSpark: Boolean + // In Hive 1.2, the string representation of a decimal omits trailing zeroes. + // But in Hive 2.3, it is always padded to 18 digits with trailing zeroes if necessary. + val decimalToString: Column => Column = if (isHive23OrSpark) { + c => c.cast("string") + } else { + c => c.cast("decimal(1, 0)").cast("string") + } + def createScriptTransformationExec( input: Seq[Expression], script: String, @@ -130,13 +138,6 @@ abstract class BaseScriptTransformationSuite extends SparkPlanTest with SQLTestU |FROM v """.stripMargin) - // In Hive 1.2, the string representation of a decimal omits trailing zeroes. - // But in Hive 2.3, it is always padded to 18 digits with trailing zeroes if necessary. - val decimalToString: Column => Column = if (isHive23OrSpark) { - c => c.cast("string") - } else { - c => c.cast("decimal(1, 0)").cast("string") - } checkAnswer(query, identity, df.select( 'a.cast("string"), 'b.cast("string"), @@ -311,6 +312,66 @@ abstract class BaseScriptTransformationSuite extends SparkPlanTest with SQLTestU } } } + + test("SPARK-32608: Script Transform ROW FORMAT DELIMIT value should format value") { + withTempView("v") { + val df = Seq( + (1, "1", 1.0, BigDecimal(1.0), new Timestamp(1)), + (2, "2", 2.0, BigDecimal(2.0), new Timestamp(2)), + (3, "3", 3.0, BigDecimal(3.0), new Timestamp(3)) + ).toDF("a", "b", "c", "d", "e") // Note column d's data type is Decimal(38, 18) + df.createTempView("v") + + // input/output with same delimit + checkAnswer( + sql( + s""" + |SELECT TRANSFORM(a, b, c, d, e) + | ROW FORMAT DELIMITED + | FIELDS TERMINATED BY ',' + | COLLECTION ITEMS TERMINATED BY '#' + | MAP KEYS TERMINATED BY '@' + | LINES TERMINATED BY '\n' + | NULL DEFINED AS 'null' + | USING 'cat' AS (a, b, c, d, e) + | ROW FORMAT DELIMITED + | FIELDS TERMINATED BY ',' + | COLLECTION ITEMS TERMINATED BY '#' + | MAP KEYS TERMINATED BY '@' + | LINES TERMINATED BY '\n' + | NULL DEFINED AS 'NULL' + |FROM v + """.stripMargin), identity, df.select( + 'a.cast("string"), + 'b.cast("string"), + 'c.cast("string"), + decimalToString('d), + 'e.cast("string")).collect()) + + // input/output with different delimit and show result + checkAnswer( + sql( + s""" + |SELECT TRANSFORM(a, b, c, d, e) + | ROW FORMAT DELIMITED + | FIELDS TERMINATED BY ',' + | LINES TERMINATED BY '\n' + | NULL DEFINED AS 'null' + | USING 'cat' AS (value) + | ROW FORMAT DELIMITED + | FIELDS TERMINATED BY '&' + | LINES TERMINATED BY '\n' + | NULL DEFINED AS 'NULL' + |FROM v + """.stripMargin), identity, df.select( + concat_ws(",", + 'a.cast("string"), + 'b.cast("string"), + 'c.cast("string"), + decimalToString('d), + 'e.cast("string"))).collect()) + } + } } case class ExceptionInjectingOperator(child: SparkPlan) extends UnaryExecNode { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala index 62712cf72eb59..ed3ac0e2ed51d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala @@ -24,8 +24,8 @@ import org.apache.spark.sql.SaveMode import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.analysis.{AnalysisTest, UnresolvedAlias, UnresolvedAttribute, UnresolvedRelation, UnresolvedStar} import org.apache.spark.sql.catalyst.catalog.{BucketSpec, CatalogStorageFormat, CatalogTable, CatalogTableType} -import org.apache.spark.sql.catalyst.expressions.{Ascending, Concat, SortOrder} -import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project, RepartitionByExpression, Sort} +import org.apache.spark.sql.catalyst.expressions.{Ascending, AttributeReference, Concat, SortOrder} +import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.execution.command._ import org.apache.spark.sql.execution.datasources.{CreateTable, RefreshResource} import org.apache.spark.sql.internal.{HiveSerDe, SQLConf, StaticSQLConf} @@ -38,6 +38,7 @@ import org.apache.spark.sql.types.{IntegerType, LongType, StringType, StructType * defined in the Catalyst module. */ class SparkSqlParserSuite extends AnalysisTest { + import org.apache.spark.sql.catalyst.dsl.expressions._ val newConf = new SQLConf private lazy val parser = new SparkSqlParser(newConf) @@ -330,4 +331,44 @@ class SparkSqlParserSuite extends AnalysisTest { assertEqual("ADD FILE /path with space/abc.txt", AddFileCommand("/path with space/abc.txt")) assertEqual("ADD JAR /path with space/abc.jar", AddJarCommand("/path with space/abc.jar")) } + + test("SPARK-32608: script transform with row format delimit") { + assertEqual( + """ + |SELECT TRANSFORM(a, b, c) + | ROW FORMAT DELIMITED + | FIELDS TERMINATED BY ',' + | COLLECTION ITEMS TERMINATED BY '#' + | MAP KEYS TERMINATED BY '@' + | LINES TERMINATED BY '\n' + | NULL DEFINED AS 'null' + | USING 'cat' AS (a, b, c) + | ROW FORMAT DELIMITED + | FIELDS TERMINATED BY ',' + | COLLECTION ITEMS TERMINATED BY '#' + | MAP KEYS TERMINATED BY '@' + | LINES TERMINATED BY '\n' + | NULL DEFINED AS 'NULL' + |FROM testData + """.stripMargin, + ScriptTransformation( + Seq('a, 'b, 'c), + "cat", + Seq(AttributeReference("a", StringType)(), + AttributeReference("b", StringType)(), + AttributeReference("c", StringType)()), + UnresolvedRelation(TableIdentifier("testData")), + ScriptInputOutputSchema( + Seq(("TOK_TABLEROWFORMATFIELD", ","), + ("TOK_TABLEROWFORMATCOLLITEMS", "#"), + ("TOK_TABLEROWFORMATMAPKEYS", "@"), + ("TOK_TABLEROWFORMATLINES", "\n"), + ("TOK_TABLEROWFORMATNULL", "null")), + Seq(("TOK_TABLEROWFORMATFIELD", ","), + ("TOK_TABLEROWFORMATCOLLITEMS", "#"), + ("TOK_TABLEROWFORMATMAPKEYS", "@"), + ("TOK_TABLEROWFORMATLINES", "\n"), + ("TOK_TABLEROWFORMATNULL", "NULL")), None, None, + List.empty, List.empty, None, None, false))) + } }