Skip to content

Commit a2b5408

Browse files
committed
WIP: Code generation with scala reflection.
1 parent b520b64 commit a2b5408

35 files changed

+1493
-64
lines changed

project/SparkBuild.scala

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -495,6 +495,9 @@ object SparkBuild extends Build {
495495
// assumptions about the the expression ids being contiguous. Running tests in parallel breaks
496496
// this non-deterministically. TODO: FIX THIS.
497497
parallelExecution in Test := false,
498+
addCompilerPlugin("org.scalamacros" % "paradise" % "2.0.0-M8" cross CrossVersion.full),
499+
libraryDependencies <+= scalaVersion(v => "org.scala-lang" % "scala-compiler" % v ),
500+
libraryDependencies += "org.scalamacros" %% "quasiquotes" % "2.0.0-M8",
498501
libraryDependencies ++= Seq(
499502
"com.typesafe" %% "scalalogging-slf4j" % "1.0.1"
500503
)

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -158,8 +158,8 @@ class Analyzer(catalog: Catalog, registry: FunctionRegistry, caseSensitive: Bool
158158
*/
159159
object ImplicitGenerate extends Rule[LogicalPlan] {
160160
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
161-
case Project(Seq(Alias(g: Generator, _)), child) =>
162-
Generate(g, join = false, outer = false, None, child)
161+
case Project(Seq(Alias(g: Generator, alias)), child) =>
162+
Generate(g, join = false, outer = false, Some(alias), child)
163163
}
164164
}
165165

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -259,7 +259,7 @@ trait HiveTypeCoercion {
259259
case Cast(e, BooleanType) if e.dataType != BooleanType => Not(EqualTo(e, Literal(0)))
260260
// Turn true into 1, and false into 0 if casting boolean into other types.
261261
case Cast(e, dataType) if e.dataType == BooleanType =>
262-
Cast(If(e, Literal(1), Literal(0)), dataType)
262+
If(e, Cast(Literal(1), dataType), Cast(Literal(0), dataType))
263263
}
264264
}
265265

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ class BindReferences[TreeNode <: QueryPlan[TreeNode]] extends Rule[TreeNode] {
6262
plan.transform {
6363
case n: NoBind => n.asInstanceOf[TreeNode]
6464
case leafNode if leafNode.children.isEmpty => leafNode
65+
case nb: NoBind => nb.asInstanceOf[TreeNode]
6566
case unaryNode if unaryNode.children.size == 1 => unaryNode.transformExpressions { case e =>
6667
bindReference(e, unaryNode.children.head.output)
6768
}
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.catalyst.expressions
19+
20+
import org.apache.spark.Logging
21+
import org.apache.spark.sql.catalyst._
22+
23+
import org.apache.spark.sql.catalyst.types._
24+
25+
object CodeGeneration
26+

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ package org.apache.spark.sql.catalyst.expressions
2222
* new row. If the schema of the input row is specified, then the given expression will be bound to
2323
* that schema.
2424
*/
25-
class Projection(expressions: Seq[Expression]) extends (Row => Row) {
25+
class InterpretedProjection(expressions: Seq[Expression]) extends (Row => Row) {
2626
def this(expressions: Seq[Expression], inputSchema: Seq[Attribute]) =
2727
this(expressions.map(BindReferences.bindReference(_, inputSchema)))
2828

@@ -40,7 +40,7 @@ class Projection(expressions: Seq[Expression]) extends (Row => Row) {
4040
}
4141

4242
/**
43-
* Converts a [[Row]] to another Row given a sequence of expression that define each column of th
43+
* Converts a [[Row]] to another Row given a sequence of expression that define each column of the
4444
* new row. If the schema of the input row is specified, then the given expression will be bound to
4545
* that schema.
4646
*
@@ -50,14 +50,19 @@ class Projection(expressions: Seq[Expression]) extends (Row => Row) {
5050
* has been called on the [[Iterator]] that produced it. Instead, the user must call `Row.copy()`
5151
* and hold on to the returned [[Row]] before calling `next()`.
5252
*/
53-
case class MutableProjection(expressions: Seq[Expression]) extends (Row => Row) {
53+
case class InterpretedMutableProjection(expressions: Seq[Expression]) extends MutableProjection {
5454
def this(expressions: Seq[Expression], inputSchema: Seq[Attribute]) =
5555
this(expressions.map(BindReferences.bindReference(_, inputSchema)))
5656

5757
private[this] val exprArray = expressions.toArray
58-
private[this] val mutableRow = new GenericMutableRow(exprArray.size)
58+
private[this] var mutableRow: MutableRow = new GenericMutableRow(exprArray.size)
5959
def currentValue: Row = mutableRow
6060

61+
def target(row: MutableRow): MutableProjection = {
62+
mutableRow = row
63+
this
64+
}
65+
6166
def apply(input: Row): Row = {
6267
var i = 0
6368
while (i < exprArray.length) {
@@ -76,6 +81,12 @@ class JoinedRow extends Row {
7681
private[this] var row1: Row = _
7782
private[this] var row2: Row = _
7883

84+
def this(left: Row, right: Row) = {
85+
this()
86+
row1 = left
87+
row2 = right
88+
}
89+
7990
/** Updates this JoinedRow to used point at two new base rows. Returns itself. */
8091
def apply(r1: Row, r2: Row): Row = {
8192
row1 = r1

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Row.scala

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -180,6 +180,34 @@ class GenericRow(protected[catalyst] val values: Array[Any]) extends Row {
180180
values(i).asInstanceOf[String]
181181
}
182182

183+
override def hashCode(): Int = {
184+
var result: Int = 37
185+
186+
var i = 0
187+
while (i < values.length) {
188+
val update: Int =
189+
if (isNullAt(i)) {
190+
0
191+
} else {
192+
apply(i) match {
193+
case b: Boolean => if (b) 0 else 1
194+
case b: Byte => b.toInt
195+
case s: Short => s.toInt
196+
case i: Int => i
197+
case l: Long => (l ^ (l >>> 32)).toInt
198+
case f: Float => java.lang.Float.floatToIntBits(f)
199+
case d: Double =>
200+
val b = java.lang.Double.doubleToLongBits(d)
201+
(b ^ (b >>> 32)).toInt
202+
case other => other.hashCode()
203+
}
204+
}
205+
result = 37 * result + update
206+
i += 1
207+
}
208+
result
209+
}
210+
183211
def copy() = this
184212
}
185213

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUdf.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ case class ScalaUdf(function: AnyRef, dataType: DataType, children: Seq[Expressi
2929

3030
override def eval(input: Row): Any = {
3131
children.size match {
32+
case 0 => function.asInstanceOf[() => Any]()
3233
case 1 => function.asInstanceOf[(Any) => Any](children(0).eval(input))
3334
case 2 =>
3435
function.asInstanceOf[(Any, Any) => Any](

0 commit comments

Comments
 (0)