Skip to content

Commit a769aa7

Browse files
committed
address comment
1 parent 7916d72 commit a769aa7

File tree

8 files changed

+138
-61
lines changed

8 files changed

+138
-61
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala

Lines changed: 22 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -744,6 +744,27 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging
744744
selectClause.hints.asScala.foldRight(withWindow)(withHints)
745745
}
746746

747+
// Decode and input/output format.
748+
type Format = (Seq[(String, String)], Option[String], Seq[(String, String)], Option[String])
749+
750+
protected def getRowFormatDelimited(ctx: RowFormatDelimitedContext): Format = {
751+
// TODO we should use the visitRowFormatDelimited function here. However HiveScriptIOSchema
752+
// expects a seq of pairs in which the old parsers' token names are used as keys.
753+
// Transforming the result of visitRowFormatDelimited would be quite a bit messier than
754+
// retrieving the key value pairs ourselves.
755+
def entry(key: String, value: Token): Seq[(String, String)] = {
756+
Option(value).map(t => key -> t.getText).toSeq
757+
}
758+
759+
val entries = entry("TOK_TABLEROWFORMATFIELD", ctx.fieldsTerminatedBy) ++
760+
entry("TOK_TABLEROWFORMATCOLLITEMS", ctx.collectionItemsTerminatedBy) ++
761+
entry("TOK_TABLEROWFORMATMAPKEYS", ctx.keysTerminatedBy) ++
762+
entry("TOK_TABLEROWFORMATLINES", ctx.linesSeparatedBy) ++
763+
entry("TOK_TABLEROWFORMATNULL", ctx.nullDefinedAs)
764+
765+
(entries, None, Seq.empty, None)
766+
}
767+
747768
/**
748769
* Create a [[ScriptInputOutputSchema]].
749770
*/
@@ -754,26 +775,10 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging
754775
outRowFormat: RowFormatContext,
755776
recordReader: Token,
756777
schemaLess: Boolean): ScriptInputOutputSchema = {
757-
// Decode and input/output format.
758-
type Format = (Seq[(String, String)], Option[String], Seq[(String, String)], Option[String])
759778

760779
def format(fmt: RowFormatContext): Format = fmt match {
761780
case c: RowFormatDelimitedContext =>
762-
// TODO we should use the visitRowFormatDelimited function here. However HiveScriptIOSchema
763-
// expects a seq of pairs in which the old parsers' token names are used as keys.
764-
// Transforming the result of visitRowFormatDelimited would be quite a bit messier than
765-
// retrieving the key value pairs ourselves.
766-
def entry(key: String, value: Token): Seq[(String, String)] = {
767-
Option(value).map(t => key -> t.getText).toSeq
768-
}
769-
770-
val entries = entry("TOK_TABLEROWFORMATFIELD", c.fieldsTerminatedBy) ++
771-
entry("TOK_TABLEROWFORMATCOLLITEMS", c.collectionItemsTerminatedBy) ++
772-
entry("TOK_TABLEROWFORMATMAPKEYS", c.keysTerminatedBy) ++
773-
entry("TOK_TABLEROWFORMATLINES", c.linesSeparatedBy) ++
774-
entry("TOK_TABLEROWFORMATNULL", c.nullDefinedAs)
775-
776-
(entries, None, Seq.empty, None)
781+
getRowFormatDelimited(c)
777782

778783
case c: RowFormatSerdeContext =>
779784
throw new ParseException("TRANSFORM with serde is only supported in hive mode", ctx)

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala

Lines changed: 93 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ import org.apache.spark.sql.catalyst.expressions._
2323
import org.apache.spark.sql.catalyst.plans._
2424
import org.apache.spark.sql.catalyst.plans.logical._
2525
import org.apache.spark.sql.internal.SQLConf
26-
import org.apache.spark.sql.types.IntegerType
26+
import org.apache.spark.sql.types.{IntegerType, LongType, StringType}
2727

2828
/**
2929
* Parser test cases for rules defined in [[CatalystSqlParser]] / [[AstBuilder]].
@@ -1031,4 +1031,96 @@ class PlanParserSuite extends AnalysisTest {
10311031
assertEqual("select a, b from db.c;;;", table("db", "c").select('a, 'b))
10321032
assertEqual("select a, b from db.c; ;; ;", table("db", "c").select('a, 'b))
10331033
}
1034+
1035+
test("SPARK-32106: TRANSFORM without serde") {
1036+
// verify schema less
1037+
assertEqual(
1038+
"""
1039+
|SELECT TRANSFORM(a, b, c)
1040+
|USING 'cat'
1041+
|FROM testData
1042+
""".stripMargin,
1043+
ScriptTransformation(
1044+
Seq('a, 'b, 'c),
1045+
"cat",
1046+
Seq(AttributeReference("key", StringType)(),
1047+
AttributeReference("value", StringType)()),
1048+
UnresolvedRelation(TableIdentifier("testData")),
1049+
ScriptInputOutputSchema(List.empty, List.empty, None, None,
1050+
List.empty, List.empty, None, None, true))
1051+
)
1052+
1053+
// verify without output schema
1054+
assertEqual(
1055+
"""
1056+
|SELECT TRANSFORM(a, b, c)
1057+
|USING 'cat' AS (a, b, c)
1058+
|FROM testData
1059+
""".stripMargin,
1060+
ScriptTransformation(
1061+
Seq('a, 'b, 'c),
1062+
"cat",
1063+
Seq(AttributeReference("a", StringType)(),
1064+
AttributeReference("b", StringType)(),
1065+
AttributeReference("c", StringType)()),
1066+
UnresolvedRelation(TableIdentifier("testData")),
1067+
ScriptInputOutputSchema(List.empty, List.empty, None, None,
1068+
List.empty, List.empty, None, None, false)))
1069+
1070+
// verify with output schema
1071+
assertEqual(
1072+
"""
1073+
|SELECT TRANSFORM(a, b, c)
1074+
|USING 'cat' AS (a int, b string, c long)
1075+
|FROM testData
1076+
""".stripMargin,
1077+
ScriptTransformation(
1078+
Seq('a, 'b, 'c),
1079+
"cat",
1080+
Seq(AttributeReference("a", IntegerType)(),
1081+
AttributeReference("b", StringType)(),
1082+
AttributeReference("c", LongType)()),
1083+
UnresolvedRelation(TableIdentifier("testData")),
1084+
ScriptInputOutputSchema(List.empty, List.empty, None, None,
1085+
List.empty, List.empty, None, None, false)))
1086+
1087+
// verify with ROW FORMAT DELIMETED
1088+
assertEqual(
1089+
"""
1090+
|SELECT TRANSFORM(a, b, c)
1091+
|ROW FORMAT DELIMITED
1092+
|FIELDS TERMINATED BY '\t'
1093+
|COLLECTION ITEMS TERMINATED BY '\u0002'
1094+
|MAP KEYS TERMINATED BY '\u0003'
1095+
|LINES TERMINATED BY '\n'
1096+
|NULL DEFINED AS 'null'
1097+
|USING 'cat' AS (a, b, c)
1098+
|ROW FORMAT DELIMITED
1099+
|FIELDS TERMINATED BY '\t'
1100+
|COLLECTION ITEMS TERMINATED BY '\u0004'
1101+
|MAP KEYS TERMINATED BY '\u0005'
1102+
|LINES TERMINATED BY '\n'
1103+
|NULL DEFINED AS 'NULL'
1104+
|FROM testData
1105+
""".stripMargin,
1106+
ScriptTransformation(
1107+
Seq('a, 'b, 'c),
1108+
"cat",
1109+
Seq(AttributeReference("a", StringType)(),
1110+
AttributeReference("b", StringType)(),
1111+
AttributeReference("c", StringType)()),
1112+
UnresolvedRelation(TableIdentifier("testData")),
1113+
ScriptInputOutputSchema(
1114+
Seq(("TOK_TABLEROWFORMATFIELD", "'\\t'"),
1115+
("TOK_TABLEROWFORMATCOLLITEMS", "'\u0002'"),
1116+
("TOK_TABLEROWFORMATMAPKEYS", "'\u0003'"),
1117+
("TOK_TABLEROWFORMATLINES", "'\\n'"),
1118+
("TOK_TABLEROWFORMATNULL", "'null'")),
1119+
Seq(("TOK_TABLEROWFORMATFIELD", "'\\t'"),
1120+
("TOK_TABLEROWFORMATCOLLITEMS", "'\u0004'"),
1121+
("TOK_TABLEROWFORMATMAPKEYS", "'\u0005'"),
1122+
("TOK_TABLEROWFORMATLINES", "'\\n'"),
1123+
("TOK_TABLEROWFORMATNULL", "'NULL'")), None, None,
1124+
List.empty, List.empty, None, None, false)))
1125+
}
10341126
}

sql/core/src/main/scala/org/apache/spark/sql/execution/BaseScriptTransformationExec.scala

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,7 @@ trait BaseScriptTransformationExec extends UnaryExecNode {
108108
prevLine: String =>
109109
new GenericInternalRow(
110110
prevLine.split(outputRowFormat)
111-
.zip(fieldWriters)
111+
.zip(outputFieldWriters)
112112
.map { case (data, writer) => writer(data) })
113113
} else {
114114
// In schema less mode, hive default serde will choose first two output column as output
@@ -182,7 +182,7 @@ trait BaseScriptTransformationExec extends UnaryExecNode {
182182
}
183183
}
184184

185-
private lazy val fieldWriters: Seq[String => Any] = output.map { attr =>
185+
private lazy val outputFieldWriters: Seq[String => Any] = output.map { attr =>
186186
val converter = CatalystTypeConverters.createToCatalystConverter(attr.dataType)
187187
attr.dataType match {
188188
case StringType => wrapperConvertException(data => data, converter)
@@ -218,10 +218,9 @@ trait BaseScriptTransformationExec extends UnaryExecNode {
218218
converter)
219219
case udt: UserDefinedType[_] =>
220220
wrapperConvertException(data => udt.deserialize(data), converter)
221-
case ArrayType(_, _) | MapType(_, _, _) | StructType(_) =>
222-
throw new SparkException("TRANSFORM without serde don't support" +
223-
" ArrayType/MapType/StructType as output data type")
224-
case _ => wrapperConvertException(data => data, converter)
221+
case dt =>
222+
throw new SparkException("TRANSFORM without serde does not support " +
223+
s"${dt.getClass.getSimpleName} as output data type")
225224
}
226225
}
227226

sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala

Lines changed: 1 addition & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -689,30 +689,12 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder(conf) {
689689
recordReader,
690690
schemaLess)
691691
} else {
692-
693-
// Decode and input/output format.
694-
type Format = (Seq[(String, String)], Option[String], Seq[(String, String)], Option[String])
695-
696692
def format(
697693
fmt: RowFormatContext,
698694
configKey: String,
699695
defaultConfigValue: String): Format = fmt match {
700696
case c: RowFormatDelimitedContext =>
701-
// TODO we should use visitRowFormatDelimited function here. However HiveScriptIOSchema
702-
// expects a seq of pairs in which the old parsers' token names are used as keys.
703-
// Transforming the result of visitRowFormatDelimited would be quite a bit messier than
704-
// retrieving the key value pairs ourselves.
705-
def entry(key: String, value: Token): Seq[(String, String)] = {
706-
Option(value).map(t => key -> t.getText).toSeq
707-
}
708-
709-
val entries = entry("TOK_TABLEROWFORMATFIELD", c.fieldsTerminatedBy) ++
710-
entry("TOK_TABLEROWFORMATCOLLITEMS", c.collectionItemsTerminatedBy) ++
711-
entry("TOK_TABLEROWFORMATMAPKEYS", c.keysTerminatedBy) ++
712-
entry("TOK_TABLEROWFORMATLINES", c.linesSeparatedBy) ++
713-
entry("TOK_TABLEROWFORMATNULL", c.nullDefinedAs)
714-
715-
(entries, None, Seq.empty, None)
697+
getRowFormatDelimited(c)
716698

717699
case c: RowFormatSerdeContext =>
718700
// Use a serde format.

sql/core/src/test/resources/sql-tests/inputs/transform.sql

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ SELECT a, b, decode(c, 'UTF-8'), d, e, f, g, h, i, j, k, l FROM (
5757
FROM t
5858
) tmp;
5959

60-
-- handle schema less
60+
-- SPARK-32388 handle schema less
6161
SELECT TRANSFORM(a)
6262
USING 'cat'
6363
FROM t;

sql/core/src/test/scala/org/apache/spark/sql/execution/BaseScriptTransformationSuite.scala

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -268,7 +268,6 @@ abstract class BaseScriptTransformationSuite extends SparkPlanTest with SQLTestU
268268
}
269269
}
270270

271-
272271
test("SPARK-32106: TRANSFORM should respect DATETIME_JAVA8API_ENABLED (no serde)") {
273272
assume(TestUtils.testCommandAvailable("python"))
274273
Array(false, true).foreach { java8AapiEnable =>

sql/core/src/test/scala/org/apache/spark/sql/execution/SparkScriptTransformationSuite.scala

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ class SparkScriptTransformationSuite extends BaseScriptTransformationSuite with
6262
}
6363
}
6464

65-
test("TRANSFORM don't support ArrayType/MapType/StructType as output data type (no serde)") {
65+
test("TRANSFORM doesn't support ArrayType/MapType/StructType as output data type (no serde)") {
6666
assume(TestUtils.testCommandAvailable("/bin/bash"))
6767
// check for ArrayType
6868
val e1 = intercept[SparkException] {
@@ -73,8 +73,8 @@ class SparkScriptTransformationSuite extends BaseScriptTransformationSuite with
7373
|FROM VALUES (array(1, 1), map('1', 1), struct(1, 'a')) t(a, b, c)
7474
""".stripMargin).collect()
7575
}.getMessage
76-
assert(e1.contains("TRANSFORM without serde don't support" +
77-
" ArrayType/MapType/StructType as output data type"))
76+
assert(e1.contains("TRANSFORM without serde does not support" +
77+
" ArrayType as output data type"))
7878

7979
// check for MapType
8080
val e2 = intercept[SparkException] {
@@ -85,8 +85,8 @@ class SparkScriptTransformationSuite extends BaseScriptTransformationSuite with
8585
|FROM VALUES (array(1, 1), map('1', 1), struct(1, 'a')) t(a, b, c)
8686
""".stripMargin).collect()
8787
}.getMessage
88-
assert(e2.contains("TRANSFORM without serde don't support" +
89-
" ArrayType/MapType/StructType as output data type"))
88+
assert(e2.contains("TRANSFORM without serde does not support" +
89+
" MapType as output data type"))
9090

9191
// check for StructType
9292
val e3 = intercept[SparkException] {
@@ -97,7 +97,7 @@ class SparkScriptTransformationSuite extends BaseScriptTransformationSuite with
9797
|FROM VALUES (array(1, 1), map('1', 1), struct(1, 'a')) t(a, b, c)
9898
""".stripMargin).collect()
9999
}.getMessage
100-
assert(e3.contains("TRANSFORM without serde don't support" +
101-
" ArrayType/MapType/StructType as output data type"))
100+
assert(e3.contains("TRANSFORM without serde does not support" +
101+
" StructType as output data type"))
102102
}
103103
}

sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveScriptTransformationSuite.scala

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ class HiveScriptTransformationSuite extends BaseScriptTransformationSuite with T
5353
)
5454
}
5555

56-
private val serdeIOSchema: ScriptTransformationIOSchema = {
56+
private val hiveIOSchema: ScriptTransformationIOSchema = {
5757
defaultIOSchema.copy(
5858
inputSerdeClass = Some(classOf[LazySimpleSerDe].getCanonicalName),
5959
outputSerdeClass = Some(classOf[LazySimpleSerDe].getCanonicalName)
@@ -71,7 +71,7 @@ class HiveScriptTransformationSuite extends BaseScriptTransformationSuite with T
7171
script = "cat",
7272
output = Seq(AttributeReference("a", StringType)()),
7373
child = child,
74-
ioschema = serdeIOSchema
74+
ioschema = hiveIOSchema
7575
),
7676
rowsDf.collect())
7777
assert(uncaughtExceptionHandler.exception.isEmpty)
@@ -89,7 +89,7 @@ class HiveScriptTransformationSuite extends BaseScriptTransformationSuite with T
8989
script = "cat",
9090
output = Seq(AttributeReference("a", StringType)()),
9191
child = ExceptionInjectingOperator(child),
92-
ioschema = serdeIOSchema
92+
ioschema = hiveIOSchema
9393
),
9494
rowsDf.collect())
9595
}
@@ -110,7 +110,7 @@ class HiveScriptTransformationSuite extends BaseScriptTransformationSuite with T
110110
script = "some_non_existent_command",
111111
output = Seq(AttributeReference("a", StringType)()),
112112
child = rowsDf.queryExecution.sparkPlan,
113-
ioschema = serdeIOSchema)
113+
ioschema = hiveIOSchema)
114114
SparkPlanTest.executePlan(plan, hiveContext)
115115
}
116116
assert(e.getMessage.contains("Subprocess exited with status"))
@@ -131,7 +131,7 @@ class HiveScriptTransformationSuite extends BaseScriptTransformationSuite with T
131131
script = "cat",
132132
output = Seq(AttributeReference("name", StringType)()),
133133
child = child,
134-
ioschema = serdeIOSchema
134+
ioschema = hiveIOSchema
135135
),
136136
rowsDf.select("name").collect())
137137
assert(uncaughtExceptionHandler.exception.isEmpty)
@@ -148,7 +148,7 @@ class HiveScriptTransformationSuite extends BaseScriptTransformationSuite with T
148148
script = "some_non_existent_command",
149149
output = Seq(AttributeReference("a", StringType)()),
150150
child = rowsDf.queryExecution.sparkPlan,
151-
ioschema = serdeIOSchema)
151+
ioschema = hiveIOSchema)
152152
SparkPlanTest.executePlan(plan, hiveContext)
153153
}
154154
assert(e.getMessage.contains("Subprocess exited with status"))
@@ -212,7 +212,7 @@ class HiveScriptTransformationSuite extends BaseScriptTransformationSuite with T
212212
StructField("col1", IntegerType, false),
213213
StructField("col2", StringType, true))))()),
214214
child = child,
215-
ioschema = serdeIOSchema
215+
ioschema = hiveIOSchema
216216
),
217217
df.select('c, 'd, 'e).collect())
218218
}
@@ -256,7 +256,7 @@ class HiveScriptTransformationSuite extends BaseScriptTransformationSuite with T
256256
AttributeReference("a", IntegerType)(),
257257
AttributeReference("b", CalendarIntervalType)()),
258258
child = df.queryExecution.sparkPlan,
259-
ioschema = serdeIOSchema)
259+
ioschema = hiveIOSchema)
260260
SparkPlanTest.executePlan(plan, hiveContext)
261261
}
262262
assert(e1.getMessage.contains("scala.MatchError: CalendarIntervalType"))
@@ -269,7 +269,7 @@ class HiveScriptTransformationSuite extends BaseScriptTransformationSuite with T
269269
AttributeReference("a", IntegerType)(),
270270
AttributeReference("c", new TestUDT.MyDenseVectorUDT)()),
271271
child = df.queryExecution.sparkPlan,
272-
ioschema = serdeIOSchema)
272+
ioschema = hiveIOSchema)
273273
SparkPlanTest.executePlan(plan, hiveContext)
274274
}
275275
assert(e2.getMessage.contains(

0 commit comments

Comments
 (0)