Skip to content

Commit d931b01

Browse files
committed
[SQL] Two DataFrame fixes.
- Removed DataFrame.apply for projection & filtering since they are extremely confusing. - Added implicits for RDD[Int], RDD[Long], and RDD[String] Author: Reynold Xin <[email protected]> Closes #4543 from rxin/df-cleanup and squashes the following commits: 81ec915 [Reynold Xin] [SQL] More DataFrame fixes.
1 parent fa6bdc6 commit d931b01

File tree

5 files changed

+119
-57
lines changed

5 files changed

+119
-57
lines changed

sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala

Lines changed: 11 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -85,20 +85,24 @@ trait DataFrame extends RDDApi[Row] {
8585

8686
protected[sql] def logicalPlan: LogicalPlan
8787

88+
override def toString =
89+
try {
90+
schema.map(f => s"${f.name}: ${f.dataType.simpleString}").mkString("[", ", ", "]")
91+
} catch {
92+
case NonFatal(e) =>
93+
s"Invalid tree; ${e.getMessage}:\n$queryExecution"
94+
}
95+
8896
/** Left here for backward compatibility. */
8997
@deprecated("1.3.0", "use toDataFrame")
9098
def toSchemaRDD: DataFrame = this
9199

92100
/**
93101
* Returns the object itself. Used to force an implicit conversion from RDD to DataFrame in Scala.
94102
*/
95-
def toDataFrame: DataFrame = this
96-
97-
override def toString =
98-
try schema.map(f => s"${f.name}: ${f.dataType.simpleString}").mkString("[", ", ", "]") catch {
99-
case NonFatal(e) =>
100-
s"Invalid tree; ${e.getMessage}:\n$queryExecution"
101-
}
103+
// This is declared with parentheses to prevent the Scala compiler from treating
104+
// `rdd.toDataFrame("1")` as invoking this toDataFrame and then apply on the returned DataFrame.
105+
def toDataFrame(): DataFrame = this
102106

103107
/**
104108
* Returns a new [[DataFrame]] with columns renamed. This can be quite convenient in conversion
@@ -234,16 +238,6 @@ trait DataFrame extends RDDApi[Row] {
234238
*/
235239
def col(colName: String): Column
236240

237-
/**
238-
* Selects a set of expressions, wrapped in a Product.
239-
* {{{
240-
* // The following two are equivalent:
241-
* df.apply(($"colA", $"colB" + 1))
242-
* df.select($"colA", $"colB" + 1)
243-
* }}}
244-
*/
245-
def apply(projection: Product): DataFrame
246-
247241
/**
248242
* Returns a new [[DataFrame]] with an alias set.
249243
*/
@@ -317,17 +311,6 @@ trait DataFrame extends RDDApi[Row] {
317311
*/
318312
def where(condition: Column): DataFrame
319313

320-
/**
321-
* Filters rows using the given condition. This is a shorthand meant for Scala.
322-
* {{{
323-
* // The following are equivalent:
324-
* peopleDf.filter($"age" > 15)
325-
* peopleDf.where($"age" > 15)
326-
* peopleDf($"age" > 15)
327-
* }}}
328-
*/
329-
def apply(condition: Column): DataFrame
330-
331314
/**
332315
* Groups the [[DataFrame]] using the specified columns, so we can run aggregation on them.
333316
* See [[GroupedData]] for all the available aggregate functions.

sql/core/src/main/scala/org/apache/spark/sql/DataFrameImpl.scala

Lines changed: 7 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -49,8 +49,10 @@ private[sql] class DataFrameImpl protected[sql](
4949
extends DataFrame {
5050

5151
/**
52-
* A constructor that automatically analyzes the logical plan. This reports error eagerly
53-
* as the [[DataFrame]] is constructed.
52+
* A constructor that automatically analyzes the logical plan.
53+
*
54+
* This reports error eagerly as the [[DataFrame]] is constructed, unless
55+
* [[SQLConf.dataFrameEagerAnalysis]] is turned off.
5456
*/
5557
def this(sqlContext: SQLContext, logicalPlan: LogicalPlan) = {
5658
this(sqlContext, {
@@ -158,7 +160,7 @@ private[sql] class DataFrameImpl protected[sql](
158160
}
159161

160162
override def show(): Unit = {
161-
println(showString)
163+
println(showString())
162164
}
163165

164166
override def join(right: DataFrame): DataFrame = {
@@ -205,14 +207,6 @@ private[sql] class DataFrameImpl protected[sql](
205207
Column(sqlContext, Project(Seq(expr), logicalPlan), expr)
206208
}
207209

208-
override def apply(projection: Product): DataFrame = {
209-
require(projection.productArity >= 1)
210-
select(projection.productIterator.map {
211-
case c: Column => c
212-
case o: Any => Column(Literal(o))
213-
}.toSeq :_*)
214-
}
215-
216210
override def as(alias: String): DataFrame = Subquery(alias, logicalPlan)
217211

218212
override def as(alias: Symbol): DataFrame = Subquery(alias.name, logicalPlan)
@@ -259,10 +253,6 @@ private[sql] class DataFrameImpl protected[sql](
259253
filter(condition)
260254
}
261255

262-
override def apply(condition: Column): DataFrame = {
263-
filter(condition)
264-
}
265-
266256
override def groupBy(cols: Column*): GroupedData = {
267257
new GroupedData(this, cols.map(_.expr))
268258
}
@@ -323,7 +313,7 @@ private[sql] class DataFrameImpl protected[sql](
323313
override def count(): Long = groupBy().count().rdd.collect().head.getLong(0)
324314

325315
override def repartition(numPartitions: Int): DataFrame = {
326-
sqlContext.applySchema(rdd.repartition(numPartitions), schema)
316+
sqlContext.createDataFrame(rdd.repartition(numPartitions), schema)
327317
}
328318

329319
override def distinct: DataFrame = Distinct(logicalPlan)
@@ -401,7 +391,7 @@ private[sql] class DataFrameImpl protected[sql](
401391
val gen = new JsonFactory().createGenerator(writer).setRootValueSeparator(null)
402392

403393
new Iterator[String] {
404-
override def hasNext() = iter.hasNext
394+
override def hasNext = iter.hasNext
405395
override def next(): String = {
406396
JsonRDD.rowToJSON(rowSchema, gen)(iter.next())
407397
gen.flush()

sql/core/src/main/scala/org/apache/spark/sql/IncomputableColumn.scala

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -80,8 +80,6 @@ private[sql] class IncomputableColumn(protected[sql] val expr: Expression) exten
8080

8181
override def col(colName: String): Column = err()
8282

83-
override def apply(projection: Product): DataFrame = err()
84-
8583
override def select(cols: Column*): DataFrame = err()
8684

8785
override def select(col: String, cols: String*): DataFrame = err()
@@ -98,8 +96,6 @@ private[sql] class IncomputableColumn(protected[sql] val expr: Expression) exten
9896

9997
override def where(condition: Column): DataFrame = err()
10098

101-
override def apply(condition: Column): DataFrame = err()
102-
10399
override def groupBy(cols: Column*): GroupedData = err()
104100

105101
override def groupBy(col1: String, cols: String*): GroupedData = err()

sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala

Lines changed: 46 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -180,21 +180,59 @@ class SQLContext(@transient val sparkContext: SparkContext)
180180
*/
181181
object implicits {
182182
// scalastyle:on
183-
/**
184-
* Creates a DataFrame from an RDD of case classes.
185-
*
186-
* @group userf
187-
*/
183+
184+
/** Creates a DataFrame from an RDD of case classes or tuples. */
188185
implicit def rddToDataFrame[A <: Product : TypeTag](rdd: RDD[A]): DataFrame = {
189186
self.createDataFrame(rdd)
190187
}
191188

192-
/**
193-
* Creates a DataFrame from a local Seq of Product.
194-
*/
189+
/** Creates a DataFrame from a local Seq of Product. */
195190
implicit def localSeqToDataFrame[A <: Product : TypeTag](data: Seq[A]): DataFrame = {
196191
self.createDataFrame(data)
197192
}
193+
194+
// Do NOT add more implicit conversions. They are likely to break source compatibility by
195+
// making existing implicit conversions ambiguous. In particular, RDD[Double] is dangerous
196+
// because of [[DoubleRDDFunctions]].
197+
198+
/** Creates a single column DataFrame from an RDD[Int]. */
199+
implicit def intRddToDataFrame(data: RDD[Int]): DataFrame = {
200+
val dataType = IntegerType
201+
val rows = data.mapPartitions { iter =>
202+
val row = new SpecificMutableRow(dataType :: Nil)
203+
iter.map { v =>
204+
row.setInt(0, v)
205+
row: Row
206+
}
207+
}
208+
self.createDataFrame(rows, StructType(StructField("_1", dataType) :: Nil))
209+
}
210+
211+
/** Creates a single column DataFrame from an RDD[Long]. */
212+
implicit def longRddToDataFrame(data: RDD[Long]): DataFrame = {
213+
val dataType = LongType
214+
val rows = data.mapPartitions { iter =>
215+
val row = new SpecificMutableRow(dataType :: Nil)
216+
iter.map { v =>
217+
row.setLong(0, v)
218+
row: Row
219+
}
220+
}
221+
self.createDataFrame(rows, StructType(StructField("_1", dataType) :: Nil))
222+
}
223+
224+
/** Creates a single column DataFrame from an RDD[String]. */
225+
implicit def stringRddToDataFrame(data: RDD[String]): DataFrame = {
226+
val dataType = StringType
227+
val rows = data.mapPartitions { iter =>
228+
val row = new SpecificMutableRow(dataType :: Nil)
229+
iter.map { v =>
230+
row.setString(0, v)
231+
row: Row
232+
}
233+
}
234+
self.createDataFrame(rows, StructType(StructField("_1", dataType) :: Nil))
235+
}
198236
}
199237

200238
/**
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql
19+
20+
import org.apache.spark.sql.test.TestSQLContext.{sparkContext => sc}
21+
import org.apache.spark.sql.test.TestSQLContext.implicits._
22+
23+
24+
class DataFrameImplicitsSuite extends QueryTest {
25+
26+
test("RDD of tuples") {
27+
checkAnswer(
28+
sc.parallelize(1 to 10).map(i => (i, i.toString)).toDataFrame("intCol", "strCol"),
29+
(1 to 10).map(i => Row(i, i.toString)))
30+
}
31+
32+
test("Seq of tuples") {
33+
checkAnswer(
34+
(1 to 10).map(i => (i, i.toString)).toDataFrame("intCol", "strCol"),
35+
(1 to 10).map(i => Row(i, i.toString)))
36+
}
37+
38+
test("RDD[Int]") {
39+
checkAnswer(
40+
sc.parallelize(1 to 10).toDataFrame("intCol"),
41+
(1 to 10).map(i => Row(i)))
42+
}
43+
44+
test("RDD[Long]") {
45+
checkAnswer(
46+
sc.parallelize(1L to 10L).toDataFrame("longCol"),
47+
(1L to 10L).map(i => Row(i)))
48+
}
49+
50+
test("RDD[String]") {
51+
checkAnswer(
52+
sc.parallelize(1 to 10).map(_.toString).toDataFrame("stringCol"),
53+
(1 to 10).map(i => Row(i.toString)))
54+
}
55+
}

0 commit comments

Comments
 (0)