1717
1818package org .apache .spark .ml .classification ;
1919
20+ import scala .Tuple2 ;
21+
2022import java .io .Serializable ;
23+ import java .lang .Math ;
24+ import java .util .ArrayList ;
2125import java .util .List ;
2226
2327import org .junit .After ;
2428import org .junit .Before ;
2529import org .junit .Test ;
2630
31+ import org .apache .spark .api .java .JavaRDD ;
2732import org .apache .spark .api .java .JavaSparkContext ;
2833import org .apache .spark .mllib .regression .LabeledPoint ;
2934import org .apache .spark .sql .DataFrame ;
3035import org .apache .spark .sql .SQLContext ;
3136import static org .apache .spark .mllib .classification .LogisticRegressionSuite .generateLogisticInputAsList ;
37+ import org .apache .spark .api .java .function .Function ;
38+ import org .apache .spark .mllib .linalg .Vector ;
39+ import org .apache .spark .ml .LabeledPoint ;
40+ import org .apache .spark .sql .Row ;
41+
3242
3343public class JavaLogisticRegressionSuite implements Serializable {
3444
3545 private transient JavaSparkContext jsc ;
3646 private transient SQLContext jsql ;
3747 private transient DataFrame dataset ;
3848
49+ private transient JavaRDD <LabeledPoint > datasetRDD ;
50+ private transient JavaRDD <Vector > featuresRDD ;
51+ private double eps = 1e-5 ;
52+
3953 @ Before
4054 public void setUp () {
4155 jsc = new JavaSparkContext ("local" , "JavaLogisticRegressionSuite" );
4256 jsql = new SQLContext (jsc );
43- List <LabeledPoint > points = generateLogisticInputAsList (1.0 , 1.0 , 100 , 42 );
44- dataset = jsql .applySchema (jsc .parallelize (points , 2 ), LabeledPoint .class );
57+ List <LabeledPoint > points = new ArrayList <LabeledPoint >();
58+ for (org .apache .spark .mllib .regression .LabeledPoint lp :
59+ generateLogisticInputAsList (1.0 , 1.0 , 100 , 42 )) {
60+ points .add (new LabeledPoint (lp .label (), lp .features ()));
61+ }
62+ datasetRDD = jsc .parallelize (points , 2 );
63+ featuresRDD = datasetRDD .map (new Function <LabeledPoint , Vector >() {
64+ @ Override public Vector call (LabeledPoint lp ) { return lp .features (); }
65+ });
66+ dataset = jsql .applySchema (datasetRDD , LabeledPoint .class );
67+ dataset .registerTempTable ("dataset" );
4568 }
4669
4770 @ After
@@ -51,29 +74,112 @@ public void tearDown() {
5174 }
5275
5376 @ Test
54- public void logisticRegression () {
77+ public void logisticRegressionDefaultParams () {
5578 LogisticRegression lr = new LogisticRegression ();
79+ assert (lr .getLabelCol ().equals ("label" ));
5680 LogisticRegressionModel model = lr .fit (dataset );
5781 model .transform (dataset ).registerTempTable ("prediction" );
5882 DataFrame predictions = jsql .sql ("SELECT label, score, prediction FROM prediction" );
5983 predictions .collectAsList ();
84+ // Check defaults
85+ assert (model .getThreshold () == 0.5 );
86+ assert (model .getFeaturesCol ().equals ("features" ));
87+ assert (model .getPredictionCol ().equals ("prediction" ));
88+ assert (model .getScoreCol ().equals ("score" ));
6089 }
6190
6291 @ Test
6392 public void logisticRegressionWithSetters () {
93+ // Set params, train, and check as many params as we can.
6494 LogisticRegression lr = new LogisticRegression ()
6595 .setMaxIter (10 )
66- .setRegParam (1.0 );
96+ .setRegParam (1.0 )
97+ .setThreshold (0.6 )
98+ .setScoreCol ("probability" );
6799 LogisticRegressionModel model = lr .fit (dataset );
100+ assert (model .fittingParamMap ().get (lr .maxIter ()).get () == 10 );
101+ assert (model .fittingParamMap ().get (lr .regParam ()).get () == 1.0 );
102+ assert (model .fittingParamMap ().get (lr .threshold ()).get () == 0.6 );
103+ assert (model .getThreshold () == 0.6 );
104+
105+ // Modify model params, and check that the params worked.
106+ model .setThreshold (1.0 );
107+ model .transform (dataset ).registerTempTable ("predAllZero" );
108+ SchemaRDD predAllZero = jsql .sql ("SELECT prediction, probability FROM predAllZero" );
109+ for (Row r : predAllZero .collectAsList ()) {
110+ assert (r .getDouble (0 ) == 0.0 );
111+ }
112+ // Call transform with params, and check that the params worked.
113+ /* TODO: USE THIS
68114 model.transform(dataset, model.threshold().w(0.8)) // overwrite threshold
69- .registerTempTable ("prediction" );
115+ .registerTempTable("prediction");
70116 DataFrame predictions = jsql.sql("SELECT label, score, prediction FROM prediction");
71117 predictions.collectAsList();
118+ */
119+
120+ model .transform (dataset , model .threshold ().w (0.0 ), model .scoreCol ().w ("myProb" ))
121+ .registerTempTable ("predNotAllZero" );
122+ SchemaRDD predNotAllZero = jsql .sql ("SELECT prediction, myProb FROM predNotAllZero" );
123+ boolean foundNonZero = false ;
124+ for (Row r : predNotAllZero .collectAsList ()) {
125+ if (r .getDouble (0 ) != 0.0 ) foundNonZero = true ;
126+ }
127+ assert (foundNonZero );
128+
129+ // Call fit() with new params, and check as many params as we can.
130+ LogisticRegressionModel model2 = lr .fit (dataset , lr .maxIter ().w (5 ), lr .regParam ().w (0.1 ),
131+ lr .threshold ().w (0.4 ), lr .scoreCol ().w ("theProb" ));
132+ assert (model2 .fittingParamMap ().get (lr .maxIter ()).get () == 5 );
133+ assert (model2 .fittingParamMap ().get (lr .regParam ()).get () == 0.1 );
134+ assert (model2 .fittingParamMap ().get (lr .threshold ()).get () == 0.4 );
135+ assert (model2 .getThreshold () == 0.4 );
136+ assert (model2 .getScoreCol ().equals ("theProb" ));
72137 }
73138
74139 @ Test
75- public void logisticRegressionFitWithVarargs () {
140+ public void logisticRegressionPredictorClassifierMethods () {
76141 LogisticRegression lr = new LogisticRegression ();
77- lr .fit (dataset , lr .maxIter ().w (10 ), lr .regParam ().w (1.0 ));
142+
143+ // fit() vs. train()
144+ LogisticRegressionModel model1 = lr .fit (dataset );
145+ LogisticRegressionModel model2 = lr .train (datasetRDD );
146+ assert (model1 .intercept () == model2 .intercept ());
147+ assert (model1 .weights ().equals (model2 .weights ()));
148+ assert (model1 .numClasses () == model2 .numClasses ());
149+ assert (model1 .numClasses () == 2 );
150+
151+ // transform() vs. predict()
152+ model1 .transform (dataset ).registerTempTable ("transformed" );
153+ SchemaRDD trans = jsql .sql ("SELECT prediction FROM transformed" );
154+ JavaRDD <Double > preds = model1 .predict (featuresRDD );
155+ for (scala .Tuple2 <Row , Double > trans_pred : trans .toJavaRDD ().zip (preds ).collect ()) {
156+ double t = trans_pred ._1 ().getDouble (0 );
157+ double p = trans_pred ._2 ();
158+ assert (t == p );
159+ }
160+
161+ // Check various types of predictions.
162+ JavaRDD <Vector > rawPredictions = model1 .predictRaw (featuresRDD );
163+ JavaRDD <Vector > probabilities = model1 .predictProbabilities (featuresRDD );
164+ JavaRDD <Double > predictions = model1 .predict (featuresRDD );
165+ double threshold = model1 .getThreshold ();
166+ for (Tuple2 <Vector , Vector > raw_prob : rawPredictions .zip (probabilities ).collect ()) {
167+ Vector raw = raw_prob ._1 ();
168+ Vector prob = raw_prob ._2 ();
169+ for (int i = 0 ; i < raw .size (); ++i ) {
170+ double r = raw .apply (i );
171+ double p = prob .apply (i );
172+ double pFromR = 1.0 / (1.0 + Math .exp (-r ));
173+ assert (Math .abs (r - pFromR ) < eps );
174+ }
175+ }
176+ for (Tuple2 <Vector , Double > prob_pred : probabilities .zip (predictions ).collect ()) {
177+ Vector prob = prob_pred ._1 ();
178+ double pred = prob_pred ._2 ();
179+ double probOfPred = prob .apply ((int )pred );
180+ for (int i = 0 ; i < prob .size (); ++i ) {
181+ assert (probOfPred >= prob .apply (i ));
182+ }
183+ }
78184 }
79185}
0 commit comments