Skip to content

Commit 4090d95

Browse files
committed
Merge remote-tracking branch 'apache/master' into SPARK-5957
2 parents 4fee9e7 + 68d1faa commit 4090d95

File tree

2 files changed

+178
-0
lines changed

2 files changed

+178
-0
lines changed

sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala

Lines changed: 144 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -192,6 +192,127 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) {
192192
*/
193193
def fill(valueMap: Map[String, Any]): DataFrame = fill0(valueMap.toSeq)
194194

195+
/**
196+
* Replaces values matching keys in `replacement` map with the corresponding values.
197+
* Key and value of `replacement` map must have the same type, and can only be doubles or strings.
198+
* If `col` is "*", then the replacement is applied on all string columns or numeric columns.
199+
*
200+
* {{{
201+
* import com.google.common.collect.ImmutableMap;
202+
*
203+
* // Replaces all occurrences of 1.0 with 2.0 in column "height".
204+
* df.replace("height", ImmutableMap.of(1.0, 2.0));
205+
*
206+
* // Replaces all occurrences of "UNKNOWN" with "unnamed" in column "name".
207+
* df.replace("name", ImmutableMap.of("UNKNOWN", "unnamed"));
208+
*
209+
* // Replaces all occurrences of "UNKNOWN" with "unnamed" in all string columns.
210+
* df.replace("*", ImmutableMap.of("UNKNOWN", "unnamed"));
211+
* }}}
212+
*
213+
* @param col name of the column to apply the value replacement
214+
* @param replacement value replacement map, as explained above
215+
*/
216+
def replace[T](col: String, replacement: java.util.Map[T, T]): DataFrame = {
217+
replace[T](col, replacement.toMap : Map[T, T])
218+
}
219+
220+
/**
221+
* Replaces values matching keys in `replacement` map with the corresponding values.
222+
* Key and value of `replacement` map must have the same type, and can only be doubles or strings.
223+
*
224+
* {{{
225+
* import com.google.common.collect.ImmutableMap;
226+
*
227+
* // Replaces all occurrences of 1.0 with 2.0 in column "height" and "weight".
228+
* df.replace(new String[] {"height", "weight"}, ImmutableMap.of(1.0, 2.0));
229+
*
230+
* // Replaces all occurrences of "UNKNOWN" with "unnamed" in column "firstname" and "lastname".
231+
* df.replace(new String[] {"firstname", "lastname"}, ImmutableMap.of("UNKNOWN", "unnamed"));
232+
* }}}
233+
*
234+
* @param cols list of columns to apply the value replacement
235+
* @param replacement value replacement map, as explained above
236+
*/
237+
def replace[T](cols: Array[String], replacement: java.util.Map[T, T]): DataFrame = {
238+
replace(cols.toSeq, replacement.toMap)
239+
}
240+
241+
/**
242+
* (Scala-specific) Replaces values matching keys in `replacement` map.
243+
* Key and value of `replacement` map must have the same type, and can only be doubles or strings.
244+
* If `col` is "*", then the replacement is applied on all string columns or numeric columns.
245+
*
246+
* {{{
247+
* // Replaces all occurrences of 1.0 with 2.0 in column "height".
248+
* df.replace("height", Map(1.0 -> 2.0))
249+
*
250+
* // Replaces all occurrences of "UNKNOWN" with "unnamed" in column "name".
251+
* df.replace("name", Map("UNKNOWN" -> "unnamed")
252+
*
253+
* // Replaces all occurrences of "UNKNOWN" with "unnamed" in all string columns.
254+
* df.replace("*", Map("UNKNOWN" -> "unnamed")
255+
* }}}
256+
*
257+
* @param col name of the column to apply the value replacement
258+
* @param replacement value replacement map, as explained above
259+
*/
260+
def replace[T](col: String, replacement: Map[T, T]): DataFrame = {
261+
if (col == "*") {
262+
replace0(df.columns, replacement)
263+
} else {
264+
replace0(Seq(col), replacement)
265+
}
266+
}
267+
268+
/**
269+
* (Scala-specific) Replaces values matching keys in `replacement` map.
270+
* Key and value of `replacement` map must have the same type, and can only be doubles or strings.
271+
*
272+
* {{{
273+
* // Replaces all occurrences of 1.0 with 2.0 in column "height" and "weight".
274+
* df.replace("height" :: "weight" :: Nil, Map(1.0 -> 2.0));
275+
*
276+
* // Replaces all occurrences of "UNKNOWN" with "unnamed" in column "firstname" and "lastname".
277+
* df.replace("firstname" :: "lastname" :: Nil, Map("UNKNOWN" -> "unnamed");
278+
* }}}
279+
*
280+
* @param cols list of columns to apply the value replacement
281+
* @param replacement value replacement map, as explained above
282+
*/
283+
def replace[T](cols: Seq[String], replacement: Map[T, T]): DataFrame = replace0(cols, replacement)
284+
285+
private def replace0[T](cols: Seq[String], replacement: Map[T, T]): DataFrame = {
286+
if (replacement.isEmpty || cols.isEmpty) {
287+
return df
288+
}
289+
290+
// replacementMap is either Map[String, String] or Map[Double, Double]
291+
val replacementMap: Map[_, _] = replacement.head._2 match {
292+
case v: String => replacement
293+
case _ => replacement.map { case (k, v) => (convertToDouble(k), convertToDouble(v)) }
294+
}
295+
296+
// targetColumnType is either DoubleType or StringType
297+
val targetColumnType = replacement.head._1 match {
298+
case _: jl.Double | _: jl.Float | _: jl.Integer | _: jl.Long => DoubleType
299+
case _: String => StringType
300+
}
301+
302+
val columnEquals = df.sqlContext.analyzer.resolver
303+
val projections = df.schema.fields.map { f =>
304+
val shouldReplace = cols.exists(colName => columnEquals(colName, f.name))
305+
if (f.dataType.isInstanceOf[NumericType] && targetColumnType == DoubleType && shouldReplace) {
306+
replaceCol(f, replacementMap)
307+
} else if (f.dataType == targetColumnType && shouldReplace) {
308+
replaceCol(f, replacementMap)
309+
} else {
310+
df.col(f.name)
311+
}
312+
}
313+
df.select(projections : _*)
314+
}
315+
195316
private def fill0(values: Seq[(String, Any)]): DataFrame = {
196317
// Error handling
197318
values.foreach { case (colName, replaceValue) =>
@@ -228,4 +349,27 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) {
228349
private def fillCol[T](col: StructField, replacement: T): Column = {
229350
coalesce(df.col(col.name), lit(replacement).cast(col.dataType)).as(col.name)
230351
}
352+
353+
/**
354+
* Returns a [[Column]] expression that replaces value matching key in `replacementMap` with
355+
* value in `replacementMap`, using [[CaseWhen]].
356+
*
357+
* TODO: This can be optimized to use broadcast join when replacementMap is large.
358+
*/
359+
private def replaceCol(col: StructField, replacementMap: Map[_, _]): Column = {
360+
val branches: Seq[Expression] = replacementMap.flatMap { case (source, target) =>
361+
df.col(col.name).equalTo(lit(source).cast(col.dataType)).expr ::
362+
lit(target).cast(col.dataType).expr :: Nil
363+
}.toSeq
364+
new Column(CaseWhen(branches ++ Seq(df.col(col.name).expr))).as(col.name)
365+
}
366+
367+
private def convertToDouble(v: Any): Double = v match {
368+
case v: Float => v.toDouble
369+
case v: Double => v
370+
case v: Long => v.toDouble
371+
case v: Int => v.toDouble
372+
case v => throw new IllegalArgumentException(
373+
s"Unsupported value type ${v.getClass.getName} ($v).")
374+
}
231375
}

sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -154,4 +154,38 @@ class DataFrameNaFunctionsSuite extends QueryTest {
154154
))),
155155
Row("test", null, 1, 2.2))
156156
}
157+
158+
test("replace") {
159+
val input = createDF()
160+
161+
// Replace two numeric columns: age and height
162+
val out = input.na.replace(Seq("age", "height"), Map(
163+
16 -> 61,
164+
60 -> 6,
165+
164.3 -> 461.3 // Alice is really tall
166+
))
167+
168+
checkAnswer(
169+
out,
170+
Row("Bob", 61, 176.5) ::
171+
Row("Alice", null, 461.3) ::
172+
Row("David", 6, null) ::
173+
Row("Amy", null, null) ::
174+
Row(null, null, null) :: Nil)
175+
176+
// Replace only the age column
177+
val out1 = input.na.replace("age", Map(
178+
16 -> 61,
179+
60 -> 6,
180+
164.3 -> 461.3 // Alice is really tall
181+
))
182+
183+
checkAnswer(
184+
out1,
185+
Row("Bob", 61, 176.5) ::
186+
Row("Alice", null, 164.3) ::
187+
Row("David", 6, null) ::
188+
Row("Amy", null, null) ::
189+
Row(null, null, null) :: Nil)
190+
}
157191
}

0 commit comments

Comments
 (0)