@@ -192,6 +192,127 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) {
192192 */
193193 def fill (valueMap : Map [String , Any ]): DataFrame = fill0(valueMap.toSeq)
194194
195+ /**
196+ * Replaces values matching keys in `replacement` map with the corresponding values.
197+ * Key and value of `replacement` map must have the same type, and can only be doubles or strings.
198+ * If `col` is "*", then the replacement is applied on all string columns or numeric columns.
199+ *
200+ * {{{
201+ * import com.google.common.collect.ImmutableMap;
202+ *
203+ * // Replaces all occurrences of 1.0 with 2.0 in column "height".
204+ * df.replace("height", ImmutableMap.of(1.0, 2.0));
205+ *
206+ * // Replaces all occurrences of "UNKNOWN" with "unnamed" in column "name".
207+ * df.replace("name", ImmutableMap.of("UNKNOWN", "unnamed"));
208+ *
209+ * // Replaces all occurrences of "UNKNOWN" with "unnamed" in all string columns.
210+ * df.replace("*", ImmutableMap.of("UNKNOWN", "unnamed"));
211+ * }}}
212+ *
213+ * @param col name of the column to apply the value replacement
214+ * @param replacement value replacement map, as explained above
215+ */
216+ def replace [T ](col : String , replacement : java.util.Map [T , T ]): DataFrame = {
217+ replace[T ](col, replacement.toMap : Map [T , T ])
218+ }
219+
220+ /**
221+ * Replaces values matching keys in `replacement` map with the corresponding values.
222+ * Key and value of `replacement` map must have the same type, and can only be doubles or strings.
223+ *
224+ * {{{
225+ * import com.google.common.collect.ImmutableMap;
226+ *
227+ * // Replaces all occurrences of 1.0 with 2.0 in column "height" and "weight".
228+ * df.replace(new String[] {"height", "weight"}, ImmutableMap.of(1.0, 2.0));
229+ *
230+ * // Replaces all occurrences of "UNKNOWN" with "unnamed" in column "firstname" and "lastname".
231+ * df.replace(new String[] {"firstname", "lastname"}, ImmutableMap.of("UNKNOWN", "unnamed"));
232+ * }}}
233+ *
234+ * @param cols list of columns to apply the value replacement
235+ * @param replacement value replacement map, as explained above
236+ */
237+ def replace [T ](cols : Array [String ], replacement : java.util.Map [T , T ]): DataFrame = {
238+ replace(cols.toSeq, replacement.toMap)
239+ }
240+
241+ /**
242+ * (Scala-specific) Replaces values matching keys in `replacement` map.
243+ * Key and value of `replacement` map must have the same type, and can only be doubles or strings.
244+ * If `col` is "*", then the replacement is applied on all string columns or numeric columns.
245+ *
246+ * {{{
247+ * // Replaces all occurrences of 1.0 with 2.0 in column "height".
248+ * df.replace("height", Map(1.0 -> 2.0))
249+ *
250+ * // Replaces all occurrences of "UNKNOWN" with "unnamed" in column "name".
251+ * df.replace("name", Map("UNKNOWN" -> "unnamed")
252+ *
253+ * // Replaces all occurrences of "UNKNOWN" with "unnamed" in all string columns.
254+ * df.replace("*", Map("UNKNOWN" -> "unnamed")
255+ * }}}
256+ *
257+ * @param col name of the column to apply the value replacement
258+ * @param replacement value replacement map, as explained above
259+ */
260+ def replace [T ](col : String , replacement : Map [T , T ]): DataFrame = {
261+ if (col == " *" ) {
262+ replace0(df.columns, replacement)
263+ } else {
264+ replace0(Seq (col), replacement)
265+ }
266+ }
267+
268+ /**
269+ * (Scala-specific) Replaces values matching keys in `replacement` map.
270+ * Key and value of `replacement` map must have the same type, and can only be doubles or strings.
271+ *
272+ * {{{
273+ * // Replaces all occurrences of 1.0 with 2.0 in column "height" and "weight".
274+ * df.replace("height" :: "weight" :: Nil, Map(1.0 -> 2.0));
275+ *
276+ * // Replaces all occurrences of "UNKNOWN" with "unnamed" in column "firstname" and "lastname".
277+ * df.replace("firstname" :: "lastname" :: Nil, Map("UNKNOWN" -> "unnamed");
278+ * }}}
279+ *
280+ * @param cols list of columns to apply the value replacement
281+ * @param replacement value replacement map, as explained above
282+ */
283+ def replace [T ](cols : Seq [String ], replacement : Map [T , T ]): DataFrame = replace0(cols, replacement)
284+
285+ private def replace0 [T ](cols : Seq [String ], replacement : Map [T , T ]): DataFrame = {
286+ if (replacement.isEmpty || cols.isEmpty) {
287+ return df
288+ }
289+
290+ // replacementMap is either Map[String, String] or Map[Double, Double]
291+ val replacementMap : Map [_, _] = replacement.head._2 match {
292+ case v : String => replacement
293+ case _ => replacement.map { case (k, v) => (convertToDouble(k), convertToDouble(v)) }
294+ }
295+
296+ // targetColumnType is either DoubleType or StringType
297+ val targetColumnType = replacement.head._1 match {
298+ case _ : jl.Double | _ : jl.Float | _ : jl.Integer | _ : jl.Long => DoubleType
299+ case _ : String => StringType
300+ }
301+
302+ val columnEquals = df.sqlContext.analyzer.resolver
303+ val projections = df.schema.fields.map { f =>
304+ val shouldReplace = cols.exists(colName => columnEquals(colName, f.name))
305+ if (f.dataType.isInstanceOf [NumericType ] && targetColumnType == DoubleType && shouldReplace) {
306+ replaceCol(f, replacementMap)
307+ } else if (f.dataType == targetColumnType && shouldReplace) {
308+ replaceCol(f, replacementMap)
309+ } else {
310+ df.col(f.name)
311+ }
312+ }
313+ df.select(projections : _* )
314+ }
315+
195316 private def fill0 (values : Seq [(String , Any )]): DataFrame = {
196317 // Error handling
197318 values.foreach { case (colName, replaceValue) =>
@@ -228,4 +349,27 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) {
228349 private def fillCol [T ](col : StructField , replacement : T ): Column = {
229350 coalesce(df.col(col.name), lit(replacement).cast(col.dataType)).as(col.name)
230351 }
352+
353+ /**
354+ * Returns a [[Column ]] expression that replaces value matching key in `replacementMap` with
355+ * value in `replacementMap`, using [[CaseWhen ]].
356+ *
357+ * TODO: This can be optimized to use broadcast join when replacementMap is large.
358+ */
359+ private def replaceCol (col : StructField , replacementMap : Map [_, _]): Column = {
360+ val branches : Seq [Expression ] = replacementMap.flatMap { case (source, target) =>
361+ df.col(col.name).equalTo(lit(source).cast(col.dataType)).expr ::
362+ lit(target).cast(col.dataType).expr :: Nil
363+ }.toSeq
364+ new Column (CaseWhen (branches ++ Seq (df.col(col.name).expr))).as(col.name)
365+ }
366+
367+ private def convertToDouble (v : Any ): Double = v match {
368+ case v : Float => v.toDouble
369+ case v : Double => v
370+ case v : Long => v.toDouble
371+ case v : Int => v.toDouble
372+ case v => throw new IllegalArgumentException (
373+ s " Unsupported value type ${v.getClass.getName} ( $v). " )
374+ }
231375}
0 commit comments