Skip to content

Commit c39e3cf

Browse files
committed
[SPARK-33934][SQL] Support user defined script command wrapper like hive
1 parent 49aa6eb commit c39e3cf

File tree

6 files changed

+248
-7
lines changed

6 files changed

+248
-7
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2918,6 +2918,14 @@ object SQLConf {
29182918
.checkValues(LegacyBehaviorPolicy.values.map(_.toString))
29192919
.createWithDefault(LegacyBehaviorPolicy.EXCEPTION.toString)
29202920

2921+
val SCRIPT_TRANSFORMATION_COMMAND_WRAPPER =
2922+
buildConf("spark.sql.scriptTransformation.commandWrapper")
2923+
.internal()
2924+
.doc("Command wrapper for executor to execute transformation script.")
2925+
.version("3.2.0")
2926+
.stringConf
2927+
.createWithDefault("/bin/bash -c")
2928+
29212929
val SCRIPT_TRANSFORMATION_EXIT_TIMEOUT =
29222930
buildConf("spark.sql.scriptTransformation.exitTimeoutInSeconds")
29232931
.internal()

sql/core/src/main/scala/org/apache/spark/sql/execution/BaseScriptTransformationExec.scala

Lines changed: 62 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,16 +17,17 @@
1717

1818
package org.apache.spark.sql.execution
1919

20-
import java.io.{BufferedReader, InputStream, InputStreamReader, OutputStream}
20+
import java.io._
2121
import java.nio.charset.StandardCharsets
22+
import java.util.ArrayList
2223
import java.util.concurrent.TimeUnit
2324

2425
import scala.collection.JavaConverters._
2526
import scala.util.control.NonFatal
2627

2728
import org.apache.hadoop.conf.Configuration
2829

29-
import org.apache.spark.{SparkException, TaskContext}
30+
import org.apache.spark.{SparkException, SparkFiles, TaskContext}
3031
import org.apache.spark.internal.Logging
3132
import org.apache.spark.rdd.RDD
3233
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
@@ -46,6 +47,8 @@ trait BaseScriptTransformationExec extends UnaryExecNode {
4647
def child: SparkPlan
4748
def ioschema: ScriptTransformationIOSchema
4849

50+
type ProcParameters = (OutputStream, Process, InputStream, CircularBuffer)
51+
4952
protected lazy val inputExpressionsWithoutSerde: Seq[Expression] = {
5053
input.map(Cast(_, StringType).withTimeZone(conf.sessionLocalTimeZone))
5154
}
@@ -55,8 +58,11 @@ trait BaseScriptTransformationExec extends UnaryExecNode {
5558
override def outputPartitioning: Partitioning = child.outputPartitioning
5659

5760
override def doExecute(): RDD[InternalRow] = {
61+
val hadoopConf = sqlContext.sessionState.newHadoopConf()
62+
hadoopConf.set(SQLConf.SCRIPT_TRANSFORMATION_COMMAND_WRAPPER.key,
63+
conf.getConf(SQLConf.SCRIPT_TRANSFORMATION_COMMAND_WRAPPER))
5864
val broadcastedHadoopConf =
59-
new SerializableConfiguration(sqlContext.sessionState.newHadoopConf())
65+
new SerializableConfiguration(hadoopConf)
6066

6167
child.execute().mapPartitions { iter =>
6268
if (iter.hasNext) {
@@ -69,9 +75,11 @@ trait BaseScriptTransformationExec extends UnaryExecNode {
6975
}
7076
}
7177

72-
protected def initProc: (OutputStream, Process, InputStream, CircularBuffer) = {
73-
val cmd = List("/bin/bash", "-c", script)
78+
protected def initProc(hadoopConf: Configuration): ProcParameters = {
79+
val wrapper = splitArgs(hadoopConf.get(SQLConf.SCRIPT_TRANSFORMATION_COMMAND_WRAPPER.key))
80+
val cmd = wrapper.toList ++ List(script)
7481
val builder = new ProcessBuilder(cmd.asJava)
82+
.directory(new File(SparkFiles.getRootDirectory()))
7583

7684
val proc = builder.start()
7785
val inputStream = proc.getInputStream
@@ -181,6 +189,55 @@ trait BaseScriptTransformationExec extends UnaryExecNode {
181189
}
182190
}
183191

192+
def splitArgs(args: String): Array[String] = {
193+
val OUTSIDE = 1
194+
val SINGLEQ = 2
195+
val DOUBLEQ = 3
196+
val argList = new ArrayList[String]
197+
val ch = args.toCharArray
198+
val clen = ch.length
199+
var state = OUTSIDE
200+
var argstart = 0
201+
var c = 0
202+
while (c <= clen) {
203+
val last = c == clen
204+
var lastState = state
205+
var endToken = false
206+
if (!last) {
207+
if (ch(c) == '\'') {
208+
if (state == OUTSIDE) {
209+
state = SINGLEQ
210+
} else if (state == SINGLEQ) {
211+
state = OUTSIDE
212+
}
213+
endToken = state != lastState
214+
} else if (ch(c) == '"') {
215+
if (state == OUTSIDE) {
216+
state = DOUBLEQ
217+
} else if (state == DOUBLEQ) {
218+
state = OUTSIDE
219+
}
220+
endToken = state != lastState
221+
} else if (ch(c) == ' ') {
222+
if (state == OUTSIDE) {
223+
endToken = true
224+
}
225+
}
226+
}
227+
if (last || endToken) {
228+
if (c == argstart) {
229+
// unquoted space
230+
} else {
231+
argList.add(args.substring(argstart, c))
232+
}
233+
argstart = c + 1
234+
lastState = state
235+
}
236+
c += 1
237+
}
238+
argList.toArray(new Array[String](0))
239+
}
240+
184241
private lazy val outputFieldWriters: Seq[String => Any] = output.map { attr =>
185242
val converter = CatalystTypeConverters.createToCatalystConverter(attr.dataType)
186243
attr.dataType match {

sql/core/src/main/scala/org/apache/spark/sql/execution/SparkScriptTransformationExec.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ case class SparkScriptTransformationExec(
4848
inputIterator: Iterator[InternalRow],
4949
hadoopConf: Configuration): Iterator[InternalRow] = {
5050

51-
val (outputStream, proc, inputStream, stderrBuffer) = initProc
51+
val (outputStream, proc, inputStream, stderrBuffer) = initProc(hadoopConf)
5252

5353
val outputProjection = new InterpretedProjection(inputExpressionsWithoutSerde, child.output)
5454

sql/core/src/test/resources/test_script.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
#!/usr/bin/python
2+
13
# Licensed to the Apache Software Foundation (ASF) under one or more
24
# contributor license agreements. See the NOTICE file distributed with
35
# this work for additional information regarding copyright ownership.

sql/core/src/test/scala/org/apache/spark/sql/execution/BaseScriptTransformationSuite.scala

Lines changed: 174 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -470,6 +470,180 @@ abstract class BaseScriptTransformationSuite extends SparkPlanTest with SQLTestU
470470
Row("3\u00014\u00015") :: Nil)
471471
}
472472
}
473+
474+
test("SPARK-33934: Check default execute command wrapper") {
475+
assume(TestUtils.testCommandAvailable("python"))
476+
val scriptFilePath = copyAndGetResourceFile("test_script.py", ".py").getAbsoluteFile
477+
withTempView("v") {
478+
val df = Seq(
479+
(1, "1", 1.0, BigDecimal(1.0), new Timestamp(1)),
480+
(2, "2", 2.0, BigDecimal(2.0), new Timestamp(2)),
481+
(3, "3", 3.0, BigDecimal(3.0), new Timestamp(3))
482+
).toDF("a", "b", "c", "d", "e") // Note column d's data type is Decimal(38, 18)
483+
df.createTempView("v")
484+
485+
// test '/bin/bash -c python /path/to/script.py'
486+
checkAnswer(
487+
sql(
488+
s"""
489+
|SELECT
490+
|TRANSFORM(a, b, c, d, e)
491+
| ROW FORMAT DELIMITED
492+
| FIELDS TERMINATED BY '\t'
493+
| USING 'python $scriptFilePath' AS (a, b, c, d, e)
494+
| ROW FORMAT DELIMITED
495+
| FIELDS TERMINATED BY '\t'
496+
|FROM v
497+
""".stripMargin), identity, df.select(
498+
'a.cast("string"),
499+
'b.cast("string"),
500+
'c.cast("string"),
501+
'd.cast("string"),
502+
'e.cast("string")).collect())
503+
504+
// test '/bin/bash -c /path/to/script.py' with script not executable
505+
val e1 = intercept[TestFailedException] {
506+
checkAnswer(
507+
sql(
508+
s"""
509+
|SELECT
510+
|TRANSFORM(a, b, c, d, e)
511+
| ROW FORMAT DELIMITED
512+
| FIELDS TERMINATED BY '\t'
513+
| USING '$scriptFilePath' AS (a, b, c, d, e)
514+
| ROW FORMAT DELIMITED
515+
| FIELDS TERMINATED BY '\t'
516+
|FROM v
517+
""".stripMargin), identity, df.select(
518+
'a.cast("string"),
519+
'b.cast("string"),
520+
'c.cast("string"),
521+
'd.cast("string"),
522+
'e.cast("string")).collect())
523+
}.getMessage
524+
assert(e1.contains("Permission denied"))
525+
526+
// test '/bin/bash -c /path/to/script.py' with script executable
527+
scriptFilePath.setExecutable(true)
528+
checkAnswer(
529+
sql(
530+
s"""
531+
|SELECT
532+
|TRANSFORM(a, b, c, d, e)
533+
| ROW FORMAT DELIMITED
534+
| FIELDS TERMINATED BY '\t'
535+
| USING '$scriptFilePath' AS (a, b, c, d, e)
536+
| ROW FORMAT DELIMITED
537+
| FIELDS TERMINATED BY '\t'
538+
|FROM v
539+
""".stripMargin), identity, df.select(
540+
'a.cast("string"),
541+
'b.cast("string"),
542+
'c.cast("string"),
543+
'd.cast("string"),
544+
'e.cast("string")).collect())
545+
546+
scriptFilePath.setExecutable(false)
547+
sql(s"ADD FILE ${scriptFilePath.getAbsolutePath}")
548+
549+
// test '/bin/bash -c script.py'
550+
val e2 = intercept[TestFailedException] {
551+
checkAnswer(
552+
sql(
553+
s"""
554+
|SELECT TRANSFORM(a, b, c, d, e)
555+
| ROW FORMAT DELIMITED
556+
| FIELDS TERMINATED BY '\t'
557+
| USING '${scriptFilePath.getName}' AS (a, b, c, d, e)
558+
| ROW FORMAT DELIMITED
559+
| FIELDS TERMINATED BY '\t'
560+
|FROM v
561+
""".stripMargin), identity, df.select(
562+
'a.cast("string"),
563+
'b.cast("string"),
564+
'c.cast("string"),
565+
'd.cast("string"),
566+
'e.cast("string")).collect())
567+
}.getMessage()
568+
assert(e2.contains("command not found"))
569+
}
570+
}
571+
572+
test("SPARK-33934: Check execute command wrapper is empty") {
573+
assume(TestUtils.testCommandAvailable("python"))
574+
val scriptFilePath = copyAndGetResourceFile(
575+
"test_script.py", "_test_empty.py").getAbsoluteFile
576+
withTempView("v") {
577+
withSQLConf(SQLConf.SCRIPT_TRANSFORMATION_COMMAND_WRAPPER.key -> "") {
578+
val df = Seq(
579+
(1, "1", 1.0, BigDecimal(1.0), new Timestamp(1)),
580+
(2, "2", 2.0, BigDecimal(2.0), new Timestamp(2)),
581+
(3, "3", 3.0, BigDecimal(3.0), new Timestamp(3))
582+
).toDF("a", "b", "c", "d", "e") // Note column d's data type is Decimal(38, 18)
583+
df.createTempView("v")
584+
585+
scriptFilePath.setExecutable(true)
586+
sql(s"ADD FILE ${scriptFilePath.getAbsolutePath}")
587+
sql(
588+
s"""
589+
|SELECT TRANSFORM(a)
590+
| ROW FORMAT DELIMITED
591+
| FIELDS TERMINATED BY '\t'
592+
| USING 'pwd' AS (a)
593+
| ROW FORMAT DELIMITED
594+
| FIELDS TERMINATED BY '&'
595+
|FROM (SELECT 1 AS a) TEMP
596+
""".stripMargin).show(false)
597+
598+
sql(
599+
s"""
600+
|SELECT TRANSFORM(a)
601+
| ROW FORMAT DELIMITED
602+
| FIELDS TERMINATED BY '\t'
603+
| USING 'ls' AS (a)
604+
| ROW FORMAT DELIMITED
605+
| FIELDS TERMINATED BY '&'
606+
|FROM (SELECT 1 AS a) TEMP
607+
""".stripMargin).show(false)
608+
609+
// test 'python script.py'
610+
checkAnswer(
611+
sql(
612+
s"""
613+
|SELECT TRANSFORM(a, b, c, d, e)
614+
| ROW FORMAT DELIMITED
615+
| FIELDS TERMINATED BY '\t'
616+
| USING 'python ${scriptFilePath.getName}' AS (a, b, c, d, e)
617+
| ROW FORMAT DELIMITED
618+
| FIELDS TERMINATED BY '\t'
619+
|FROM v
620+
""".stripMargin), identity, df.select(
621+
'a.cast("string"),
622+
'b.cast("string"),
623+
'c.cast("string"),
624+
'd.cast("string"),
625+
'e.cast("string")).collect())
626+
627+
// test 'script.py'
628+
checkAnswer(
629+
sql(
630+
s"""
631+
|SELECT TRANSFORM(a, b, c, d, e)
632+
| ROW FORMAT DELIMITED
633+
| FIELDS TERMINATED BY '\t'
634+
| USING '${scriptFilePath.getName}' AS (a, b, c, d, e)
635+
| ROW FORMAT DELIMITED
636+
| FIELDS TERMINATED BY '\t'
637+
|FROM v
638+
""".stripMargin), identity, df.select(
639+
'a.cast("string"),
640+
'b.cast("string"),
641+
'c.cast("string"),
642+
'd.cast("string"),
643+
'e.cast("string")).collect())
644+
}
645+
}
646+
}
473647
}
474648

475649
case class ExceptionInjectingOperator(child: SparkPlan) extends UnaryExecNode {

sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveScriptTransformationExec.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,7 @@ case class HiveScriptTransformationExec(
140140
inputIterator: Iterator[InternalRow],
141141
hadoopConf: Configuration): Iterator[InternalRow] = {
142142

143-
val (outputStream, proc, inputStream, stderrBuffer) = initProc
143+
val (outputStream, proc, inputStream, stderrBuffer) = initProc(hadoopConf)
144144

145145
val (inputSerde, inputSoi) = initInputSerDe(ioschema, input).getOrElse((null, null))
146146

0 commit comments

Comments
 (0)