Skip to content

Commit b4b5867

Browse files
author
root
committed
Merge remote branch 'upstream/master' into patch-6
2 parents b4c3869 + bc7041a commit b4b5867

File tree

6 files changed

+72
-10
lines changed

6 files changed

+72
-10
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,10 +47,13 @@ object ScalaReflection {
4747
val TypeRef(_, _, Seq(optType)) = t
4848
Schema(schemaFor(optType).dataType, nullable = true)
4949
case t if t <:< typeOf[Product] =>
50-
val params = t.member("<init>": TermName).asMethod.paramss
50+
val formalTypeArgs = t.typeSymbol.asClass.typeParams
51+
val TypeRef(_, _, actualTypeArgs) = t
52+
val params = t.member(nme.CONSTRUCTOR).asMethod.paramss
5153
Schema(StructType(
5254
params.head.map { p =>
53-
val Schema(dataType, nullable) = schemaFor(p.typeSignature)
55+
val Schema(dataType, nullable) =
56+
schemaFor(p.typeSignature.substituteTypes(formalTypeArgs, actualTypeArgs))
5457
StructField(p.name.toString, dataType, nullable)
5558
}), nullable = true)
5659
// Need to decide if we actually need a special type here.

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,17 @@ package object dsl {
108108

109109
implicit def symbolToUnresolvedAttribute(s: Symbol) = analysis.UnresolvedAttribute(s.name)
110110

111+
def sum(e: Expression) = Sum(e)
112+
def sumDistinct(e: Expression) = SumDistinct(e)
113+
def count(e: Expression) = Count(e)
114+
def countDistinct(e: Expression*) = CountDistinct(e)
115+
def avg(e: Expression) = Average(e)
116+
def first(e: Expression) = First(e)
117+
def min(e: Expression) = Min(e)
118+
def max(e: Expression) = Max(e)
119+
def upper(e: Expression) = Upper(e)
120+
def lower(e: Expression) = Lower(e)
121+
111122
implicit class DslSymbol(sym: Symbol) extends ImplicitAttribute { def s = sym.name }
112123
// TODO more implicit class for literal?
113124
implicit class DslString(val s: String) extends ImplicitOperators {

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,9 @@ case class ComplexData(
6060
mapField: Map[Int, String],
6161
structField: PrimitiveData)
6262

63+
case class GenericData[A](
64+
genericField: A)
65+
6366
class ScalaReflectionSuite extends FunSuite {
6467
import ScalaReflection._
6568

@@ -128,4 +131,21 @@ class ScalaReflectionSuite extends FunSuite {
128131
nullable = true))),
129132
nullable = true))
130133
}
134+
135+
test("generic data") {
136+
val schema = schemaFor[GenericData[Int]]
137+
assert(schema === Schema(
138+
StructType(Seq(
139+
StructField("genericField", IntegerType, nullable = false))),
140+
nullable = true))
141+
}
142+
143+
test("tuple data") {
144+
val schema = schemaFor[(Int, String)]
145+
assert(schema === Schema(
146+
StructType(Seq(
147+
StructField("_1", IntegerType, nullable = false),
148+
StructField("_2", StringType, nullable = true))),
149+
nullable = true))
150+
}
131151
}

sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -133,8 +133,13 @@ class SchemaRDD(
133133
*
134134
* @group Query
135135
*/
136-
def select(exprs: NamedExpression*): SchemaRDD =
137-
new SchemaRDD(sqlContext, Project(exprs, logicalPlan))
136+
def select(exprs: Expression*): SchemaRDD = {
137+
val aliases = exprs.zipWithIndex.map {
138+
case (ne: NamedExpression, _) => ne
139+
case (e, i) => Alias(e, s"c$i")()
140+
}
141+
new SchemaRDD(sqlContext, Project(aliases, logicalPlan))
142+
}
138143

139144
/**
140145
* Filters the output, only returning those rows where `condition` evaluates to true.

sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala

Lines changed: 26 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,26 @@ class DslQuerySuite extends QueryTest {
6060
Seq(Seq("1")))
6161
}
6262

63+
test("select with functions") {
64+
checkAnswer(
65+
testData.select(sum('value), avg('value), count(1)),
66+
Seq(Seq(5050.0, 50.5, 100)))
67+
68+
checkAnswer(
69+
testData2.select('a + 'b, 'a < 'b),
70+
Seq(
71+
Seq(2, false),
72+
Seq(3, true),
73+
Seq(3, false),
74+
Seq(4, false),
75+
Seq(4, false),
76+
Seq(5, false)))
77+
78+
checkAnswer(
79+
testData2.select(sumDistinct('a)),
80+
Seq(Seq(6)))
81+
}
82+
6383
test("sorting") {
6484
checkAnswer(
6585
testData2.orderBy('a.asc, 'b.asc),
@@ -110,17 +130,17 @@ class DslQuerySuite extends QueryTest {
110130

111131
test("average") {
112132
checkAnswer(
113-
testData2.groupBy()(Average('a)),
133+
testData2.groupBy()(avg('a)),
114134
2.0)
115135
}
116136

117137
test("null average") {
118138
checkAnswer(
119-
testData3.groupBy()(Average('b)),
139+
testData3.groupBy()(avg('b)),
120140
2.0)
121141

122142
checkAnswer(
123-
testData3.groupBy()(Average('b), CountDistinct('b :: Nil)),
143+
testData3.groupBy()(avg('b), countDistinct('b)),
124144
(2.0, 1) :: Nil)
125145
}
126146

@@ -130,17 +150,17 @@ class DslQuerySuite extends QueryTest {
130150

131151
test("null count") {
132152
checkAnswer(
133-
testData3.groupBy('a)('a, Count('b)),
153+
testData3.groupBy('a)('a, count('b)),
134154
Seq((1,0), (2, 1))
135155
)
136156

137157
checkAnswer(
138-
testData3.groupBy('a)('a, Count('a + 'b)),
158+
testData3.groupBy('a)('a, count('a + 'b)),
139159
Seq((1,0), (2, 1))
140160
)
141161

142162
checkAnswer(
143-
testData3.groupBy()(Count('a), Count('b), Count(1), CountDistinct('a :: Nil), CountDistinct('b :: Nil)),
163+
testData3.groupBy()(count('a), count('b), count(1), countDistinct('a), countDistinct('b)),
144164
(2, 1, 2, 2, 1) :: Nil
145165
)
146166
}

sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruningSuite.scala

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,9 @@ import scala.collection.JavaConversions._
2626
* A set of test cases that validate partition and column pruning.
2727
*/
2828
class PruningSuite extends HiveComparisonTest {
29+
// MINOR HACK: You must run a query before calling reset the first time.
30+
TestHive.hql("SHOW TABLES")
31+
2932
// Column/partition pruning is not implemented for `InMemoryColumnarTableScan` yet, need to reset
3033
// the environment to ensure all referenced tables in this suites are not cached in-memory.
3134
// Refer to https://issues.apache.org/jira/browse/SPARK-2283 for details.

0 commit comments

Comments
 (0)