Skip to content

Commit 82f340b

Browse files
committed
Fixed bug in LogisticRegression (introduced in this PR). Fixed Java suites
1 parent 0a16da9 commit 82f340b

File tree

5 files changed

+49
-91
lines changed

5 files changed

+49
-91
lines changed

mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,7 @@ class LogisticRegressionModel private[ml] (
149149
if (map(probabilityCol) != "") {
150150
if (map(rawPredictionCol) != "") {
151151
val raw2prob: Vector => Vector = (rawPreds) => {
152-
val prob1 = 1.0 / 1.0 + math.exp(-rawPreds(1))
152+
val prob1 = 1.0 / (1.0 + math.exp(-rawPreds(1)))
153153
Vectors.dense(1.0 - prob1, prob1)
154154
}
155155
tmpData = tmpData.select(Star(None),
@@ -171,7 +171,7 @@ class LogisticRegressionModel private[ml] (
171171
predict.call(map(probabilityCol).attr) as map(predictionCol))
172172
} else if (map(rawPredictionCol) != "") {
173173
val predict: Vector => Double = (rawPreds) => {
174-
val prob1 = 1.0 / 1.0 + math.exp(-rawPreds(1))
174+
val prob1 = 1.0 / (1.0 + math.exp(-rawPreds(1)))
175175
if (prob1 > t) 1.0 else 0.0
176176
}
177177
tmpData = tmpData.select(Star(None),
@@ -207,7 +207,7 @@ class LogisticRegressionModel private[ml] (
207207

208208
override protected def predictRaw(features: Vector): Vector = {
209209
val m = margin(features)
210-
Vectors.dense(-m, m)
210+
Vectors.dense(0.0, m)
211211
}
212212

213213
override protected def copy(): LogisticRegressionModel = {

mllib/src/test/java/org/apache/spark/ml/JavaPipelineSuite.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ public void pipeline() {
6565
.setStages(new PipelineStage[] {scaler, lr});
6666
PipelineModel model = pipeline.fit(dataset);
6767
model.transform(dataset).registerTempTable("prediction");
68-
DataFrame predictions = jsql.sql("SELECT label, score, prediction FROM prediction");
68+
DataFrame predictions = jsql.sql("SELECT label, probability, prediction FROM prediction");
6969
predictions.collectAsList();
7070
}
7171
}

mllib/src/test/java/org/apache/spark/ml/classification/JavaLinearRegressionSuite.java

Lines changed: 5 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -30,12 +30,12 @@
3030
import org.apache.spark.api.java.JavaRDD;
3131
import org.apache.spark.api.java.JavaSparkContext;
3232
import org.apache.spark.api.java.function.Function;
33-
import org.apache.spark.ml.LabeledPoint;
3433
import org.apache.spark.ml.regression.LinearRegression;
3534
import org.apache.spark.ml.regression.LinearRegressionModel;
3635
import static org.apache.spark.mllib.classification.LogisticRegressionSuite
3736
.generateLogisticInputAsList;
3837
import org.apache.spark.mllib.linalg.Vector;
38+
import org.apache.spark.mllib.regression.LabeledPoint;
3939
import org.apache.spark.sql.api.java.JavaSQLContext;
4040
import org.apache.spark.sql.api.java.JavaSchemaRDD;
4141
import org.apache.spark.sql.api.java.Row;
@@ -93,35 +93,14 @@ public void linearRegressionWithSetters() {
9393
.setMaxIter(10)
9494
.setRegParam(1.0);
9595
LinearRegressionModel model = lr.fit(dataset);
96-
assert(model.fittingParamMap().get(lr.maxIter()).get() == 10);
97-
assert(model.fittingParamMap().get(lr.regParam()).get() == 1.0);
96+
assert(model.fittingParamMap().apply(lr.maxIter()) == 10);
97+
assert(model.fittingParamMap().apply(lr.regParam()).equals(1.0));
9898

9999
// Call fit() with new params, and check as many params as we can.
100100
LinearRegressionModel model2 =
101101
lr.fit(dataset, lr.maxIter().w(5), lr.regParam().w(0.1), lr.predictionCol().w("thePred"));
102-
assert(model2.fittingParamMap().get(lr.maxIter()).get() == 5);
103-
assert(model2.fittingParamMap().get(lr.regParam()).get() == 0.1);
102+
assert(model2.fittingParamMap().apply(lr.maxIter()) == 5);
103+
assert(model2.fittingParamMap().apply(lr.regParam()).equals(0.1));
104104
assert(model2.getPredictionCol().equals("thePred"));
105105
}
106-
107-
@Test
108-
public void linearRegressionPredictorClassifierMethods() {
109-
LinearRegression lr = new LinearRegression();
110-
111-
// fit() vs. train()
112-
LinearRegressionModel model1 = lr.fit(dataset);
113-
LinearRegressionModel model2 = lr.train(datasetRDD);
114-
assert(model1.intercept() == model2.intercept());
115-
assert(model1.weights().equals(model2.weights()));
116-
117-
// transform() vs. predict()
118-
model1.transform(dataset).registerTempTable("transformed");
119-
JavaSchemaRDD trans = jsql.sql("SELECT prediction FROM transformed");
120-
JavaRDD<Double> preds = model1.predict(featuresRDD);
121-
for (Tuple2<Row, Double> trans_pred: trans.zip(preds).collect()) {
122-
double t = trans_pred._1().getDouble(0);
123-
double p = trans_pred._2();
124-
assert(t == p);
125-
}
126-
}
127106
}

mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java

Lines changed: 32 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,6 @@
1717

1818
package org.apache.spark.ml.classification;
1919

20-
import scala.Tuple2;
21-
2220
import java.io.Serializable;
2321
import java.lang.Math;
2422
import java.util.ArrayList;
@@ -34,9 +32,8 @@
3432
import org.apache.spark.sql.DataFrame;
3533
import org.apache.spark.sql.SQLContext;
3634
import static org.apache.spark.mllib.classification.LogisticRegressionSuite.generateLogisticInputAsList;
37-
import org.apache.spark.api.java.function.Function;
3835
import org.apache.spark.mllib.linalg.Vector;
39-
import org.apache.spark.ml.LabeledPoint;
36+
import org.apache.spark.mllib.regression.LabeledPoint;
4037
import org.apache.spark.sql.Row;
4138

4239

@@ -47,7 +44,6 @@ public class JavaLogisticRegressionSuite implements Serializable {
4744
private transient DataFrame dataset;
4845

4946
private transient JavaRDD<LabeledPoint> datasetRDD;
50-
private transient JavaRDD<Vector> featuresRDD;
5147
private double eps = 1e-5;
5248

5349
@Before
@@ -60,9 +56,6 @@ public void setUp() {
6056
points.add(new LabeledPoint(lp.label(), lp.features()));
6157
}
6258
datasetRDD = jsc.parallelize(points, 2);
63-
featuresRDD = datasetRDD.map(new Function<LabeledPoint, Vector>() {
64-
@Override public Vector call(LabeledPoint lp) { return lp.features(); }
65-
});
6659
dataset = jsql.applySchema(datasetRDD, LabeledPoint.class);
6760
dataset.registerTempTable("dataset");
6861
}
@@ -79,13 +72,13 @@ public void logisticRegressionDefaultParams() {
7972
assert(lr.getLabelCol().equals("label"));
8073
LogisticRegressionModel model = lr.fit(dataset);
8174
model.transform(dataset).registerTempTable("prediction");
82-
DataFrame predictions = jsql.sql("SELECT label, score, prediction FROM prediction");
75+
DataFrame predictions = jsql.sql("SELECT label, probability, prediction FROM prediction");
8376
predictions.collectAsList();
8477
// Check defaults
8578
assert(model.getThreshold() == 0.5);
8679
assert(model.getFeaturesCol().equals("features"));
8780
assert(model.getPredictionCol().equals("prediction"));
88-
assert(model.getScoreCol().equals("score"));
81+
assert(model.getProbabilityCol().equals("probability"));
8982
}
9083

9184
@Test
@@ -95,17 +88,17 @@ public void logisticRegressionWithSetters() {
9588
.setMaxIter(10)
9689
.setRegParam(1.0)
9790
.setThreshold(0.6)
98-
.setScoreCol("probability");
91+
.setProbabilityCol("myProbability");
9992
LogisticRegressionModel model = lr.fit(dataset);
100-
assert(model.fittingParamMap().get(lr.maxIter()).get() == 10);
101-
assert(model.fittingParamMap().get(lr.regParam()).get() == 1.0);
102-
assert(model.fittingParamMap().get(lr.threshold()).get() == 0.6);
93+
assert(model.fittingParamMap().apply(lr.maxIter()) == 10);
94+
assert(model.fittingParamMap().apply(lr.regParam()).equals(1.0));
95+
assert(model.fittingParamMap().apply(lr.threshold()).equals(0.6));
10396
assert(model.getThreshold() == 0.6);
10497

10598
// Modify model params, and check that the params worked.
10699
model.setThreshold(1.0);
107100
model.transform(dataset).registerTempTable("predAllZero");
108-
SchemaRDD predAllZero = jsql.sql("SELECT prediction, probability FROM predAllZero");
101+
SchemaRDD predAllZero = jsql.sql("SELECT prediction, myProbability FROM predAllZero");
109102
for (Row r: predAllZero.collectAsList()) {
110103
assert(r.getDouble(0) == 0.0);
111104
}
@@ -117,7 +110,7 @@ public void logisticRegressionWithSetters() {
117110
predictions.collectAsList();
118111
*/
119112

120-
model.transform(dataset, model.threshold().w(0.0), model.scoreCol().w("myProb"))
113+
model.transform(dataset, model.threshold().w(0.0), model.probabilityCol().w("myProb"))
121114
.registerTempTable("predNotAllZero");
122115
SchemaRDD predNotAllZero = jsql.sql("SELECT prediction, myProb FROM predNotAllZero");
123116
boolean foundNonZero = false;
@@ -128,54 +121,37 @@ public void logisticRegressionWithSetters() {
128121

129122
// Call fit() with new params, and check as many params as we can.
130123
LogisticRegressionModel model2 = lr.fit(dataset, lr.maxIter().w(5), lr.regParam().w(0.1),
131-
lr.threshold().w(0.4), lr.scoreCol().w("theProb"));
132-
assert(model2.fittingParamMap().get(lr.maxIter()).get() == 5);
133-
assert(model2.fittingParamMap().get(lr.regParam()).get() == 0.1);
134-
assert(model2.fittingParamMap().get(lr.threshold()).get() == 0.4);
124+
lr.threshold().w(0.4), lr.probabilityCol().w("theProb"));
125+
assert(model2.fittingParamMap().apply(lr.maxIter()) == 5);
126+
assert(model2.fittingParamMap().apply(lr.regParam()).equals(0.1));
127+
assert(model2.fittingParamMap().apply(lr.threshold()).equals(0.4));
135128
assert(model2.getThreshold() == 0.4);
136-
assert(model2.getScoreCol().equals("theProb"));
129+
assert(model2.getProbabilityCol().equals("theProb"));
137130
}
138131

132+
@SuppressWarnings("unchecked")
139133
@Test
140134
public void logisticRegressionPredictorClassifierMethods() {
141135
LogisticRegression lr = new LogisticRegression();
142-
143-
// fit() vs. train()
144-
LogisticRegressionModel model1 = lr.fit(dataset);
145-
LogisticRegressionModel model2 = lr.train(datasetRDD);
146-
assert(model1.intercept() == model2.intercept());
147-
assert(model1.weights().equals(model2.weights()));
148-
assert(model1.numClasses() == model2.numClasses());
149-
assert(model1.numClasses() == 2);
150-
151-
// transform() vs. predict()
152-
model1.transform(dataset).registerTempTable("transformed");
153-
SchemaRDD trans = jsql.sql("SELECT prediction FROM transformed");
154-
JavaRDD<Double> preds = model1.predict(featuresRDD);
155-
for (scala.Tuple2<Row, Double> trans_pred: trans.toJavaRDD().zip(preds).collect()) {
156-
double t = trans_pred._1().getDouble(0);
157-
double p = trans_pred._2();
158-
assert(t == p);
136+
LogisticRegressionModel model = lr.fit(dataset);
137+
assert(model.numClasses() == 2);
138+
139+
model.transform(dataset).registerTempTable("transformed");
140+
SchemaRDD trans1 = jsql.sql("SELECT rawPrediction, probability FROM transformed");
141+
for (Row row: trans1.collect()) {
142+
Vector raw = (Vector)row.get(0);
143+
Vector prob = (Vector)row.get(1);
144+
assert(raw.size() == 2);
145+
assert(prob.size() == 2);
146+
double probFromRaw1 = 1.0 / (1.0 + Math.exp(-raw.apply(1)));
147+
assert(Math.abs(prob.apply(1) - probFromRaw1) < eps);
148+
assert(Math.abs(prob.apply(0) - (1.0 - probFromRaw1)) < eps);
159149
}
160150

161-
// Check various types of predictions.
162-
JavaRDD<Vector> rawPredictions = model1.predictRaw(featuresRDD);
163-
JavaRDD<Vector> probabilities = model1.predictProbabilities(featuresRDD);
164-
JavaRDD<Double> predictions = model1.predict(featuresRDD);
165-
double threshold = model1.getThreshold();
166-
for (Tuple2<Vector, Vector> raw_prob: rawPredictions.zip(probabilities).collect()) {
167-
Vector raw = raw_prob._1();
168-
Vector prob = raw_prob._2();
169-
for (int i = 0; i < raw.size(); ++i) {
170-
double r = raw.apply(i);
171-
double p = prob.apply(i);
172-
double pFromR = 1.0 / (1.0 + Math.exp(-r));
173-
assert(Math.abs(r - pFromR) < eps);
174-
}
175-
}
176-
for (Tuple2<Vector, Double> prob_pred: probabilities.zip(predictions).collect()) {
177-
Vector prob = prob_pred._1();
178-
double pred = prob_pred._2();
151+
SchemaRDD trans2 = jsql.sql("SELECT prediction, probability FROM transformed");
152+
for (Row row: trans2.collect()) {
153+
double pred = row.getDouble(0);
154+
Vector prob = (Vector)row.get(1);
179155
double probOfPred = prob.apply((int)pred);
180156
for (int i = 0; i < prob.size(); ++i) {
181157
assert(probOfPred >= prob.apply(i));

mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,9 @@ class LogisticRegressionSuite extends FunSuite with MLlibTestSparkContext {
8282
.select('prediction, 'myProbability)
8383
.collect()
8484
.map { case Row(pred: Double, prob: Vector) => pred }
85-
assert(predAllZero.forall(_ === 0.0))
85+
assert(predAllZero.forall(_ === 0),
86+
s"With threshold=1.0, expected predictions to be all 0, but only" +
87+
s" ${predAllZero.count(_ === 0)} of ${dataset.count()} were 0.")
8688
// Call transform with params, and check that the params worked.
8789
val predNotAllZero =
8890
model.transform(dataset, model.threshold -> 0.0, model.probabilityCol -> "myProb")
@@ -115,10 +117,11 @@ class LogisticRegressionSuite extends FunSuite with MLlibTestSparkContext {
115117
// Compare rawPrediction with probability
116118
results.select('rawPrediction, 'probability).collect().map {
117119
case Row(raw: Vector, prob: Vector) =>
118-
val raw2prob: (Double => Double) = (m) => 1.0 / (1.0 + math.exp(-m))
119-
raw.toArray.map(raw2prob).zip(prob.toArray).foreach { case (r, p) =>
120-
assert(r ~== p relTol eps)
121-
}
120+
assert(raw.size === 2)
121+
assert(prob.size === 2)
122+
val probFromRaw1 = 1.0 / (1.0 + math.exp(-raw(1)))
123+
assert(prob(1) ~== probFromRaw1 relTol eps)
124+
assert(prob(0) ~== 1.0 - probFromRaw1 relTol eps)
122125
}
123126

124127
// Compare prediction with probability

0 commit comments

Comments
 (0)