Skip to content

Commit ec754e2

Browse files
committed
fix input and out put format
1 parent 5bfa669 commit ec754e2

File tree

5 files changed

+75
-16
lines changed

5 files changed

+75
-16
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/BaseScriptTransformationExec.scala

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,12 +27,15 @@ import org.apache.hadoop.conf.Configuration
2727

2828
import org.apache.spark.{SparkException, TaskContext}
2929
import org.apache.spark.internal.Logging
30+
import org.apache.spark.network.util.JavaUtils
3031
import org.apache.spark.rdd.RDD
3132
import org.apache.spark.sql.catalyst.InternalRow
3233
import org.apache.spark.sql.catalyst.expressions.{AttributeSet, UnsafeProjection}
3334
import org.apache.spark.sql.catalyst.plans.physical.Partitioning
35+
import org.apache.spark.sql.catalyst.util.{DateTimeUtils, IntervalUtils}
3436
import org.apache.spark.sql.internal.SQLConf
35-
import org.apache.spark.sql.types.DataType
37+
import org.apache.spark.sql.types._
38+
import org.apache.spark.unsafe.types.UTF8String
3639
import org.apache.spark.util.{CircularBuffer, SerializableConfiguration, Utils}
3740

3841
trait BaseScriptTransformationExec extends UnaryExecNode {
@@ -87,6 +90,41 @@ trait BaseScriptTransformationExec extends UnaryExecNode {
8790
}
8891
}
8992
}
93+
94+
def wrapper(data: String, dt: DataType): Any = {
95+
dt match {
96+
case StringType => data
97+
case ByteType => JavaUtils.stringToBytes(data)
98+
case IntegerType => data.toInt
99+
case ShortType => data.toShort
100+
case LongType => data.toLong
101+
case FloatType => data.toFloat
102+
case DoubleType => data.toDouble
103+
case dt: DecimalType => BigDecimal(data)
104+
case DateType if conf.datetimeJava8ApiEnabled =>
105+
DateTimeUtils.stringToDate(
106+
UTF8String.fromString(data),
107+
DateTimeUtils.getZoneId(conf.sessionLocalTimeZone))
108+
.map(DateTimeUtils.daysToLocalDate).orNull
109+
case DateType =>
110+
DateTimeUtils.stringToDate(
111+
UTF8String.fromString(data),
112+
DateTimeUtils.getZoneId(conf.sessionLocalTimeZone))
113+
.map(DateTimeUtils.toJavaDate).orNull
114+
case TimestampType if conf.datetimeJava8ApiEnabled =>
115+
DateTimeUtils.stringToTimestamp(
116+
UTF8String.fromString(data),
117+
DateTimeUtils.getZoneId(conf.sessionLocalTimeZone))
118+
.map(DateTimeUtils.microsToInstant).orNull
119+
case TimestampType =>
120+
DateTimeUtils.stringToTimestamp(
121+
UTF8String.fromString(data),
122+
DateTimeUtils.getZoneId(conf.sessionLocalTimeZone))
123+
.map(DateTimeUtils.toJavaTimestamp).orNull
124+
case CalendarIntervalType => IntervalUtils.stringToInterval(UTF8String.fromString(data))
125+
case dataType: DataType => data
126+
}
127+
}
90128
}
91129

92130
abstract class BaseScriptTransformationWriterThread extends Thread with Logging {

sql/core/src/main/scala/org/apache/spark/sql/execution/SparkScriptTransformationExec.scala

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ import org.apache.spark.TaskContext
2929
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
3030
import org.apache.spark.sql.catalyst.expressions._
3131
import org.apache.spark.sql.catalyst.plans.logical.ScriptInputOutputSchema
32-
import org.apache.spark.sql.types.DataType
32+
import org.apache.spark.sql.types._
3333
import org.apache.spark.util.{CircularBuffer, RedirectThread}
3434

3535
/**
@@ -67,7 +67,9 @@ case class SparkScriptTransformationExec(
6767
stderrBuffer,
6868
"Thread-ScriptTransformation-STDERR-Consumer").start()
6969

70-
val outputProjection = new InterpretedProjection(input, child.output)
70+
val finalInput = input.map(Cast(_, StringType).withTimeZone(conf.sessionLocalTimeZone))
71+
72+
val outputProjection = new InterpretedProjection(finalInput, child.output)
7173

7274
// This new thread will consume the ScriptTransformation's input rows and write them to the
7375
// external process. That process's output will be read by this current thread.
@@ -116,11 +118,17 @@ case class SparkScriptTransformationExec(
116118
if (!ioschema.schemaLess) {
117119
new GenericInternalRow(
118120
prevLine.split(ioschema.outputRowFormatMap("TOK_TABLEROWFORMATFIELD"))
119-
.map(CatalystTypeConverters.convertToCatalyst))
121+
.zip(output)
122+
.map { case (data, dataType) =>
123+
CatalystTypeConverters.convertToCatalyst(wrapper(data, dataType.dataType))
124+
})
120125
} else {
121126
new GenericInternalRow(
122127
prevLine.split(ioschema.outputRowFormatMap("TOK_TABLEROWFORMATFIELD"), 2)
123-
.map(CatalystTypeConverters.convertToCatalyst))
128+
.zip(output)
129+
.map { case (data, dataType) =>
130+
CatalystTypeConverters.convertToCatalyst(wrapper(data, dataType.dataType))
131+
})
124132
}
125133
}
126134
}

sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -713,7 +713,8 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder(conf) {
713713
None
714714
}
715715
(Seq.empty, Option(name), props.toSeq, recordHandler)
716-
716+
// SPARK-32106: When there is no definition about format, we return empty result
717+
// then we finally execute with SparkScriptTransformationExec
717718
case null =>
718719
(Nil, None, Seq.empty, None)
719720
}

sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,6 @@
1717

1818
package org.apache.spark.sql.execution
1919

20-
import java.time.ZoneId
21-
2220
import org.apache.spark.rdd.RDD
2321
import org.apache.spark.sql.{execution, AnalysisException, Strategy}
2422
import org.apache.spark.sql.catalyst.InternalRow
@@ -40,7 +38,7 @@ import org.apache.spark.sql.execution.streaming._
4038
import org.apache.spark.sql.execution.streaming.sources.MemoryPlan
4139
import org.apache.spark.sql.internal.SQLConf
4240
import org.apache.spark.sql.streaming.{OutputMode, StreamingQuery}
43-
import org.apache.spark.sql.types.{StringType, StructType}
41+
import org.apache.spark.sql.types.StructType
4442

4543
/**
4644
* Converts a logical plan into zero or more SparkPlans. This API is exposed for experimenting
@@ -539,7 +537,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
539537
case logical.ScriptTransformation(input, script, output, child, ioschema)
540538
if ioschema.inputSerdeClass.isEmpty && ioschema.outputSerdeClass.isEmpty =>
541539
SparkScriptTransformationExec(
542-
input.map(Cast(_, StringType).withTimeZone(conf.sessionLocalTimeZone)),
540+
input,
543541
script,
544542
output,
545543
planLater(child),

sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveScriptTransformationExec.scala

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ import org.apache.spark.sql.catalyst.plans.logical.ScriptInputOutputSchema
3939
import org.apache.spark.sql.execution._
4040
import org.apache.spark.sql.hive.HiveInspectors
4141
import org.apache.spark.sql.hive.HiveShim._
42-
import org.apache.spark.sql.types.DataType
42+
import org.apache.spark.sql.types.{DataType, StringType}
4343
import org.apache.spark.util.{CircularBuffer, RedirectThread, Utils}
4444

4545
/**
@@ -78,17 +78,25 @@ case class HiveScriptTransformationExec(
7878
stderrBuffer,
7979
"Thread-ScriptTransformation-STDERR-Consumer").start()
8080

81-
val outputProjection = new InterpretedProjection(input, child.output)
82-
8381
// This nullability is a performance optimization in order to avoid an Option.foreach() call
8482
// inside of a loop
8583
@Nullable val (inputSerde, inputSoi) = ioschema.initInputSerDe(input).getOrElse((null, null))
8684

85+
// For HiveScriptTransformationExec, if inputSerde == null, but outputSerde != null
86+
// We will use StringBuffer to pass data, in this case, we should cast data as string too.
87+
val finalInput = if (inputSerde == null) {
88+
input.map(Cast(_, StringType).withTimeZone(conf.sessionLocalTimeZone))
89+
} else {
90+
input
91+
}
92+
93+
val outputProjection = new InterpretedProjection(finalInput, child.output)
94+
8795
// This new thread will consume the ScriptTransformation's input rows and write them to the
8896
// external process. That process's output will be read by this current thread.
8997
val writerThread = HiveScriptTransformationWriterThread(
9098
inputIterator.map(outputProjection),
91-
input.map(_.dataType),
99+
finalInput.map(_.dataType),
92100
inputSerde,
93101
inputSoi,
94102
ioschema,
@@ -178,11 +186,17 @@ case class HiveScriptTransformationExec(
178186
if (!ioschema.schemaLess) {
179187
new GenericInternalRow(
180188
prevLine.split(ioschema.outputRowFormatMap("TOK_TABLEROWFORMATFIELD"))
181-
.map(CatalystTypeConverters.convertToCatalyst))
189+
.zip(output)
190+
.map { case (data, dataType) =>
191+
CatalystTypeConverters.convertToCatalyst(wrapper(data, dataType.dataType))
192+
})
182193
} else {
183194
new GenericInternalRow(
184195
prevLine.split(ioschema.outputRowFormatMap("TOK_TABLEROWFORMATFIELD"), 2)
185-
.map(CatalystTypeConverters.convertToCatalyst))
196+
.zip(output)
197+
.map { case (data, dataType) =>
198+
CatalystTypeConverters.convertToCatalyst(wrapper(data, dataType.dataType))
199+
})
186200
}
187201
} else {
188202
val raw = outputSerde.deserialize(scriptOutputWritable)

0 commit comments

Comments
 (0)