1717
1818package org .apache .spark .ml .classification ;
1919
20- import scala .Tuple2 ;
21-
2220import java .io .Serializable ;
2321import java .lang .Math ;
2422import java .util .ArrayList ;
3432import org .apache .spark .sql .DataFrame ;
3533import org .apache .spark .sql .SQLContext ;
3634import static org .apache .spark .mllib .classification .LogisticRegressionSuite .generateLogisticInputAsList ;
37- import org .apache .spark .api .java .function .Function ;
3835import org .apache .spark .mllib .linalg .Vector ;
39- import org .apache .spark .ml .LabeledPoint ;
36+ import org .apache .spark .mllib . regression .LabeledPoint ;
4037import org .apache .spark .sql .Row ;
4138
4239
@@ -47,7 +44,6 @@ public class JavaLogisticRegressionSuite implements Serializable {
4744 private transient DataFrame dataset ;
4845
4946 private transient JavaRDD <LabeledPoint > datasetRDD ;
50- private transient JavaRDD <Vector > featuresRDD ;
5147 private double eps = 1e-5 ;
5248
5349 @ Before
@@ -60,9 +56,6 @@ public void setUp() {
6056 points .add (new LabeledPoint (lp .label (), lp .features ()));
6157 }
6258 datasetRDD = jsc .parallelize (points , 2 );
63- featuresRDD = datasetRDD .map (new Function <LabeledPoint , Vector >() {
64- @ Override public Vector call (LabeledPoint lp ) { return lp .features (); }
65- });
6659 dataset = jsql .applySchema (datasetRDD , LabeledPoint .class );
6760 dataset .registerTempTable ("dataset" );
6861 }
@@ -79,13 +72,13 @@ public void logisticRegressionDefaultParams() {
7972 assert (lr .getLabelCol ().equals ("label" ));
8073 LogisticRegressionModel model = lr .fit (dataset );
8174 model .transform (dataset ).registerTempTable ("prediction" );
82- DataFrame predictions = jsql .sql ("SELECT label, score , prediction FROM prediction" );
75+ DataFrame predictions = jsql .sql ("SELECT label, probability , prediction FROM prediction" );
8376 predictions .collectAsList ();
8477 // Check defaults
8578 assert (model .getThreshold () == 0.5 );
8679 assert (model .getFeaturesCol ().equals ("features" ));
8780 assert (model .getPredictionCol ().equals ("prediction" ));
88- assert (model .getScoreCol ().equals ("score " ));
81+ assert (model .getProbabilityCol ().equals ("probability " ));
8982 }
9083
9184 @ Test
@@ -95,17 +88,17 @@ public void logisticRegressionWithSetters() {
9588 .setMaxIter (10 )
9689 .setRegParam (1.0 )
9790 .setThreshold (0.6 )
98- .setScoreCol ( "probability " );
91+ .setProbabilityCol ( "myProbability " );
9992 LogisticRegressionModel model = lr .fit (dataset );
100- assert (model .fittingParamMap ().get (lr .maxIter ()). get ( ) == 10 );
101- assert (model .fittingParamMap ().get (lr .regParam ()).get () == 1.0 );
102- assert (model .fittingParamMap ().get (lr .threshold ()).get () == 0.6 );
93+ assert (model .fittingParamMap ().apply (lr .maxIter ()) == 10 );
94+ assert (model .fittingParamMap ().apply (lr .regParam ()).equals ( 1.0 ) );
95+ assert (model .fittingParamMap ().apply (lr .threshold ()).equals ( 0.6 ) );
10396 assert (model .getThreshold () == 0.6 );
10497
10598 // Modify model params, and check that the params worked.
10699 model .setThreshold (1.0 );
107100 model .transform (dataset ).registerTempTable ("predAllZero" );
108- SchemaRDD predAllZero = jsql .sql ("SELECT prediction, probability FROM predAllZero" );
101+ SchemaRDD predAllZero = jsql .sql ("SELECT prediction, myProbability FROM predAllZero" );
109102 for (Row r : predAllZero .collectAsList ()) {
110103 assert (r .getDouble (0 ) == 0.0 );
111104 }
@@ -117,7 +110,7 @@ public void logisticRegressionWithSetters() {
117110 predictions.collectAsList();
118111 */
119112
120- model .transform (dataset , model .threshold ().w (0.0 ), model .scoreCol ().w ("myProb" ))
113+ model .transform (dataset , model .threshold ().w (0.0 ), model .probabilityCol ().w ("myProb" ))
121114 .registerTempTable ("predNotAllZero" );
122115 SchemaRDD predNotAllZero = jsql .sql ("SELECT prediction, myProb FROM predNotAllZero" );
123116 boolean foundNonZero = false ;
@@ -128,54 +121,37 @@ public void logisticRegressionWithSetters() {
128121
129122 // Call fit() with new params, and check as many params as we can.
130123 LogisticRegressionModel model2 = lr .fit (dataset , lr .maxIter ().w (5 ), lr .regParam ().w (0.1 ),
131- lr .threshold ().w (0.4 ), lr .scoreCol ().w ("theProb" ));
132- assert (model2 .fittingParamMap ().get (lr .maxIter ()). get ( ) == 5 );
133- assert (model2 .fittingParamMap ().get (lr .regParam ()).get () == 0.1 );
134- assert (model2 .fittingParamMap ().get (lr .threshold ()).get () == 0.4 );
124+ lr .threshold ().w (0.4 ), lr .probabilityCol ().w ("theProb" ));
125+ assert (model2 .fittingParamMap ().apply (lr .maxIter ()) == 5 );
126+ assert (model2 .fittingParamMap ().apply (lr .regParam ()).equals ( 0.1 ) );
127+ assert (model2 .fittingParamMap ().apply (lr .threshold ()).equals ( 0.4 ) );
135128 assert (model2 .getThreshold () == 0.4 );
136- assert (model2 .getScoreCol ().equals ("theProb" ));
129+ assert (model2 .getProbabilityCol ().equals ("theProb" ));
137130 }
138131
132+ @ SuppressWarnings ("unchecked" )
139133 @ Test
140134 public void logisticRegressionPredictorClassifierMethods () {
141135 LogisticRegression lr = new LogisticRegression ();
142-
143- // fit() vs. train()
144- LogisticRegressionModel model1 = lr .fit (dataset );
145- LogisticRegressionModel model2 = lr .train (datasetRDD );
146- assert (model1 .intercept () == model2 .intercept ());
147- assert (model1 .weights ().equals (model2 .weights ()));
148- assert (model1 .numClasses () == model2 .numClasses ());
149- assert (model1 .numClasses () == 2 );
150-
151- // transform() vs. predict()
152- model1 .transform (dataset ).registerTempTable ("transformed" );
153- SchemaRDD trans = jsql .sql ("SELECT prediction FROM transformed" );
154- JavaRDD <Double > preds = model1 .predict (featuresRDD );
155- for (scala .Tuple2 <Row , Double > trans_pred : trans .toJavaRDD ().zip (preds ).collect ()) {
156- double t = trans_pred ._1 ().getDouble (0 );
157- double p = trans_pred ._2 ();
158- assert (t == p );
136+ LogisticRegressionModel model = lr .fit (dataset );
137+ assert (model .numClasses () == 2 );
138+
139+ model .transform (dataset ).registerTempTable ("transformed" );
140+ SchemaRDD trans1 = jsql .sql ("SELECT rawPrediction, probability FROM transformed" );
141+ for (Row row : trans1 .collect ()) {
142+ Vector raw = (Vector )row .get (0 );
143+ Vector prob = (Vector )row .get (1 );
144+ assert (raw .size () == 2 );
145+ assert (prob .size () == 2 );
146+ double probFromRaw1 = 1.0 / (1.0 + Math .exp (-raw .apply (1 )));
147+ assert (Math .abs (prob .apply (1 ) - probFromRaw1 ) < eps );
148+ assert (Math .abs (prob .apply (0 ) - (1.0 - probFromRaw1 )) < eps );
159149 }
160150
161- // Check various types of predictions.
162- JavaRDD <Vector > rawPredictions = model1 .predictRaw (featuresRDD );
163- JavaRDD <Vector > probabilities = model1 .predictProbabilities (featuresRDD );
164- JavaRDD <Double > predictions = model1 .predict (featuresRDD );
165- double threshold = model1 .getThreshold ();
166- for (Tuple2 <Vector , Vector > raw_prob : rawPredictions .zip (probabilities ).collect ()) {
167- Vector raw = raw_prob ._1 ();
168- Vector prob = raw_prob ._2 ();
169- for (int i = 0 ; i < raw .size (); ++i ) {
170- double r = raw .apply (i );
171- double p = prob .apply (i );
172- double pFromR = 1.0 / (1.0 + Math .exp (-r ));
173- assert (Math .abs (r - pFromR ) < eps );
174- }
175- }
176- for (Tuple2 <Vector , Double > prob_pred : probabilities .zip (predictions ).collect ()) {
177- Vector prob = prob_pred ._1 ();
178- double pred = prob_pred ._2 ();
151+ SchemaRDD trans2 = jsql .sql ("SELECT prediction, probability FROM transformed" );
152+ for (Row row : trans2 .collect ()) {
153+ double pred = row .getDouble (0 );
154+ Vector prob = (Vector )row .get (1 );
179155 double probOfPred = prob .apply ((int )pred );
180156 for (int i = 0 ; i < prob .size (); ++i ) {
181157 assert (probOfPred >= prob .apply (i ));
0 commit comments