Skip to content

Commit 0a16da9

Browse files
committed
Fixed Linear/Logistic RegressionSuites
1 parent c3c8da5 commit 0a16da9

File tree

2 files changed

+18
-55
lines changed

2 files changed

+18
-55
lines changed

mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala

Lines changed: 18 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@ import org.scalatest.FunSuite
2121

2222
import org.apache.spark.mllib.classification.LogisticRegressionSuite.generateLogisticInput
2323
import org.apache.spark.mllib.linalg.Vector
24-
import org.apache.spark.mllib.regression.LabeledPoint
2524
import org.apache.spark.mllib.util.MLlibTestSparkContext
2625
import org.apache.spark.mllib.util.TestingUtils._
2726
import org.apache.spark.sql.{DataFrame, Row, SQLContext}
@@ -107,39 +106,26 @@ class LogisticRegressionSuite extends FunSuite with MLlibTestSparkContext {
107106
import sqlContext._
108107
val lr = new LogisticRegression
109108

110-
// fit() vs. train()
111-
val model1 = lr.fit(dataset)
112-
val rdd = dataset.select('label, 'features).map { case Row(label: Double, features: Vector) =>
113-
LabeledPoint(label, features)
114-
}
115-
val featuresRDD = rdd.map(_.features)
116-
val model2 = lr.train(rdd)
117-
assert(model1.intercept == model2.intercept)
118-
assert(model1.weights.equals(model2.weights))
119-
assert(model1.numClasses == model2.numClasses)
120-
assert(model1.numClasses === 2)
121-
122-
// transform() vs. predict()
123-
val trans = model1.transform(dataset).select('prediction)
124-
val preds = model1.predict(rdd.map(_.features))
125-
trans.zip(preds).collect().foreach { case (Row(pred1: Double), pred2: Double) =>
126-
assert(pred1 == pred2)
109+
val model = lr.fit(dataset)
110+
assert(model.numClasses === 2)
111+
112+
val threshold = model.getThreshold
113+
val results = model.transform(dataset)
114+
115+
// Compare rawPrediction with probability
116+
results.select('rawPrediction, 'probability).collect().map {
117+
case Row(raw: Vector, prob: Vector) =>
118+
val raw2prob: (Double => Double) = (m) => 1.0 / (1.0 + math.exp(-m))
119+
raw.toArray.map(raw2prob).zip(prob.toArray).foreach { case (r, p) =>
120+
assert(r ~== p relTol eps)
121+
}
127122
}
128123

129-
// Check various types of predictions.
130-
val rawPredictions = model1.predictRaw(featuresRDD)
131-
val probabilities = model1.predictProbabilities(featuresRDD)
132-
val predictions = model1.predict(featuresRDD)
133-
val threshold = model1.getThreshold
134-
rawPredictions.zip(probabilities).collect().foreach { case (raw: Vector, prob: Vector) =>
135-
val computeProbFromRaw: (Double => Double) = (m) => 1.0 / (1.0 + math.exp(-m))
136-
raw.toArray.map(computeProbFromRaw).zip(prob.toArray).foreach { case (r, p) =>
137-
assert(r ~== p relTol eps)
138-
}
139-
}
140-
probabilities.zip(predictions).collect().foreach { case (prob: Vector, pred: Double) =>
141-
val predFromProb = prob.toArray.zipWithIndex.maxBy(_._1)._2
142-
assert(pred == predFromProb)
124+
// Compare prediction with probability
125+
results.select('prediction, 'probability).collect().map {
126+
case Row(pred: Double, prob: Vector) =>
127+
val predFromProb = prob.toArray.zipWithIndex.maxBy(_._1)._2
128+
assert(pred == predFromProb)
143129
}
144130
}
145131
}

mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala

Lines changed: 0 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -69,27 +69,4 @@ class LinearRegressionSuite extends FunSuite with MLlibTestSparkContext {
6969
assert(model2.fittingParamMap.get(lr.regParam).get === 0.1)
7070
assert(model2.getPredictionCol == "thePred")
7171
}
72-
73-
test("linear regression: Predictor, Regressor methods") {
74-
val sqlContext = this.sqlContext
75-
import sqlContext._
76-
val lr = new LinearRegression
77-
78-
// fit() vs. train()
79-
val model1 = lr.fit(dataset)
80-
val rdd = dataset.select('label, 'features).map { case Row(label: Double, features: Vector) =>
81-
LabeledPoint(label, features)
82-
}
83-
val features = rdd.map(_.features)
84-
val model2 = lr.train(rdd)
85-
assert(model1.intercept == model2.intercept)
86-
assert(model1.weights.equals(model2.weights))
87-
88-
// transform() vs. predict()
89-
val trans = model1.transform(dataset).select('prediction)
90-
val preds = model1.predict(rdd.map(_.features))
91-
trans.zip(preds).collect().foreach { case (Row(pred1: Double), pred2: Double) =>
92-
assert(pred1 == pred2)
93-
}
94-
}
9572
}

0 commit comments

Comments
 (0)