Skip to content

Commit bcb9549

Browse files
committed
Fixed issues after rebasing from master (after move from SchemaRDD to DataFrame)
1 parent 9872424 commit bcb9549

File tree

10 files changed

+79
-134
lines changed

10 files changed

+79
-134
lines changed

mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,9 @@ import org.apache.spark.annotation.{DeveloperApi, AlphaComponent}
2121
import org.apache.spark.ml.impl.estimator.{PredictionModel, Predictor, PredictorParams}
2222
import org.apache.spark.ml.param.{Params, ParamMap, HasRawPredictionCol}
2323
import org.apache.spark.mllib.linalg.{Vector, VectorUDT}
24-
import org.apache.spark.sql._
25-
import org.apache.spark.sql.catalyst.analysis.Star
26-
import org.apache.spark.sql.types.{DataType, StructType}
24+
import org.apache.spark.sql.Dsl._
25+
import org.apache.spark.sql.DataFrame
26+
import org.apache.spark.sql.types.{DataType, DoubleType, StructType}
2727

2828

2929
/**
@@ -95,7 +95,7 @@ abstract class ClassificationModel[FeaturesType, M <: ClassificationModel[Featur
9595
* @param paramMap additional parameters, overwrite embedded params
9696
* @return transformed dataset
9797
*/
98-
override def transform(dataset: SchemaRDD, paramMap: ParamMap): SchemaRDD = {
98+
override def transform(dataset: DataFrame, paramMap: ParamMap): DataFrame = {
9999
// This default implementation should be overridden as needed.
100100

101101
// Check schema
@@ -162,12 +162,9 @@ private[ml] object ClassificationModel {
162162
* @return (number of columns added, transformed dataset)
163163
*/
164164
private[ml] def transformColumnsImpl[FeaturesType](
165-
dataset: SchemaRDD,
165+
dataset: DataFrame,
166166
model: ClassificationModel[FeaturesType, _],
167-
map: ParamMap): (Int, SchemaRDD) = {
168-
169-
import org.apache.spark.sql.catalyst.dsl._
170-
import dataset.sqlContext._
167+
map: ParamMap): (Int, DataFrame) = {
171168

172169
// Output selected columns only.
173170
// This is a bit complicated since it tries to avoid repeated computation.
@@ -176,22 +173,25 @@ private[ml] object ClassificationModel {
176173
if (map(model.rawPredictionCol) != "") {
177174
// output raw prediction
178175
val features2raw: FeaturesType => Vector = model.predictRaw
179-
tmpData = tmpData.select(Star(None),
180-
features2raw.call(map(model.featuresCol).attr) as map(model.rawPredictionCol))
176+
tmpData = tmpData.select($"*",
177+
callUDF(features2raw, new VectorUDT,
178+
tmpData(map(model.featuresCol))).as(map(model.rawPredictionCol)))
181179
numColsOutput += 1
182180
if (map(model.predictionCol) != "") {
183181
val raw2pred: Vector => Double = (rawPred) => {
184182
rawPred.toArray.zipWithIndex.maxBy(_._1)._2
185183
}
186-
tmpData = tmpData.select(Star(None),
187-
raw2pred.call(map(model.rawPredictionCol).attr) as map(model.predictionCol))
184+
tmpData = tmpData.select($"*",
185+
callUDF(raw2pred, DoubleType,
186+
tmpData(map(model.rawPredictionCol))).as(map(model.predictionCol)))
188187
numColsOutput += 1
189188
}
190189
} else if (map(model.predictionCol) != "") {
191190
// output prediction
192191
val features2pred: FeaturesType => Double = model.predict
193-
tmpData = tmpData.select(Star(None),
194-
features2pred.call(map(model.featuresCol).attr) as map(model.predictionCol))
192+
tmpData = tmpData.select($"*",
193+
callUDF(features2pred, DoubleType,
194+
tmpData(map(model.featuresCol))).as(map(model.predictionCol)))
195195
numColsOutput += 1
196196
}
197197
(numColsOutput, tmpData)

mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala

Lines changed: 17 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -20,12 +20,10 @@ package org.apache.spark.ml.classification
2020
import org.apache.spark.annotation.AlphaComponent
2121
import org.apache.spark.ml.param._
2222
import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS
23-
import org.apache.spark.mllib.linalg.{BLAS, Vector, Vectors}
24-
import org.apache.spark.sql._
23+
import org.apache.spark.mllib.linalg.{VectorUDT, BLAS, Vector, Vectors}
24+
import org.apache.spark.sql.DataFrame
2525
import org.apache.spark.sql.Dsl._
26-
import org.apache.spark.sql.types.{DoubleType, StructField, StructType}
27-
import org.apache.spark.sql.catalyst.analysis.Star
28-
import org.apache.spark.sql.catalyst.dsl._
26+
import org.apache.spark.sql.types.DoubleType
2927
import org.apache.spark.storage.StorageLevel
3028

3129

@@ -55,10 +53,10 @@ class LogisticRegression
5553
def setMaxIter(value: Int): this.type = set(maxIter, value)
5654
def setThreshold(value: Double): this.type = set(threshold, value)
5755

58-
override protected def train(dataset: SchemaRDD, paramMap: ParamMap): LogisticRegressionModel = {
56+
override protected def train(dataset: DataFrame, paramMap: ParamMap): LogisticRegressionModel = {
5957
// Extract columns from data. If dataset is persisted, do not persist oldDataset.
6058
val oldDataset = extractLabeledPoints(dataset, paramMap)
61-
val handlePersistence = dataset.getStorageLevel == StorageLevel.NONE
59+
val handlePersistence = dataset.rdd.getStorageLevel == StorageLevel.NONE
6260
if (handlePersistence) {
6361
oldDataset.persist(StorageLevel.MEMORY_AND_DISK)
6462
}
@@ -106,25 +104,10 @@ class LogisticRegressionModel private[ml] (
106104
1.0 / (1.0 + math.exp(-m))
107105
}
108106

109-
/*
110107
override def transform(dataset: DataFrame, paramMap: ParamMap): DataFrame = {
111-
transformSchema(dataset.schema, paramMap, logging = true)
112-
val map = this.paramMap ++ paramMap
113-
val scoreFunction = udf { v: Vector =>
114-
val margin = BLAS.dot(v, weights)
115-
1.0 / (1.0 + math.exp(-margin))
116-
}
117-
val t = map(threshold)
118-
val predictFunction: Double => Double = (score) => { if (score > t) 1.0 else 0.0 }
119-
dataset
120-
.select($"*", callUDF(scoreFunction, col(map(featuresCol))).as(map(scoreCol)))
121-
.select($"*", callUDF(predictFunction, col(map(scoreCol))).as(map(predictionCol)))
122-
*/
123-
override def transform(dataset: SchemaRDD, paramMap: ParamMap): SchemaRDD = {
124108
// Check schema
125109
transformSchema(dataset.schema, paramMap, logging = true)
126110

127-
import dataset.sqlContext._
128111
val map = this.paramMap ++ paramMap
129112

130113
// Output selected columns only.
@@ -136,8 +119,8 @@ class LogisticRegressionModel private[ml] (
136119
var numColsOutput = 0
137120
if (map(rawPredictionCol) != "") {
138121
val features2raw: Vector => Vector = predictRaw
139-
tmpData = tmpData.select(Star(None),
140-
features2raw.call(map(featuresCol).attr) as map(rawPredictionCol))
122+
tmpData = tmpData.select($"*",
123+
callUDF(features2raw, new VectorUDT, tmpData(map(featuresCol))).as(map(rawPredictionCol)))
141124
numColsOutput += 1
142125
}
143126
if (map(probabilityCol) != "") {
@@ -146,12 +129,12 @@ class LogisticRegressionModel private[ml] (
146129
val prob1 = 1.0 / (1.0 + math.exp(-rawPreds(1)))
147130
Vectors.dense(1.0 - prob1, prob1)
148131
}
149-
tmpData = tmpData.select(Star(None),
150-
raw2prob.call(map(rawPredictionCol).attr) as map(probabilityCol))
132+
tmpData = tmpData.select($"*",
133+
callUDF(raw2prob, new VectorUDT, tmpData(map(rawPredictionCol))).as(map(probabilityCol)))
151134
} else {
152135
val features2prob: Vector => Vector = predictProbabilities
153-
tmpData = tmpData.select(Star(None),
154-
features2prob.call(map(featuresCol).attr) as map(probabilityCol))
136+
tmpData = tmpData.select($"*",
137+
callUDF(features2prob, new VectorUDT, tmpData(map(featuresCol))).as(map(probabilityCol)))
155138
}
156139
numColsOutput += 1
157140
}
@@ -161,19 +144,19 @@ class LogisticRegressionModel private[ml] (
161144
val predict: Vector => Double = (probs) => {
162145
if (probs(1) > t) 1.0 else 0.0
163146
}
164-
tmpData = tmpData.select(Star(None),
165-
predict.call(map(probabilityCol).attr) as map(predictionCol))
147+
tmpData = tmpData.select($"*",
148+
callUDF(predict, DoubleType, tmpData(map(probabilityCol))).as(map(predictionCol)))
166149
} else if (map(rawPredictionCol) != "") {
167150
val predict: Vector => Double = (rawPreds) => {
168151
val prob1 = 1.0 / (1.0 + math.exp(-rawPreds(1)))
169152
if (prob1 > t) 1.0 else 0.0
170153
}
171-
tmpData = tmpData.select(Star(None),
172-
predict.call(map(rawPredictionCol).attr) as map(predictionCol))
154+
tmpData = tmpData.select($"*",
155+
callUDF(predict, DoubleType, tmpData(map(rawPredictionCol))).as(map(predictionCol)))
173156
} else {
174157
val predict: Vector => Double = this.predict
175-
tmpData = tmpData.select(Star(None),
176-
predict.call(map(featuresCol).attr) as map(predictionCol))
158+
tmpData = tmpData.select($"*",
159+
callUDF(predict, DoubleType, tmpData(map(featuresCol))).as(map(predictionCol)))
177160
}
178161
numColsOutput += 1
179162
}

mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,8 @@ package org.apache.spark.ml.classification
2020
import org.apache.spark.annotation.{AlphaComponent, DeveloperApi}
2121
import org.apache.spark.ml.param.{HasProbabilityCol, ParamMap, Params}
2222
import org.apache.spark.mllib.linalg.{Vector, VectorUDT}
23-
import org.apache.spark.sql._
24-
import org.apache.spark.sql.catalyst.analysis.Star
23+
import org.apache.spark.sql.DataFrame
24+
import org.apache.spark.sql.Dsl._
2525
import org.apache.spark.sql.types.{DataType, StructType}
2626

2727

@@ -91,10 +91,8 @@ abstract class ProbabilisticClassificationModel[
9191
* @param paramMap additional parameters, overwrite embedded params
9292
* @return transformed dataset
9393
*/
94-
override def transform(dataset: SchemaRDD, paramMap: ParamMap): SchemaRDD = {
94+
override def transform(dataset: DataFrame, paramMap: ParamMap): DataFrame = {
9595
// This default implementation should be overridden as needed.
96-
import dataset.sqlContext._
97-
import org.apache.spark.sql.catalyst.dsl._
9896

9997
// Check schema
10098
transformSchema(dataset.schema, paramMap, logging = true)
@@ -118,8 +116,9 @@ abstract class ProbabilisticClassificationModel[
118116
val features2probs: FeaturesType => Vector = (features) => {
119117
tmpModel.predictProbabilities(features)
120118
}
121-
outputData.select(Star(None),
122-
features2probs.call(map(featuresCol).attr) as map(probabilityCol))
119+
outputData.select($"*",
120+
callUDF(features2probs, new VectorUDT,
121+
outputData(map(featuresCol))).as(map(probabilityCol)))
123122
} else {
124123
if (numColsOutput == 0) {
125124
this.logWarning(s"$uid: ProbabilisticClassificationModel.transform() was called as NOOP" +

mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,11 @@
1818
package org.apache.spark.ml.evaluation
1919

2020
import org.apache.spark.annotation.AlphaComponent
21-
import org.apache.spark.ml._
21+
import org.apache.spark.ml.Evaluator
2222
import org.apache.spark.ml.param._
2323
import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics
24-
import org.apache.spark.sql.{DataFrame, Row}
2524
import org.apache.spark.mllib.linalg.{Vector, VectorUDT}
25+
import org.apache.spark.sql.{DataFrame, Row}
2626
import org.apache.spark.sql.types.DoubleType
2727

2828

@@ -52,7 +52,7 @@ class BinaryClassificationEvaluator extends Evaluator with Params
5252
checkInputColumn(schema, map(labelCol), DoubleType)
5353

5454
// TODO: When dataset metadata has been implemented, check rawPredictionCol vector length = 2.
55-
val scoreAndLabels = dataset.select(map(rawPredictionCol).attr, map(labelCol).attr)
55+
val scoreAndLabels = dataset.select(map(rawPredictionCol), map(labelCol))
5656
.map { case Row(rawPrediction: Vector, label: Double) =>
5757
(rawPrediction(1), label)
5858
}

mllib/src/main/scala/org/apache/spark/ml/impl/estimator/Predictor.scala

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,8 @@ import org.apache.spark.ml.param._
2323
import org.apache.spark.mllib.linalg.{VectorUDT, Vector}
2424
import org.apache.spark.mllib.regression.LabeledPoint
2525
import org.apache.spark.rdd.RDD
26-
import org.apache.spark.sql._
27-
import org.apache.spark.sql.catalyst.analysis.Star
26+
import org.apache.spark.sql.{DataFrame, Row}
27+
import org.apache.spark.sql.Dsl._
2828
import org.apache.spark.sql.types.{DataType, DoubleType, StructType}
2929

3030

@@ -85,7 +85,7 @@ abstract class Predictor[
8585
def setFeaturesCol(value: String): Learner = set(featuresCol, value).asInstanceOf[Learner]
8686
def setPredictionCol(value: String): Learner = set(predictionCol, value).asInstanceOf[Learner]
8787

88-
override def fit(dataset: SchemaRDD, paramMap: ParamMap): M = {
88+
override def fit(dataset: DataFrame, paramMap: ParamMap): M = {
8989
// This handles a few items such as schema validation.
9090
// Developers only need to implement train().
9191
transformSchema(dataset.schema, paramMap, logging = true)
@@ -108,7 +108,7 @@ abstract class Predictor[
108108
* @return Fitted model
109109
*/
110110
@DeveloperApi
111-
protected def train(dataset: SchemaRDD, paramMap: ParamMap): M
111+
protected def train(dataset: DataFrame, paramMap: ParamMap): M
112112

113113
/**
114114
* :: DeveloperApi ::
@@ -131,10 +131,9 @@ abstract class Predictor[
131131
* Extract [[labelCol]] and [[featuresCol]] from the given dataset,
132132
* and put it in an RDD with strong types.
133133
*/
134-
protected def extractLabeledPoints(dataset: SchemaRDD, paramMap: ParamMap): RDD[LabeledPoint] = {
135-
import dataset.sqlContext._
134+
protected def extractLabeledPoints(dataset: DataFrame, paramMap: ParamMap): RDD[LabeledPoint] = {
136135
val map = this.paramMap ++ paramMap
137-
dataset.select(map(labelCol).attr, map(featuresCol).attr)
136+
dataset.select(map(labelCol), map(featuresCol))
138137
.map { case Row(label: Double, features: Vector) =>
139138
LabeledPoint(label, features)
140139
}
@@ -184,10 +183,8 @@ abstract class PredictionModel[FeaturesType, M <: PredictionModel[FeaturesType,
184183
* @param paramMap additional parameters, overwrite embedded params
185184
* @return transformed dataset with [[predictionCol]] of type [[Double]]
186185
*/
187-
override def transform(dataset: SchemaRDD, paramMap: ParamMap): SchemaRDD = {
186+
override def transform(dataset: DataFrame, paramMap: ParamMap): DataFrame = {
188187
// This default implementation should be overridden as needed.
189-
import org.apache.spark.sql.catalyst.dsl._
190-
import dataset.sqlContext._
191188

192189
// Check schema
193190
transformSchema(dataset.schema, paramMap, logging = true)
@@ -206,7 +203,8 @@ abstract class PredictionModel[FeaturesType, M <: PredictionModel[FeaturesType,
206203
val pred: FeaturesType => Double = (features) => {
207204
tmpModel.predict(features)
208205
}
209-
dataset.select(Star(None), pred.call(map(featuresCol).attr) as map(predictionCol))
206+
dataset.select($"*",
207+
callUDF(pred, DoubleType, dataset(map(featuresCol))).as(map(predictionCol)))
210208
} else {
211209
this.logWarning(s"$uid: Predictor.transform() was called as NOOP" +
212210
" since no output columns were set.")

mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ import org.apache.spark.annotation.AlphaComponent
2121
import org.apache.spark.ml.param.{Params, ParamMap, HasMaxIter, HasRegParam}
2222
import org.apache.spark.mllib.linalg.{BLAS, Vector}
2323
import org.apache.spark.mllib.regression.LinearRegressionWithSGD
24-
import org.apache.spark.sql._
24+
import org.apache.spark.sql.DataFrame
2525
import org.apache.spark.storage.StorageLevel
2626

2727

@@ -47,10 +47,10 @@ class LinearRegression extends Regressor[Vector, LinearRegression, LinearRegress
4747
def setRegParam(value: Double): this.type = set(regParam, value)
4848
def setMaxIter(value: Int): this.type = set(maxIter, value)
4949

50-
override protected def train(dataset: SchemaRDD, paramMap: ParamMap): LinearRegressionModel = {
50+
override protected def train(dataset: DataFrame, paramMap: ParamMap): LinearRegressionModel = {
5151
// Extract columns from data. If dataset is persisted, do not persist oldDataset.
5252
val oldDataset = extractLabeledPoints(dataset, paramMap)
53-
val handlePersistence = dataset.getStorageLevel == StorageLevel.NONE
53+
val handlePersistence = dataset.rdd.getStorageLevel == StorageLevel.NONE
5454
if (handlePersistence) {
5555
oldDataset.persist(StorageLevel.MEMORY_AND_DISK)
5656
}

mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java

Lines changed: 7 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919

2020
import java.io.Serializable;
2121
import java.lang.Math;
22-
import java.util.ArrayList;
2322
import java.util.List;
2423

2524
import org.junit.After;
@@ -28,12 +27,11 @@
2827

2928
import org.apache.spark.api.java.JavaRDD;
3029
import org.apache.spark.api.java.JavaSparkContext;
31-
import org.apache.spark.mllib.regression.LabeledPoint;
32-
import org.apache.spark.sql.DataFrame;
33-
import org.apache.spark.sql.SQLContext;
3430
import static org.apache.spark.mllib.classification.LogisticRegressionSuite.generateLogisticInputAsList;
3531
import org.apache.spark.mllib.linalg.Vector;
3632
import org.apache.spark.mllib.regression.LabeledPoint;
33+
import org.apache.spark.sql.DataFrame;
34+
import org.apache.spark.sql.SQLContext;
3735
import org.apache.spark.sql.Row;
3836

3937

@@ -50,11 +48,7 @@ public class JavaLogisticRegressionSuite implements Serializable {
5048
public void setUp() {
5149
jsc = new JavaSparkContext("local", "JavaLogisticRegressionSuite");
5250
jsql = new SQLContext(jsc);
53-
List<LabeledPoint> points = new ArrayList<LabeledPoint>();
54-
for (org.apache.spark.mllib.regression.LabeledPoint lp:
55-
generateLogisticInputAsList(1.0, 1.0, 100, 42)) {
56-
points.add(new LabeledPoint(lp.label(), lp.features()));
57-
}
51+
List<LabeledPoint> points = generateLogisticInputAsList(1.0, 1.0, 100, 42);
5852
datasetRDD = jsc.parallelize(points, 2);
5953
dataset = jsql.applySchema(datasetRDD, LabeledPoint.class);
6054
dataset.registerTempTable("dataset");
@@ -98,21 +92,14 @@ public void logisticRegressionWithSetters() {
9892
// Modify model params, and check that the params worked.
9993
model.setThreshold(1.0);
10094
model.transform(dataset).registerTempTable("predAllZero");
101-
SchemaRDD predAllZero = jsql.sql("SELECT prediction, myProbability FROM predAllZero");
95+
DataFrame predAllZero = jsql.sql("SELECT prediction, myProbability FROM predAllZero");
10296
for (Row r: predAllZero.collectAsList()) {
10397
assert(r.getDouble(0) == 0.0);
10498
}
10599
// Call transform with params, and check that the params worked.
106-
/* TODO: USE THIS
107-
model.transform(dataset, model.threshold().w(0.8)) // overwrite threshold
108-
.registerTempTable("prediction");
109-
DataFrame predictions = jsql.sql("SELECT label, score, prediction FROM prediction");
110-
predictions.collectAsList();
111-
*/
112-
113100
model.transform(dataset, model.threshold().w(0.0), model.probabilityCol().w("myProb"))
114101
.registerTempTable("predNotAllZero");
115-
SchemaRDD predNotAllZero = jsql.sql("SELECT prediction, myProb FROM predNotAllZero");
102+
DataFrame predNotAllZero = jsql.sql("SELECT prediction, myProb FROM predNotAllZero");
116103
boolean foundNonZero = false;
117104
for (Row r: predNotAllZero.collectAsList()) {
118105
if (r.getDouble(0) != 0.0) foundNonZero = true;
@@ -137,7 +124,7 @@ public void logisticRegressionPredictorClassifierMethods() {
137124
assert(model.numClasses() == 2);
138125

139126
model.transform(dataset).registerTempTable("transformed");
140-
SchemaRDD trans1 = jsql.sql("SELECT rawPrediction, probability FROM transformed");
127+
DataFrame trans1 = jsql.sql("SELECT rawPrediction, probability FROM transformed");
141128
for (Row row: trans1.collect()) {
142129
Vector raw = (Vector)row.get(0);
143130
Vector prob = (Vector)row.get(1);
@@ -148,7 +135,7 @@ public void logisticRegressionPredictorClassifierMethods() {
148135
assert(Math.abs(prob.apply(0) - (1.0 - probFromRaw1)) < eps);
149136
}
150137

151-
SchemaRDD trans2 = jsql.sql("SELECT prediction, probability FROM transformed");
138+
DataFrame trans2 = jsql.sql("SELECT prediction, probability FROM transformed");
152139
for (Row row: trans2.collect()) {
153140
double pred = row.getDouble(0);
154141
Vector prob = (Vector)row.get(1);

0 commit comments

Comments
 (0)