Skip to content

Commit 91fc7a2

Browse files
author
Davies Liu
committed
disable codegen for ScalaUDF
1 parent 44573a3 commit 91fc7a2

File tree

6 files changed

+20
-52
lines changed

6 files changed

+20
-52
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUdf.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -957,4 +957,6 @@ case class ScalaUdf(function: AnyRef, dataType: DataType, children: Seq[Expressi
957957
private[this] val converter = CatalystTypeConverters.createToCatalystConverter(dataType)
958958
override def eval(input: Row): Any = converter(f(input))
959959

960+
// TODO(davies): make ScalaUdf work with codegen
961+
override def isThreadSafe: Boolean = false
960962
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -349,7 +349,7 @@ case class MaxOf(left: Expression, right: Expression) extends BinaryArithmetic {
349349
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
350350
val eval1 = left.gen(ctx)
351351
val eval2 = right.gen(ctx)
352-
val compCode = ctx.compFunc(dataType)(eval1.primitive, eval2.primitive)
352+
val compCode = ctx.genCmop(dataType, eval1.primitive, eval2.primitive)
353353

354354
eval1.code + eval2.code + s"""
355355
boolean ${ev.isNull} = false;
@@ -401,7 +401,7 @@ case class MinOf(left: Expression, right: Expression) extends BinaryArithmetic {
401401
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
402402
val eval1 = left.gen(ctx)
403403
val eval2 = right.gen(ctx)
404-
val compCode = ctx.compFunc(dataType)(eval1.primitive, eval2.primitive)
404+
val compCode = ctx.genCmop(dataType, eval1.primitive, eval2.primitive)
405405

406406
eval1.code + eval2.code + s"""
407407
boolean ${ev.isNull} = false;

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala

Lines changed: 9 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -161,28 +161,21 @@ class CodeGenContext {
161161
/**
162162
* Returns a function to generate equal expression in Java
163163
*/
164-
def equalFunc(dataType: DataType): ((String, String) => String) = dataType match {
165-
case BinaryType => { case (eval1, eval2) =>
166-
s"java.util.Arrays.equals($eval1, $eval2)" }
164+
def genEqual(dataType: DataType, c1: String, c2: String): String = dataType match {
165+
case BinaryType => s"java.util.Arrays.equals($c1, $c2)"
167166
case IntegerType | BooleanType | LongType | DoubleType | FloatType | ShortType | ByteType
168-
| DateType =>
169-
{ case (eval1, eval2) => s"$eval1 == $eval2" }
170-
case other =>
171-
{ case (eval1, eval2) => s"$eval1.equals($eval2)" }
167+
| DateType => s"$c1 == $c2"
168+
case other => s"$c1.equals($c2)"
172169
}
173170

174171
/**
175172
* Return a function to generate compare expression in Java
176173
*/
177-
def compFunc(dataType: DataType): (String, String) => String = dataType match {
178-
case BinaryType => {
179-
case (c1, c2) =>
180-
s"org.apache.spark.sql.catalyst.util.TypeUtils.compareBinary($c1, $c2)"
181-
}
182-
case IntegerType | LongType | DoubleType | FloatType | ShortType | ByteType | DateType => {
183-
case (c1, c2) => s"$c1 - $c2"
184-
}
185-
case other => { case (c1, c2) => s"$c1.compare($c2)" }
174+
def genCmop(dataType: DataType, c1: String, c2: String): String = dataType match {
175+
case BinaryType => s"org.apache.spark.sql.catalyst.util.TypeUtils.compareBinary($c1, $c2)"
176+
case IntegerType | LongType | DoubleType | FloatType | ShortType | ByteType | DateType =>
177+
s"$c1 - $c2"
178+
case other => s"$c1.compare($c2)"
186179
}
187180

188181
/**

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypes.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ case class CreateStruct(children: Seq[NamedExpression]) extends Expression {
6161
override lazy val dataType: StructType = {
6262
assert(resolved,
6363
s"CreateStruct contains unresolvable children: ${children.filterNot(_.resolved)}.")
64+
println(s"$children")
6465
val fields = children.map { child =>
6566
StructField(child.name, child.dataType, child.nullable, child.metadata)
6667
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionals.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -261,7 +261,7 @@ case class CaseKeyWhen(key: Expression, branches: Seq[Expression]) extends CaseW
261261
${cond.code}
262262
if (${keyEval.isNull} && ${cond.isNull} ||
263263
!${keyEval.isNull} && !${cond.isNull}
264-
&& ${ctx.equalFunc(key.dataType)(keyEval.primitive, cond.primitive)}) {
264+
&& ${ctx.genEqual(key.dataType, keyEval.primitive, cond.primitive)}) {
265265
$got = true;
266266
${res.code}
267267
${ev.isNull} = ${res.isNull};

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala

Lines changed: 5 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -250,37 +250,9 @@ abstract class BinaryComparison extends BinaryExpression with Predicate {
250250
}
251251

252252
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
253-
left.dataType match {
254-
case dt: NumericType if ctx.isNativeType(dt) => defineCodeGen (ctx, ev, {
255-
(c1, c3) => s"$c1 $symbol $c3"
256-
})
257-
case DateType => defineCodeGen (ctx, ev, {
258-
(c1, c3) => s"$c1 $symbol $c3"
259-
})
260-
case BinaryType =>
261-
val eval1 = left.gen(ctx)
262-
val eval2 = right.gen(ctx)
263-
s"""
264-
${eval1.code}
265-
boolean ${ev.isNull} = ${eval1.isNull};
266-
boolean ${ev.primitive} = ${ctx.defaultValue(dataType)};
267-
if (!${ev.isNull}) {
268-
${eval2.code}
269-
if (!${eval2.isNull}) {
270-
${ev.primitive} = org.apache.spark.sql.catalyst.util.TypeUtils.compareBinary(
271-
${eval1.primitive}, ${eval2.primitive}) $symbol 0;
272-
} else {
273-
${ev.isNull} = true;
274-
}
275-
}
276-
"""
277-
case TimestampType =>
278-
// java.sql.Timestamp does not have compare()
279-
super.genCode(ctx, ev)
280-
case other => defineCodeGen (ctx, ev, {
281-
(c1, c2) => s"$c1.compare($c2) $symbol 0"
282-
})
283-
}
253+
defineCodeGen(ctx, ev, {
254+
(c1, c2) => s"${ctx.genCmop(left.dataType, c1, c2)} $symbol 0"
255+
})
284256
}
285257

286258
protected def evalInternal(evalE1: Any, evalE2: Any): Any =
@@ -301,7 +273,7 @@ case class EqualTo(left: Expression, right: Expression) extends BinaryComparison
301273
else java.util.Arrays.equals(l.asInstanceOf[Array[Byte]], r.asInstanceOf[Array[Byte]])
302274
}
303275
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
304-
defineCodeGen(ctx, ev, ctx.equalFunc(left.dataType))
276+
defineCodeGen(ctx, ev, (c1, c2) => ctx.genEqual(left.dataType, c1, c2))
305277
}
306278
}
307279

@@ -327,7 +299,7 @@ case class EqualNullSafe(left: Expression, right: Expression) extends BinaryComp
327299
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
328300
val eval1 = left.gen(ctx)
329301
val eval2 = right.gen(ctx)
330-
val equalCode = ctx.equalFunc(left.dataType)(eval1.primitive, eval2.primitive)
302+
val equalCode = ctx.genEqual(left.dataType, eval1.primitive, eval2.primitive)
331303
ev.isNull = "false"
332304
eval1.code + eval2.code + s"""
333305
boolean ${ev.primitive} = (${eval1.isNull} && ${eval2.isNull}) ||

0 commit comments

Comments
 (0)