Skip to content

Commit 130692f

Browse files
committed
[SPARK-7324][SQL] DataFrame.dropDuplicates
1 parent 91dc3df commit 130692f

File tree

3 files changed

+105
-4
lines changed

3 files changed

+105
-4
lines changed

python/pyspark/sql/dataframe.py

Lines changed: 34 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -755,8 +755,6 @@ def groupBy(self, *cols):
755755
jdf = self._jdf.groupBy(self._jcols(*cols))
756756
return GroupedData(jdf, self.sql_ctx)
757757

758-
groupby = groupBy
759-
760758
def agg(self, *exprs):
761759
""" Aggregate on the entire :class:`DataFrame` without groups
762760
(shorthand for ``df.groupBy.agg()``).
@@ -793,6 +791,36 @@ def subtract(self, other):
793791
"""
794792
return DataFrame(getattr(self._jdf, "except")(other._jdf), self.sql_ctx)
795793

794+
def dropDuplicates(self, subset=None):
795+
"""Return a new :class:`DataFrame` with duplicate rows removed,
796+
optionally only considering certain columns.
797+
798+
>>> from pyspark.sql import Row
799+
>>> df = sc.parallelize([ \
800+
Row(name='Alice', age=5, height=80), \
801+
Row(name='Alice', age=5, height=80), \
802+
Row(name='Alice', age=10, height=80)]).toDF()
803+
>>> df.dropDuplicates().show()
804+
+---+------+-----+
805+
|age|height| name|
806+
+---+------+-----+
807+
| 5| 80|Alice|
808+
| 10| 80|Alice|
809+
+---+------+-----+
810+
811+
>>> df.dropDuplicates(['name', 'height']).show()
812+
+---+------+-----+
813+
|age|height| name|
814+
+---+------+-----+
815+
| 5| 80|Alice|
816+
+---+------+-----+
817+
"""
818+
if subset is None:
819+
jdf = self._jdf.dropDuplicates()
820+
else:
821+
jdf = self._jdf.dropDuplicates(self._jseq(subset))
822+
return DataFrame(jdf, self.sql_ctx)
823+
796824
def dropna(self, how='any', thresh=None, subset=None):
797825
"""Returns a new :class:`DataFrame` omitting rows with null values.
798826
@@ -1012,6 +1040,10 @@ def toPandas(self):
10121040
import pandas as pd
10131041
return pd.DataFrame.from_records(self.collect(), columns=self.columns)
10141042

1043+
# Pandas compatibility
1044+
groupby = groupBy
1045+
drop_duplicates = dropDuplicates
1046+
10151047

10161048
# Having SchemaRDD for backward compatibility (for docs)
10171049
class SchemaRDD(DataFrame):

sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala

Lines changed: 36 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@ package org.apache.spark.sql
2020
import java.io.CharArrayWriter
2121
import java.sql.DriverManager
2222

23-
2423
import scala.collection.JavaConversions._
2524
import scala.language.implicitConversions
2625
import scala.reflect.ClassTag
@@ -42,7 +41,7 @@ import org.apache.spark.sql.catalyst.plans.{JoinType, Inner}
4241
import org.apache.spark.sql.catalyst.plans.logical._
4342
import org.apache.spark.sql.execution.{EvaluatePython, ExplainCommand, LogicalRDD}
4443
import org.apache.spark.sql.jdbc.JDBCWriteDetails
45-
import org.apache.spark.sql.json.{JacksonGenerator, JsonRDD}
44+
import org.apache.spark.sql.json.JacksonGenerator
4645
import org.apache.spark.sql.types._
4746
import org.apache.spark.sql.sources.{ResolvedDataSource, CreateTableUsingAsSelect}
4847
import org.apache.spark.util.Utils
@@ -932,6 +931,40 @@ class DataFrame private[sql](
932931
}
933932
}
934933

934+
/**
935+
* Returns a new [[DataFrame]] that contains only the unique rows from this [[DataFrame]].
936+
* This is an alias for `distinct`.
937+
* @group dfops
938+
*/
939+
def dropDuplicates(): DataFrame = dropDuplicates(this.columns)
940+
941+
/**
942+
* (Scala-specific) Returns a new [[DataFrame]] with duplicate rows removed, considering only
943+
* the subset of columns.
944+
*
945+
* @group dfops
946+
*/
947+
def dropDuplicates(colNames: Seq[String]): DataFrame = {
948+
val groupCols = colNames.map(resolve)
949+
val groupColExprIds = groupCols.map(_.exprId)
950+
val aggCols = logicalPlan.output.map { attr =>
951+
if (groupColExprIds.contains(attr.exprId)) {
952+
attr
953+
} else {
954+
Alias(First(attr), attr.name)()
955+
}
956+
}
957+
Aggregate(groupCols, aggCols, logicalPlan)
958+
}
959+
960+
/**
961+
* Returns a new [[DataFrame]] with duplicate rows removed, considering only
962+
* the subset of columns.
963+
*
964+
* @group dfops
965+
*/
966+
def dropDuplicates(colNames: Array[String]): DataFrame = dropDuplicates(colNames.toSeq)
967+
935968
/**
936969
* Computes statistics for numeric columns, including count, mean, stddev, min, and max.
937970
* If no columns are given, this function computes statistics for all numerical columns.
@@ -1089,6 +1122,7 @@ class DataFrame private[sql](
10891122

10901123
/**
10911124
* Returns a new [[DataFrame]] that contains only the unique rows from this [[DataFrame]].
1125+
* This is an alias for `dropDuplicates`.
10921126
* @group dfops
10931127
*/
10941128
override def distinct: DataFrame = Distinct(logicalPlan)

sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -457,4 +457,39 @@ class DataFrameSuite extends QueryTest {
457457
assert(complexData.filter(complexData("m")("1") === 1).count() == 1)
458458
assert(complexData.filter(complexData("s")("key") === 1).count() == 1)
459459
}
460+
461+
test("SPARK-7324 dropDuplicates") {
462+
val testData = TestSQLContext.sparkContext.parallelize(
463+
(2, 1, 2) :: (1, 1, 1) ::
464+
(1, 2, 1) :: (2, 1, 2) ::
465+
(2, 2, 2) :: (2, 2, 1) ::
466+
(2, 1, 1) :: (1, 1, 2) ::
467+
(1, 2, 2) :: (1, 2, 1) :: Nil).toDF("key", "value1", "value2")
468+
469+
checkAnswer(
470+
testData.dropDuplicates(),
471+
Seq(Row(2, 1, 2), Row(1, 1, 1), Row(1, 2, 1),
472+
Row(2, 2, 2), Row(2, 1, 1), Row(2, 2, 1),
473+
Row(1, 1, 2), Row(1, 2, 2)))
474+
475+
checkAnswer(
476+
testData.dropDuplicates(Seq("key", "value1")),
477+
Seq(Row(2, 1, 2), Row(1, 2, 1), Row(1, 1, 1), Row(2, 2, 2)))
478+
479+
checkAnswer(
480+
testData.dropDuplicates(Seq("value1", "value2")),
481+
Seq(Row(2, 1, 2), Row(1, 2, 1), Row(1, 1, 1), Row(2, 2, 2)))
482+
483+
checkAnswer(
484+
testData.dropDuplicates(Seq("key")),
485+
Seq(Row(2, 1, 2), Row(1, 1, 1)))
486+
487+
checkAnswer(
488+
testData.dropDuplicates(Seq("value1")),
489+
Seq(Row(2, 1, 2), Row(1, 2, 1)))
490+
491+
checkAnswer(
492+
testData.dropDuplicates(Seq("value2")),
493+
Seq(Row(2, 1, 2), Row(1, 1, 1)))
494+
}
460495
}

0 commit comments

Comments
 (0)