Skip to content

Commit f2e7a69

Browse files
committed
imputer doc and example
1 parent ee91a0d commit f2e7a69

File tree

4 files changed

+178
-2
lines changed

4 files changed

+178
-2
lines changed

docs/ml-features.md

Lines changed: 56 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1284,6 +1284,61 @@ for more details on the API.
12841284

12851285
</div>
12861286

1287+
1288+
## Imputer
1289+
1290+
Imputation transformer for completing missing values in the dataset, either using the mean or the
1291+
median of the columns in which the missing value are located. The input columns should be of
1292+
DoubleType or FloatType. Currently Imputer does not support categorical features and possibly
1293+
creates incorrect values for a categorical feature. All Null values in the input column are
1294+
treated as missing, and so are also imputed.
1295+
1296+
**Examples**
1297+
1298+
Suppose that we have a DataFrame with the column `a` and `b`:
1299+
1300+
~~~
1301+
a | b
1302+
------------|-----------
1303+
1.0 | Double.NaN
1304+
2.0 | Double.NaN
1305+
Double.NaN | 3.0
1306+
4.0 | 4.0
1307+
5.0 | 5.0
1308+
~~~
1309+
1310+
By default, Imputer will replace all the `Double.NaN` (missing value) with the mean (strategy) from
1311+
other values in the corresponding columns. In our example, the surrogates for `a` and `b` are 3.0
1312+
and 4.0 respectively. After transformation, the output columns will not contain missing value anymore.
1313+
1314+
~~~
1315+
a | b | out_a | out_b
1316+
------------|------------|-------|-------
1317+
1.0 | Double.NaN | 1.0 | 4.0
1318+
2.0 | Double.NaN | 2.0 | 4.0
1319+
Double.NaN | 3.0 | 3.0 | 3.0
1320+
4.0 | 4.0 | 4.0 | 4.0
1321+
5.0 | 5.0 | 5.0 | 5.0
1322+
~~~
1323+
1324+
<div class="codetabs">
1325+
<div data-lang="scala" markdown="1">
1326+
1327+
Refer to the [Imputer Scala docs](api/scala/index.html#org.apache.spark.ml.feature.Imputer)
1328+
for more details on the API.
1329+
1330+
{% include_example scala/org/apache/spark/examples/ml/ImputerExample.scala %}
1331+
</div>
1332+
1333+
<div data-lang="java" markdown="1">
1334+
1335+
Refer to the [Imputer Java docs](api/java/org/apache/spark/ml/feature/Imputer.html)
1336+
for more details on the API.
1337+
1338+
{% include_example java/org/apache/spark/examples/ml/JavaImputerExample.java %}
1339+
</div>
1340+
</div>
1341+
12871342
# Feature Selectors
12881343

12891344
## VectorSlicer
@@ -1625,4 +1680,4 @@ for more details on the API.
16251680

16261681
{% include_example python/ml/min_hash_lsh_example.py %}
16271682
</div>
1628-
</div>
1683+
</div>
Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.examples.ml;
19+
20+
// $example on$
21+
import java.util.Arrays;
22+
import java.util.List;
23+
24+
import org.apache.spark.ml.feature.Imputer;
25+
import org.apache.spark.ml.feature.ImputerModel;
26+
import org.apache.spark.sql.Dataset;
27+
import org.apache.spark.sql.Row;
28+
import org.apache.spark.sql.RowFactory;
29+
import org.apache.spark.sql.SparkSession;
30+
import org.apache.spark.sql.types.*;
31+
// $example off$
32+
33+
import static org.apache.spark.sql.types.DataTypes.*;
34+
35+
public class JavaImputerExample {
36+
public static void main(String[] args) {
37+
SparkSession spark = SparkSession
38+
.builder()
39+
.appName("JavaImputerExample")
40+
.getOrCreate();
41+
42+
// $example on$
43+
List<Row> data = Arrays.asList(
44+
RowFactory.create(1.0, Double.NaN),
45+
RowFactory.create(2.0, Double.NaN),
46+
RowFactory.create(Double.NaN, 3.0),
47+
RowFactory.create(4.0, 4.0),
48+
RowFactory.create(5.0, 5.0)
49+
);
50+
StructType schema = new StructType(new StructField[]{
51+
createStructField("a", DoubleType, false),
52+
createStructField("b", DoubleType, false)
53+
});
54+
Dataset<Row> df = spark.createDataFrame(data, schema);
55+
56+
Imputer imputerModel = new Imputer()
57+
.setStrategy("mean")
58+
.setInputCols(new String[]{"a", "b"})
59+
.setOutputCols(new String[]{"out_a", "out_b"});
60+
61+
ImputerModel model = imputerModel.fit(df);
62+
model.transform(df).show();
63+
// $example off$
64+
65+
spark.stop();
66+
}
67+
}
Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.examples.ml
19+
20+
// $example on$
21+
import org.apache.spark.ml.feature.Imputer
22+
// $example off$
23+
import org.apache.spark.sql.SparkSession
24+
25+
object ImputerExample {
26+
27+
def main(args: Array[String]): Unit = {
28+
val spark = SparkSession.builder
29+
.appName("ImputerExample")
30+
.getOrCreate()
31+
32+
// $example on$
33+
val df = spark.createDataFrame( Seq(
34+
(1.0, Double.NaN),
35+
(2.0, Double.NaN),
36+
(Double.NaN, 3.0),
37+
(4.0, 4.0),
38+
(5.0, 5.0)
39+
)).toDF("a", "b")
40+
41+
val imputer = new Imputer()
42+
.setStrategy("mean")
43+
.setInputCols(Array("a", "b"))
44+
.setOutputCols(Array("out_a", "out_b"))
45+
46+
val model = imputer.fit(df)
47+
model.transform(df).show()
48+
// $example off$
49+
50+
spark.stop()
51+
}
52+
}

mllib/src/main/scala/org/apache/spark/ml/feature/Imputer.scala

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ import org.apache.spark.sql.types._
3535
private[feature] trait ImputerParams extends Params with HasInputCols {
3636

3737
/**
38-
* The imputation strategy.
38+
* The imputation strategy. Currently only "mean" and "median" are supported.
3939
* If "mean", then replace missing values using the mean value of the feature.
4040
* If "median", then replace missing values using the approximate median value of the feature.
4141
* Default: mean
@@ -75,6 +75,8 @@ private[feature] trait ImputerParams extends Params with HasInputCols {
7575

7676
/** Validates and transforms the input schema. */
7777
protected def validateAndTransformSchema(schema: StructType): StructType = {
78+
require(get(inputCols).isDefined, "Input cols must be defined first.")
79+
require(get(outputCols).isDefined, "Output cols must be defined first.")
7880
require($(inputCols).length == $(inputCols).distinct.length, s"inputCols contains" +
7981
s" duplicates: (${$(inputCols).mkString(", ")})")
8082
require($(outputCols).length == $(outputCols).distinct.length, s"outputCols contains" +

0 commit comments

Comments
 (0)