From d11d5b95ef82a208c579daa0073bdc072a682be5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=BA=91=E5=B3=A4?= Date: Fri, 1 May 2015 23:50:12 +0800 Subject: [PATCH 01/10] [SPARK-7294] ADD BETWEEN --- python/pyspark/sql/dataframe.py | 12 ++++++++++++ python/pyspark/sql/tests.py | 6 ++++++ .../main/scala/org/apache/spark/sql/Column.scala | 14 ++++++++++++++ .../apache/spark/sql/ColumnExpressionSuite.scala | 6 ++++++ .../test/scala/org/apache/spark/sql/TestData.scala | 11 +++++++++++ 5 files changed, 49 insertions(+) diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 5908ebc990a56..a4cbc7396e386 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -1289,6 +1289,18 @@ def cast(self, dataType): raise TypeError("unexpected type: %s" % type(dataType)) return Column(jc) + @ignore_unicode_prefix + def between(self, col1, col2): + """ A boolean expression that is evaluated to true if the value of this + expression is between the given columns. + + >>> df[df.col1.between(col2, col3)].collect() + [Row(col1=5, col2=6, col3=8)] + """ + #sc = SparkContext._active_spark_context + jc = self > col1 & self < col2 + return Column(jc) + def __repr__(self): return 'Column<%s>' % self._jc.toString().encode('utf8') diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 5640bb5ea2346..206e3b7fd08f2 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -426,6 +426,12 @@ def test_rand_functions(self): for row in rndn: assert row[1] >= -4.0 and row[1] <= 4.0, "got: %s" % row[1] + def test_between_function(self): + df = self.sqlCtx.parallelize([Row(a=1, b=2, c=3), Row(a=2, b=1, c=3), Row(a=4, b=1, c=3)]).toDF() + self.assertEqual([False, True, False], + df.select(df.a.between(df.b, df.c)).collect()) + + def test_save_and_load(self): df = self.df tmpPath = tempfile.mkdtemp() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala index 33f9d0b37d006..8e0eab7918131 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala @@ -295,6 +295,20 @@ class Column(protected[sql] val expr: Expression) extends Logging { */ def eqNullSafe(other: Any): Column = this <=> other + /** + * Between col1 and col2. + * + * @group java_expr_ops + */ + def between(col1: String, col2: String): Column = between(Column(col1), Column(col2)) + + /** + * Between col1 and col2. + * + * @group java_expr_ops + */ + def between(col1: Column, col2: Column): Column = And(GreaterThan(this.expr, col1.expr), LessThan(this.expr, col2.expr)) + /** * True if the current expression is null. * diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala index 6322faf4d9907..0a81f884e9a16 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala @@ -208,6 +208,12 @@ class ColumnExpressionSuite extends QueryTest { testData2.collect().toSeq.filter(r => r.getInt(0) <= r.getInt(1))) } + test("between") { + checkAnswer( + testData4.filter($"a".between($"b", $"c")), + testData4.collect().toSeq.filter(r => r.getInt(0) > r.getInt(1) && r.getInt(0) < r.getInt(2))) + } + val booleanData = TestSQLContext.createDataFrame(TestSQLContext.sparkContext.parallelize( Row(false, false) :: Row(false, true) :: diff --git a/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala b/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala index 225b51bd73d6c..487d07249922f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala @@ -57,6 +57,17 @@ object TestData { TestData2(3, 2) :: Nil, 2).toDF() testData2.registerTempTable("testData2") + case class TestData4(a: Int, b: Int, c: Int) + val testData4 = + TestSQLContext.sparkContext.parallelize( + TestData4(0, 1, 2) :: + TestData4(1, 2, 3) :: + TestData4(2, 1, 0) :: + TestData4(2, 2, 4) :: + TestData4(3, 1, 6) :: + TestData4(3, 2, 0) :: Nil, 2).toDF() + testData4.registerTempTable("TestData4") + case class DecimalData(a: BigDecimal, b: BigDecimal) val decimalData = From baf839b4a4aa8d7d4ab8cdb1a5b82affd3ce376e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=BA=91=E5=B3=A4?= Date: Sat, 2 May 2015 09:39:17 +0800 Subject: [PATCH 02/10] [SPARK-7294] ADD BETWEEN --- python/pyspark/sql/dataframe.py | 7 +++---- python/pyspark/sql/tests.py | 4 ++-- .../main/scala/org/apache/spark/sql/Column.scala | 13 +++++++++---- .../apache/spark/sql/ColumnExpressionSuite.scala | 2 +- 4 files changed, 15 insertions(+), 11 deletions(-) diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index a4cbc7396e386..8c09bf23f3cc0 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -1290,15 +1290,14 @@ def cast(self, dataType): return Column(jc) @ignore_unicode_prefix - def between(self, col1, col2): + def between(self, lowerBound, upperBound): """ A boolean expression that is evaluated to true if the value of this expression is between the given columns. - >>> df[df.col1.between(col2, col3)].collect() + >>> df[df.col1.between(lowerBound, upperBound)].collect() [Row(col1=5, col2=6, col3=8)] """ - #sc = SparkContext._active_spark_context - jc = self > col1 & self < col2 + jc = (self >= lowerBound) & (self <= upperBound) return Column(jc) def __repr__(self): diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 206e3b7fd08f2..b5faedfe15e46 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -427,8 +427,8 @@ def test_rand_functions(self): assert row[1] >= -4.0 and row[1] <= 4.0, "got: %s" % row[1] def test_between_function(self): - df = self.sqlCtx.parallelize([Row(a=1, b=2, c=3), Row(a=2, b=1, c=3), Row(a=4, b=1, c=3)]).toDF() - self.assertEqual([False, True, False], + df = self.sqlCtx.parallelize([Row(a=1, b=2, c=3), Row(a=2, b=1, c=3), Row(a=4, b=1, c=4)]).toDF() + self.assertEqual([False, True, True], df.select(df.a.between(df.b, df.c)).collect()) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala index 8e0eab7918131..b51b6368eeb56 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala @@ -296,18 +296,23 @@ class Column(protected[sql] val expr: Expression) extends Logging { def eqNullSafe(other: Any): Column = this <=> other /** - * Between col1 and col2. + * True if the current column is between the lower bound and upper bound, inclusive. * * @group java_expr_ops */ - def between(col1: String, col2: String): Column = between(Column(col1), Column(col2)) + def between(lowerBound: String, upperBound: String): Column = { + between(Column(lowerBound), Column(upperBound)) + } /** - * Between col1 and col2. + * True if the current column is between the lower bound and upper bound, inclusive. * * @group java_expr_ops */ - def between(col1: Column, col2: Column): Column = And(GreaterThan(this.expr, col1.expr), LessThan(this.expr, col2.expr)) + def between(lowerBound: Column, upperBound: Column): Column = { + And(GreaterThanOrEqual(this.expr, lowerBound.expr), + LessThanOrEqual(this.expr, upperBound.expr)) + } /** * True if the current expression is null. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala index 0a81f884e9a16..b63c1814adc3d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala @@ -211,7 +211,7 @@ class ColumnExpressionSuite extends QueryTest { test("between") { checkAnswer( testData4.filter($"a".between($"b", $"c")), - testData4.collect().toSeq.filter(r => r.getInt(0) > r.getInt(1) && r.getInt(0) < r.getInt(2))) + testData4.collect().toSeq.filter(r => r.getInt(0) >= r.getInt(1) && r.getInt(0) <= r.getInt(2))) } val booleanData = TestSQLContext.createDataFrame(TestSQLContext.sparkContext.parallelize( From 7d623680d2c726a53b9e36c78f654e34c40f3dba Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=BA=91=E5=B3=A4?= Date: Sat, 2 May 2015 14:17:10 +0800 Subject: [PATCH 03/10] [SPARK-7294] ADD BETWEEN --- python/pyspark/sql/dataframe.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 8c09bf23f3cc0..2538bd139bb3f 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -1297,8 +1297,7 @@ def between(self, lowerBound, upperBound): >>> df[df.col1.between(lowerBound, upperBound)].collect() [Row(col1=5, col2=6, col3=8)] """ - jc = (self >= lowerBound) & (self <= upperBound) - return Column(jc) + return (self >= lowerBound) & (self <= upperBound) def __repr__(self): return 'Column<%s>' % self._jc.toString().encode('utf8') From f080f8d118f00e4f27936d55e74d391bac690c33 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=BA=91=E5=B3=A4?= Date: Sat, 2 May 2015 22:00:12 +0800 Subject: [PATCH 04/10] update pep8 --- python/pyspark/sql/tests.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index edf9f95a8ce65..000dab99ea730 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -439,7 +439,9 @@ def test_rand_functions(self): assert row[1] >= -4.0 and row[1] <= 4.0, "got: %s" % row[1] def test_between_function(self): - df = self.sqlCtx.parallelize([Row(a=1, b=2, c=3), Row(a=2, b=1, c=3), Row(a=4, b=1, c=4)]).toDF() + df = self.sqlCtx.parallelize([Row(a=1, b=2, c=3), + Row(a=2, b=1, c=3), + Row(a=4, b=1, c=4)]).toDF() self.assertEqual([False, True, True], df.select(df.a.between(df.b, df.c)).collect()) From c6e49bcdb37c7e783849464858552293bcbea4fe Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=BA=91=E5=B3=A4?= Date: Sun, 3 May 2015 11:09:20 +0800 Subject: [PATCH 05/10] undo --- python/pyspark/sql/dataframe.py | 10 ---------- python/pyspark/sql/tests.py | 8 -------- .../scala/org/apache/spark/sql/Column.scala | 19 ------------------- .../spark/sql/ColumnExpressionSuite.scala | 6 ------ .../scala/org/apache/spark/sql/TestData.scala | 11 ----------- 5 files changed, 54 deletions(-) diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 5d1e7b630bf3a..e9fd17ed4ce94 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -1330,16 +1330,6 @@ def cast(self, dataType): raise TypeError("unexpected type: %s" % type(dataType)) return Column(jc) - @ignore_unicode_prefix - def between(self, lowerBound, upperBound): - """ A boolean expression that is evaluated to true if the value of this - expression is between the given columns. - - >>> df[df.col1.between(lowerBound, upperBound)].collect() - [Row(col1=5, col2=6, col3=8)] - """ - return (self >= lowerBound) & (self <= upperBound) - def __repr__(self): return 'Column<%s>' % self._jc.toString().encode('utf8') diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 000dab99ea730..613efc0ac029d 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -438,14 +438,6 @@ def test_rand_functions(self): for row in rndn: assert row[1] >= -4.0 and row[1] <= 4.0, "got: %s" % row[1] - def test_between_function(self): - df = self.sqlCtx.parallelize([Row(a=1, b=2, c=3), - Row(a=2, b=1, c=3), - Row(a=4, b=1, c=4)]).toDF() - self.assertEqual([False, True, True], - df.select(df.a.between(df.b, df.c)).collect()) - - def test_save_and_load(self): df = self.df tmpPath = tempfile.mkdtemp() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala index b51b6368eeb56..33f9d0b37d006 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala @@ -295,25 +295,6 @@ class Column(protected[sql] val expr: Expression) extends Logging { */ def eqNullSafe(other: Any): Column = this <=> other - /** - * True if the current column is between the lower bound and upper bound, inclusive. - * - * @group java_expr_ops - */ - def between(lowerBound: String, upperBound: String): Column = { - between(Column(lowerBound), Column(upperBound)) - } - - /** - * True if the current column is between the lower bound and upper bound, inclusive. - * - * @group java_expr_ops - */ - def between(lowerBound: Column, upperBound: Column): Column = { - And(GreaterThanOrEqual(this.expr, lowerBound.expr), - LessThanOrEqual(this.expr, upperBound.expr)) - } - /** * True if the current expression is null. * diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala index b63c1814adc3d..6322faf4d9907 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala @@ -208,12 +208,6 @@ class ColumnExpressionSuite extends QueryTest { testData2.collect().toSeq.filter(r => r.getInt(0) <= r.getInt(1))) } - test("between") { - checkAnswer( - testData4.filter($"a".between($"b", $"c")), - testData4.collect().toSeq.filter(r => r.getInt(0) >= r.getInt(1) && r.getInt(0) <= r.getInt(2))) - } - val booleanData = TestSQLContext.createDataFrame(TestSQLContext.sparkContext.parallelize( Row(false, false) :: Row(false, true) :: diff --git a/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala b/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala index 487d07249922f..225b51bd73d6c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala @@ -57,17 +57,6 @@ object TestData { TestData2(3, 2) :: Nil, 2).toDF() testData2.registerTempTable("testData2") - case class TestData4(a: Int, b: Int, c: Int) - val testData4 = - TestSQLContext.sparkContext.parallelize( - TestData4(0, 1, 2) :: - TestData4(1, 2, 3) :: - TestData4(2, 1, 0) :: - TestData4(2, 2, 4) :: - TestData4(3, 1, 6) :: - TestData4(3, 2, 0) :: Nil, 2).toDF() - testData4.registerTempTable("TestData4") - case class DecimalData(a: BigDecimal, b: BigDecimal) val decimalData = From d6cc28d08eea876c96866b4c1626b49bf3ba4b73 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=BA=91=E5=B3=A4?= Date: Sun, 3 May 2015 13:53:33 +0800 Subject: [PATCH 06/10] undo --- .../org/apache/spark/sql/DataFrame.scala | 32 ++++++++----- .../org/apache/spark/sql/GroupedData.scala | 8 ++++ .../org/apache/spark/sql/DataFrameSuite.scala | 46 ++++++++++++++++--- 3 files changed, 68 insertions(+), 18 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index c421006c8fd2d..be68b2fa41db2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -20,31 +20,30 @@ package org.apache.spark.sql import java.io.CharArrayWriter import java.sql.DriverManager -import scala.collection.JavaConversions._ -import scala.language.implicitConversions -import scala.reflect.ClassTag -import scala.reflect.runtime.universe.TypeTag -import scala.util.control.NonFatal - import com.fasterxml.jackson.core.JsonFactory - import org.apache.spark.annotation.{DeveloperApi, Experimental} import org.apache.spark.api.java.JavaRDD import org.apache.spark.api.python.SerDeUtil import org.apache.spark.rdd.RDD -import org.apache.spark.storage.StorageLevel -import org.apache.spark.sql.catalyst.{CatalystTypeConverters, ScalaReflection, SqlParser} -import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedRelation, ResolvedStar} +import org.apache.spark.sql.catalyst.analysis.{ResolvedStar, UnresolvedAttribute, UnresolvedRelation} import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.{JoinType, Inner} import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.plans.{Inner, JoinType} +import org.apache.spark.sql.catalyst.{CatalystTypeConverters, ScalaReflection, SqlParser} import org.apache.spark.sql.execution.{EvaluatePython, ExplainCommand, LogicalRDD} import org.apache.spark.sql.jdbc.JDBCWriteDetails import org.apache.spark.sql.json.JsonRDD +import org.apache.spark.sql.sources.{CreateTableUsingAsSelect, ResolvedDataSource} import org.apache.spark.sql.types._ -import org.apache.spark.sql.sources.{ResolvedDataSource, CreateTableUsingAsSelect} +import org.apache.spark.storage.StorageLevel import org.apache.spark.util.Utils +import scala.collection.JavaConversions._ +import scala.language.implicitConversions +import scala.reflect.ClassTag +import scala.reflect.runtime.universe.TypeTag +import scala.util.control.NonFatal + private[sql] object DataFrame { def apply(sqlContext: SQLContext, logicalPlan: LogicalPlan): DataFrame = { @@ -887,6 +886,15 @@ class DataFrame private[sql]( } } + /** + * Returns a new [[DataFrame]] without duplicates under the given columns. + * @group dfops + */ + def dropDuplicates(subset: Seq[String] = this.columns): DataFrame = { + import org.apache.spark.sql.functions.{first => columnFirst} + new GroupedData(this, subset.map(colName => resolve(colName))).agg(columns.map(columnFirst)) + } + /** * Computes statistics for numeric columns, including count, mean, stddev, min, and max. * If no columns are given, this function computes statistics for all numerical columns. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala b/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala index 53ad67372e024..9e45a65ed6cc9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala @@ -161,6 +161,14 @@ class GroupedData protected[sql](df: DataFrame, groupingExprs: Seq[Expression]) DataFrame(df.sqlContext, Aggregate(groupingExprs, aggExprs, df.logicalPlan)) } + def agg(exprs: Seq[Column]): DataFrame = { + val aggExprs = exprs.map(_.expr).map { + case expr: NamedExpression => expr + case expr: Expression => Alias(expr, expr.prettyString)() + } + DataFrame(df.sqlContext, Aggregate(groupingExprs, aggExprs, df.logicalPlan)) + } + /** * Count the number of rows for each group. * The resulting [[DataFrame]] will also contain the grouping columns. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index e286fef23caa4..1cce98bc3176d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -17,14 +17,13 @@ package org.apache.spark.sql -import scala.language.postfixOps - import org.apache.spark.sql.functions._ -import org.apache.spark.sql.types._ -import org.apache.spark.sql.test.{ExamplePointUDT, ExamplePoint, TestSQLContext} -import org.apache.spark.sql.test.TestSQLContext.logicalPlanToSparkQuery import org.apache.spark.sql.test.TestSQLContext.implicits._ -import org.apache.spark.sql.test.TestSQLContext.sql +import org.apache.spark.sql.test.TestSQLContext.{logicalPlanToSparkQuery, sql} +import org.apache.spark.sql.test.{ExamplePoint, ExamplePointUDT, TestSQLContext} +import org.apache.spark.sql.types._ + +import scala.language.postfixOps class DataFrameSuite extends QueryTest { @@ -613,4 +612,39 @@ class DataFrameSuite extends QueryTest { Row(new java.math.BigDecimal(2.0))) TestSQLContext.setConf(SQLConf.CODEGEN_ENABLED, originalValue.toString) } + + test("SPARK-7324 dropDuplicates") { + val testData = TestSQLContext.sparkContext.parallelize( + (2, 1, 2) :: (1, 1, 1) :: + (1, 2, 1) :: (2, 1, 2) :: + (2, 2, 2) :: (2, 2, 1) :: + (2, 1, 1) :: (1, 1, 2) :: + (1, 2, 2) :: (1, 2, 1) :: Nil).toDF("key", "value1", "value2") + + checkAnswer( + testData.dropDuplicates(), + Seq(Row(2, 1, 2), Row(1, 1, 1), Row(1, 2, 1), + Row(2, 2, 2), Row(2, 1, 1), Row(2, 2, 1), + Row(1, 1, 2), Row(1, 2, 2))) + + checkAnswer( + testData.dropDuplicates(Seq("key", "value1")), + Seq(Row(2, 1, 2), Row(1, 2, 1), Row(1, 1, 1), Row(2, 2, 2))) + + checkAnswer( + testData.dropDuplicates(Seq("value1", "value2")), + Seq(Row(2, 1, 2), Row(1, 2, 1), Row(1, 1, 1), Row(2, 2, 2))) + + checkAnswer( + testData.dropDuplicates(Seq("key")), + Seq(Row(2, 1, 2), Row(1, 1, 1))) + + checkAnswer( + testData.dropDuplicates(Seq("value1")), + Seq(Row(2, 1, 2), Row(1, 2, 1))) + + checkAnswer( + testData.dropDuplicates(Seq("value2")), + Seq(Row(2, 1, 2), Row(1, 1, 1))) + } } From aab51ef2e4aa0a97eba5324bb37dc2d08f55896a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=BA=91=E5=B3=A4?= Date: Sun, 3 May 2015 14:12:02 +0800 Subject: [PATCH 07/10] update --- .../org/apache/spark/sql/DataFrame.scala | 23 ++++++++++--------- .../org/apache/spark/sql/DataFrameSuite.scala | 11 +++++---- 2 files changed, 18 insertions(+), 16 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index be68b2fa41db2..393a33afae179 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -20,30 +20,31 @@ package org.apache.spark.sql import java.io.CharArrayWriter import java.sql.DriverManager +import scala.collection.JavaConversions._ +import scala.language.implicitConversions +import scala.reflect.ClassTag +import scala.reflect.runtime.universe.TypeTag +import scala.util.control.NonFatal + import com.fasterxml.jackson.core.JsonFactory + import org.apache.spark.annotation.{DeveloperApi, Experimental} import org.apache.spark.api.java.JavaRDD import org.apache.spark.api.python.SerDeUtil import org.apache.spark.rdd.RDD -import org.apache.spark.sql.catalyst.analysis.{ResolvedStar, UnresolvedAttribute, UnresolvedRelation} +import org.apache.spark.storage.StorageLevel +import org.apache.spark.sql.catalyst.{CatalystTypeConverters, ScalaReflection, SqlParser} +import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedRelation, ResolvedStar} import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.{JoinType, Inner} import org.apache.spark.sql.catalyst.plans.logical._ -import org.apache.spark.sql.catalyst.plans.{Inner, JoinType} -import org.apache.spark.sql.catalyst.{CatalystTypeConverters, ScalaReflection, SqlParser} import org.apache.spark.sql.execution.{EvaluatePython, ExplainCommand, LogicalRDD} import org.apache.spark.sql.jdbc.JDBCWriteDetails import org.apache.spark.sql.json.JsonRDD -import org.apache.spark.sql.sources.{CreateTableUsingAsSelect, ResolvedDataSource} import org.apache.spark.sql.types._ -import org.apache.spark.storage.StorageLevel +import org.apache.spark.sql.sources.{ResolvedDataSource, CreateTableUsingAsSelect} import org.apache.spark.util.Utils -import scala.collection.JavaConversions._ -import scala.language.implicitConversions -import scala.reflect.ClassTag -import scala.reflect.runtime.universe.TypeTag -import scala.util.control.NonFatal - private[sql] object DataFrame { def apply(sqlContext: SQLContext, logicalPlan: LogicalPlan): DataFrame = { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 1cce98bc3176d..06e4983ba49ed 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -17,13 +17,14 @@ package org.apache.spark.sql +import scala.language.postfixOps + import org.apache.spark.sql.functions._ -import org.apache.spark.sql.test.TestSQLContext.implicits._ -import org.apache.spark.sql.test.TestSQLContext.{logicalPlanToSparkQuery, sql} -import org.apache.spark.sql.test.{ExamplePoint, ExamplePointUDT, TestSQLContext} import org.apache.spark.sql.types._ - -import scala.language.postfixOps +import org.apache.spark.sql.test.{ExamplePointUDT, ExamplePoint, TestSQLContext} +import org.apache.spark.sql.test.TestSQLContext.logicalPlanToSparkQuery +import org.apache.spark.sql.test.TestSQLContext.implicits._ +import org.apache.spark.sql.test.TestSQLContext.sql class DataFrameSuite extends QueryTest { From b6f187949e5dc00477cc7c52f9e8ee9a20987215 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=BA=91=E5=B3=A4?= Date: Sun, 3 May 2015 21:46:21 +0800 Subject: [PATCH 08/10] update --- .../src/main/scala/org/apache/spark/sql/DataFrame.scala | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index 393a33afae179..0c9dcf85aec24 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -887,11 +887,17 @@ class DataFrame private[sql]( } } + /** + * Returns a new [[DataFrame]] without duplicates. + * @group dfops + */ + def dropDuplicates(): DataFrame = dropDuplicates(this.columns) + /** * Returns a new [[DataFrame]] without duplicates under the given columns. * @group dfops */ - def dropDuplicates(subset: Seq[String] = this.columns): DataFrame = { + def dropDuplicates(subset: Seq[String]): DataFrame = { import org.apache.spark.sql.functions.{first => columnFirst} new GroupedData(this, subset.map(colName => resolve(colName))).agg(columns.map(columnFirst)) } From 571869ec4815105681cffcdabdd147dd00751348 Mon Sep 17 00:00:00 2001 From: kaka1992 Date: Wed, 6 May 2015 13:42:58 +0800 Subject: [PATCH 09/10] Remove useless code. --- .../src/main/scala/org/apache/spark/sql/DataFrame.scala | 3 ++- .../src/main/scala/org/apache/spark/sql/GroupedData.scala | 8 -------- 2 files changed, 2 insertions(+), 9 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index 0c9dcf85aec24..f00c78caa6aab 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -899,7 +899,8 @@ class DataFrame private[sql]( */ def dropDuplicates(subset: Seq[String]): DataFrame = { import org.apache.spark.sql.functions.{first => columnFirst} - new GroupedData(this, subset.map(colName => resolve(colName))).agg(columns.map(columnFirst)) + val columnFirsts = columns.map(columnFirst) + groupBy(subset.head, subset.tail : _*).agg(columnFirsts.head, columnFirsts.tail : _*) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala b/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala index 9e45a65ed6cc9..53ad67372e024 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala @@ -161,14 +161,6 @@ class GroupedData protected[sql](df: DataFrame, groupingExprs: Seq[Expression]) DataFrame(df.sqlContext, Aggregate(groupingExprs, aggExprs, df.logicalPlan)) } - def agg(exprs: Seq[Column]): DataFrame = { - val aggExprs = exprs.map(_.expr).map { - case expr: NamedExpression => expr - case expr: Expression => Alias(expr, expr.prettyString)() - } - DataFrame(df.sqlContext, Aggregate(groupingExprs, aggExprs, df.logicalPlan)) - } - /** * Count the number of rows for each group. * The resulting [[DataFrame]] will also contain the grouping columns. From 1de87911a41bb659d8142b35ec6c71de1b39ae8a Mon Sep 17 00:00:00 2001 From: kaka1992 Date: Tue, 12 May 2015 09:36:53 +0800 Subject: [PATCH 10/10] Update --- .../scala/org/apache/spark/sql/DataFrame.scala | 16 +++++++--------- 1 file changed, 7 insertions(+), 9 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index f00c78caa6aab..a86de847dfd79 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -891,16 +891,14 @@ class DataFrame private[sql]( * Returns a new [[DataFrame]] without duplicates. * @group dfops */ - def dropDuplicates(): DataFrame = dropDuplicates(this.columns) - - /** - * Returns a new [[DataFrame]] without duplicates under the given columns. - * @group dfops - */ - def dropDuplicates(subset: Seq[String]): DataFrame = { + def dropDuplicates(subset: Seq[String] = this.columns): DataFrame = { import org.apache.spark.sql.functions.{first => columnFirst} - val columnFirsts = columns.map(columnFirst) - groupBy(subset.head, subset.tail : _*).agg(columnFirsts.head, columnFirsts.tail : _*) + if (subset.length == 0) { + sqlContext.emptyDataFrame + } else { + val columnFirsts = columns.map(columnFirst) + groupBy(subset.head, subset.tail: _*).agg(columnFirsts.head, columnFirsts.tail: _*) + } } /**