Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 30 additions & 11 deletions sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1261,7 +1261,7 @@ class DataFrame private[sql](
* @since 1.4.0
*/
def drop(colName: String): DataFrame = {
drop(Seq(colName) : _*)
drop(colName, Seq() : _*)
}

/**
Expand All @@ -1271,10 +1271,11 @@ class DataFrame private[sql](
* @since 1.6.0
*/
@scala.annotation.varargs
def drop(colNames: String*): DataFrame = {
def drop(colName: String, colNames: String*): DataFrame = {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Doesn't this break binary compatibility?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would you please tell me a bit more about what you meant? I am not used to binary compatibility issues and to me it looks fine about the compatibility between compiled ones with different Spark versions.

One thing I am a bit worried is the ambiguity of parameters (in terms of Java interoperability), which I could fix by adding an additional parameter (if this is a problem).

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just FYI, I am not too sure about all the versions but at least this works fine at Java 7 and Scala 2.10.4.

  • Scala
class TestCompatibility {
  def test(a: String): Unit = {
    test(a, Seq(): _ *)
  }
  @varargs
  def test(a: String, b: String*): Unit = {
    (a +: b).foreach(println)
  }
  def test(a: Int): Unit = {
    test(Seq(a) : _ *)
  }
  @varargs
  def test(a: Int*): Unit = {
    a.foreach(println)
  }
}
  • Java
public class Test {
    public static void main(String[] args) {
        new TestCompatibility().test("a");
        new TestCompatibility().test("a", "b");
        new TestCompatibility().test("a", "b", "c");
        new TestCompatibility().test(1);
        new TestCompatibility().test(1, 2);
        new TestCompatibility().test(1, 2, 3);
    }
}

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Never mind, the annotation is wrong it was added in #9862 which did not get merged into 1.6 so no worry about changing the method signature.

val resolver = sqlContext.analyzer.resolver
val remainingCols =
schema.filter(f => colNames.forall(n => !resolver(f.name, n))).map(f => Column(f.name))
val remainingCols = schema.filter { f =>
(colName +: colNames).forall(n => !resolver(f.name, n))
}.map(f => Column(f.name))
if (remainingCols.size == this.schema.size) {
this
} else {
Expand All @@ -1291,16 +1292,34 @@ class DataFrame private[sql](
* @since 1.4.1
*/
def drop(col: Column): DataFrame = {
val expression = col match {
drop(Seq(col) : _*)
}

/**
* Returns a new [[DataFrame]] with columns dropped.
* This version of drop accepts Column(s) rather than name(s).
* This is a no-op if the DataFrame doesn't have column(s)
* with equivalent expression(s).
* @group dfops
* @since 1.6.0
*/
@scala.annotation.varargs
def drop(cols: Column*): DataFrame = {
val resolver = sqlContext.analyzer.resolver
val attrs = this.logicalPlan.output
val expressions = cols.map {
case Column(u: UnresolvedAttribute) =>
queryExecution.analyzed.resolveQuoted(u.name, sqlContext.analyzer.resolver).getOrElse(u)
queryExecution.analyzed.resolveQuoted(u.name, resolver).getOrElse(u)
case Column(expr: Expression) => expr
}
val attrs = this.logicalPlan.output
val colsAfterDrop = attrs.filter { attr =>
attr != expression
}.map(attr => Column(attr))
select(colsAfterDrop : _*)
val remainingCols = attrs.filter { attr =>
!expressions.contains(attr)
}.map(attr => Column(attr))
if (remainingCols.size == this.schema.size) {
this
} else {
this.select(remainingCols: _*)
}
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -402,6 +402,13 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
assert(df.schema.map(_.name) === Seq("value"))
}

test("drop column using drop with column references") {
val src = Seq((0, 2, 3)).toDF("a", "b", "c")
val df = src.drop(src("a"), src("b"))
checkAnswer(df, Row(3))
assert(df.schema.map(_.name) === Seq("c"))
}

test("drop unknown column (no-op) with column reference") {
val col = Column("random")
val df = testData.drop(col)
Expand Down