@@ -21,7 +21,6 @@ import org.scalatest.FunSuite
2121
2222import org .apache .spark .mllib .classification .LogisticRegressionSuite .generateLogisticInput
2323import org .apache .spark .mllib .linalg .Vector
24- import org .apache .spark .mllib .regression .LabeledPoint
2524import org .apache .spark .mllib .util .MLlibTestSparkContext
2625import org .apache .spark .mllib .util .TestingUtils ._
2726import org .apache .spark .sql .{DataFrame , Row , SQLContext }
@@ -107,39 +106,26 @@ class LogisticRegressionSuite extends FunSuite with MLlibTestSparkContext {
107106 import sqlContext ._
108107 val lr = new LogisticRegression
109108
110- // fit() vs. train()
111- val model1 = lr.fit(dataset)
112- val rdd = dataset.select(' label , ' features ).map { case Row (label : Double , features : Vector ) =>
113- LabeledPoint (label, features)
114- }
115- val featuresRDD = rdd.map(_.features)
116- val model2 = lr.train(rdd)
117- assert(model1.intercept == model2.intercept)
118- assert(model1.weights.equals(model2.weights))
119- assert(model1.numClasses == model2.numClasses)
120- assert(model1.numClasses === 2 )
121-
122- // transform() vs. predict()
123- val trans = model1.transform(dataset).select(' prediction )
124- val preds = model1.predict(rdd.map(_.features))
125- trans.zip(preds).collect().foreach { case (Row (pred1 : Double ), pred2 : Double ) =>
126- assert(pred1 == pred2)
109+ val model = lr.fit(dataset)
110+ assert(model.numClasses === 2 )
111+
112+ val threshold = model.getThreshold
113+ val results = model.transform(dataset)
114+
115+ // Compare rawPrediction with probability
116+ results.select(' rawPrediction , ' probability ).collect().map {
117+ case Row (raw : Vector , prob : Vector ) =>
118+ val raw2prob : (Double => Double ) = (m) => 1.0 / (1.0 + math.exp(- m))
119+ raw.toArray.map(raw2prob).zip(prob.toArray).foreach { case (r, p) =>
120+ assert(r ~== p relTol eps)
121+ }
127122 }
128123
129- // Check various types of predictions.
130- val rawPredictions = model1.predictRaw(featuresRDD)
131- val probabilities = model1.predictProbabilities(featuresRDD)
132- val predictions = model1.predict(featuresRDD)
133- val threshold = model1.getThreshold
134- rawPredictions.zip(probabilities).collect().foreach { case (raw : Vector , prob : Vector ) =>
135- val computeProbFromRaw : (Double => Double ) = (m) => 1.0 / (1.0 + math.exp(- m))
136- raw.toArray.map(computeProbFromRaw).zip(prob.toArray).foreach { case (r, p) =>
137- assert(r ~== p relTol eps)
138- }
139- }
140- probabilities.zip(predictions).collect().foreach { case (prob : Vector , pred : Double ) =>
141- val predFromProb = prob.toArray.zipWithIndex.maxBy(_._1)._2
142- assert(pred == predFromProb)
124+ // Compare prediction with probability
125+ results.select(' prediction , ' probability ).collect().map {
126+ case Row (pred : Double , prob : Vector ) =>
127+ val predFromProb = prob.toArray.zipWithIndex.maxBy(_._1)._2
128+ assert(pred == predFromProb)
143129 }
144130 }
145131}
0 commit comments