From cb876e46cae1b90aaa0c681b4c6dbc66480e2950 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Mon, 31 Oct 2016 08:59:00 +0000 Subject: [PATCH 1/3] Fix a compilation error in codegen due to splitExpression. --- .../expressions/ReferenceToExpressions.scala | 22 +++++++++-- .../org/apache/spark/sql/DatasetSuite.scala | 37 +++++++++++++++++++ 2 files changed, 56 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ReferenceToExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ReferenceToExpressions.scala index 127797c0974bb..e99cc28906099 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ReferenceToExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ReferenceToExpressions.scala @@ -67,11 +67,27 @@ case class ReferenceToExpressions(result: Expression, children: Seq[Expression]) case (childGen, child) => LambdaVariable(childGen.value, childGen.isNull, child.dataType) } + // SPARK-18125: The children vars are local variables. If the result expression uses + // splitExpression, those variables cannot be accessed so compilation fails. + // To fix it, we use class variables to hold those local variables. + val initClassChildVars = childrenVars.map { childVar => + val childVarInClass = ctx.freshName("childVarInClass") + ctx.addMutableState(ctx.javaType(childVar.dataType), childVarInClass, "") + val isNullInClass = ctx.freshName("childVarInClassIsNull") + ctx.addMutableState("boolean", isNullInClass, "") + LambdaVariable(childVarInClass, isNullInClass, childVar.dataType) + } + + val initClassChildVarsCode = initClassChildVars.zipWithIndex.map { case (childVarInClass, i) => + s"${childVarInClass.value} = ${childrenVars(i).value};\n" + + s"${childVarInClass.isNull} = ${childrenVars(i).isNull};" + }.mkString("\n") + val resultGen = result.transform { - case b: BoundReference => childrenVars(b.ordinal) + case b: BoundReference => initClassChildVars(b.ordinal) }.genCode(ctx) - ExprCode(code = childrenGen.map(_.code).mkString("\n") + "\n" + resultGen.code, - isNull = resultGen.isNull, value = resultGen.value) + ExprCode(code = childrenGen.map(_.code).mkString("\n") + "\n" + initClassChildVarsCode + + resultGen.code, isNull = resultGen.isNull, value = resultGen.value) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index cc367acae2ba4..334f978ab6b6b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -919,6 +919,40 @@ class DatasetSuite extends QueryTest with SharedSQLContext { df.withColumn("b", expr("0")).as[ClassData] .groupByKey(_.a).flatMapGroups { case (x, iter) => List[Int]() }) } + + test("SPARK-18125: Spark generated code causes CompileException") { + val data = Array( + Route("a", "b", 1), + Route("a", "b", 2), + Route("a", "c", 2), + Route("a", "d", 10), + Route("b", "a", 1), + Route("b", "a", 5), + Route("b", "c", 6)) + val ds = sparkContext.parallelize(data).toDF.as[Route] + + val grped = ds.map(r => GroupedRoutes(r.src, r.dest, Seq(r))) + .groupByKey(r => (r.src, r.dest)) + .reduceGroups { (g1: GroupedRoutes, g2: GroupedRoutes) => + GroupedRoutes(g1.src, g1.dest, g1.routes ++ g2.routes) + }.map(_._2) + + val expected = Seq( + GroupedRoutes("a", "d", Seq(Route("a", "d", 10))), + GroupedRoutes("b", "c", Seq(Route("b", "c", 6))), + GroupedRoutes("a", "b", Seq(Route("a", "b", 1), Route("a", "b", 2))), + GroupedRoutes("b", "a", Seq(Route("b", "a", 1), Route("b", "a", 5))), + GroupedRoutes("a", "c", Seq(Route("a", "c", 2))) + ) + + implicit def ordering[GroupedRoutes]: Ordering[GroupedRoutes] = new Ordering[GroupedRoutes] { + override def compare(x: GroupedRoutes, y: GroupedRoutes): Int = { + x.toString.compareTo(y.toString) + } + } + + checkDatasetUnorderly(grped, expected: _*) + } } case class Generic[T](id: T, value: Double) @@ -991,3 +1025,6 @@ object DatasetTransform { ds.map(_ + 1) } } + +case class Route(src: String, dest: String, cost: Int) +case class GroupedRoutes(src: String, dest: String, routes: Seq[Route]) From 0b660e02480bb3d193daf4acc997c1c0ca040930 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Mon, 31 Oct 2016 13:53:48 +0000 Subject: [PATCH 2/3] Refactor it a bit. --- .../expressions/ReferenceToExpressions.scala | 38 ++++++++++--------- 1 file changed, 20 insertions(+), 18 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ReferenceToExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ReferenceToExpressions.scala index e99cc28906099..42786744e9c37 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ReferenceToExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ReferenceToExpressions.scala @@ -63,31 +63,33 @@ case class ReferenceToExpressions(result: Expression, children: Seq[Expression]) override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val childrenGen = children.map(_.genCode(ctx)) - val childrenVars = childrenGen.zip(children).map { - case (childGen, child) => LambdaVariable(childGen.value, childGen.isNull, child.dataType) - } + val (childrenVars, classChildrenVars) = childrenGen.zip(children).map { + case (childGen, child) => + val childVar = LambdaVariable(childGen.value, childGen.isNull, child.dataType) - // SPARK-18125: The children vars are local variables. If the result expression uses - // splitExpression, those variables cannot be accessed so compilation fails. - // To fix it, we use class variables to hold those local variables. - val initClassChildVars = childrenVars.map { childVar => - val childVarInClass = ctx.freshName("childVarInClass") - ctx.addMutableState(ctx.javaType(childVar.dataType), childVarInClass, "") - val isNullInClass = ctx.freshName("childVarInClassIsNull") - ctx.addMutableState("boolean", isNullInClass, "") - LambdaVariable(childVarInClass, isNullInClass, childVar.dataType) - } + // SPARK-18125: The children vars are local variables. If the result expression uses + // splitExpression, those variables cannot be accessed so compilation fails. + // To fix it, we use class variables to hold those local variables. + val classChildVarName = ctx.freshName("classChildVar") + val classChildVarIsNull = ctx.freshName("classChildVarIsNull") + ctx.addMutableState(ctx.javaType(childVar.dataType), classChildVarName, "") + ctx.addMutableState("boolean", classChildVarIsNull, "") + val classChildVar = + LambdaVariable(classChildVarName, classChildVarIsNull, childVar.dataType) + + (childVar, classChildVar) + }.unzip - val initClassChildVarsCode = initClassChildVars.zipWithIndex.map { case (childVarInClass, i) => - s"${childVarInClass.value} = ${childrenVars(i).value};\n" + - s"${childVarInClass.isNull} = ${childrenVars(i).isNull};" + val initClassChildrenVars = classChildrenVars.zipWithIndex.map { case (classChildrenVar, i) => + s"${classChildrenVar.value} = ${childrenVars(i).value};\n" + + s"${classChildrenVar.isNull} = ${childrenVars(i).isNull};" }.mkString("\n") val resultGen = result.transform { - case b: BoundReference => initClassChildVars(b.ordinal) + case b: BoundReference => classChildrenVars(b.ordinal) }.genCode(ctx) - ExprCode(code = childrenGen.map(_.code).mkString("\n") + "\n" + initClassChildVarsCode + + ExprCode(code = childrenGen.map(_.code).mkString("\n") + "\n" + initClassChildrenVars + resultGen.code, isNull = resultGen.isNull, value = resultGen.value) } } From 448abfac34d5454a3bc74e0a4df11736713179b9 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Sat, 5 Nov 2016 02:11:57 +0000 Subject: [PATCH 3/3] Address comment. --- .../expressions/ReferenceToExpressions.scala | 21 ++++++++----------- 1 file changed, 9 insertions(+), 12 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ReferenceToExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ReferenceToExpressions.scala index 42786744e9c37..6c75a7a50214f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ReferenceToExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ReferenceToExpressions.scala @@ -63,33 +63,30 @@ case class ReferenceToExpressions(result: Expression, children: Seq[Expression]) override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val childrenGen = children.map(_.genCode(ctx)) - val (childrenVars, classChildrenVars) = childrenGen.zip(children).map { + val (classChildrenVars, initClassChildrenVars) = childrenGen.zip(children).map { case (childGen, child) => - val childVar = LambdaVariable(childGen.value, childGen.isNull, child.dataType) - // SPARK-18125: The children vars are local variables. If the result expression uses // splitExpression, those variables cannot be accessed so compilation fails. // To fix it, we use class variables to hold those local variables. val classChildVarName = ctx.freshName("classChildVar") val classChildVarIsNull = ctx.freshName("classChildVarIsNull") - ctx.addMutableState(ctx.javaType(childVar.dataType), classChildVarName, "") + ctx.addMutableState(ctx.javaType(child.dataType), classChildVarName, "") ctx.addMutableState("boolean", classChildVarIsNull, "") + val classChildVar = - LambdaVariable(classChildVarName, classChildVarIsNull, childVar.dataType) + LambdaVariable(classChildVarName, classChildVarIsNull, child.dataType) - (childVar, classChildVar) - }.unzip + val initCode = s"${classChildVar.value} = ${childGen.value};\n" + + s"${classChildVar.isNull} = ${childGen.isNull};" - val initClassChildrenVars = classChildrenVars.zipWithIndex.map { case (classChildrenVar, i) => - s"${classChildrenVar.value} = ${childrenVars(i).value};\n" + - s"${classChildrenVar.isNull} = ${childrenVars(i).isNull};" - }.mkString("\n") + (classChildVar, initCode) + }.unzip val resultGen = result.transform { case b: BoundReference => classChildrenVars(b.ordinal) }.genCode(ctx) - ExprCode(code = childrenGen.map(_.code).mkString("\n") + "\n" + initClassChildrenVars + + ExprCode(code = childrenGen.map(_.code).mkString("\n") + initClassChildrenVars.mkString("\n") + resultGen.code, isNull = resultGen.isNull, value = resultGen.value) } }