Skip to content

Commit c736220

Browse files
viiryamarmbrus
authored andcommitted
[SPARK-6635][SQL] DataFrame.withColumn should replace columns with identical column names
JIRA https://issues.apache.org/jira/browse/SPARK-6635 Author: Liang-Chi Hsieh <[email protected]> Closes apache#5541 from viirya/replace_with_column and squashes the following commits: b539c7b [Liang-Chi Hsieh] For comment. 72f35b1 [Liang-Chi Hsieh] DataFrame.withColumn can replace original column with identical column name.
1 parent ce7ddab commit c736220

File tree

2 files changed

+21
-1
lines changed

2 files changed

+21
-1
lines changed

sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -747,7 +747,19 @@ class DataFrame private[sql](
747747
* Returns a new [[DataFrame]] by adding a column.
748748
* @group dfops
749749
*/
750-
def withColumn(colName: String, col: Column): DataFrame = select(Column("*"), col.as(colName))
750+
def withColumn(colName: String, col: Column): DataFrame = {
751+
val resolver = sqlContext.analyzer.resolver
752+
val replaced = schema.exists(f => resolver(f.name, colName))
753+
if (replaced) {
754+
val colNames = schema.map { field =>
755+
val name = field.name
756+
if (resolver(name, colName)) col.as(colName) else Column(name)
757+
}
758+
select(colNames :_*)
759+
} else {
760+
select(Column("*"), col.as(colName))
761+
}
762+
}
751763

752764
/**
753765
* Returns a new [[DataFrame]] with a column renamed.

sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -473,6 +473,14 @@ class DataFrameSuite extends QueryTest {
473473
assert(df.schema.map(_.name).toSeq === Seq("key", "value", "newCol"))
474474
}
475475

476+
test("replace column using withColumn") {
477+
val df2 = TestSQLContext.sparkContext.parallelize(Array(1, 2, 3)).toDF("x")
478+
val df3 = df2.withColumn("x", df2("x") + 1)
479+
checkAnswer(
480+
df3.select("x"),
481+
Row(2) :: Row(3) :: Row(4) :: Nil)
482+
}
483+
476484
test("withColumnRenamed") {
477485
val df = testData.toDF().withColumn("newCol", col("key") + 1)
478486
.withColumnRenamed("value", "valueRenamed")

0 commit comments

Comments
 (0)