Skip to content

Commit 8316d5e

Browse files
committed
fixes after rebasing on master
1 parent fc62406 commit 8316d5e

File tree

4 files changed

+6
-70
lines changed

4 files changed

+6
-70
lines changed

examples/src/main/scala/org/apache/spark/examples/ml/DeveloperApiExample.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -40,10 +40,10 @@ object DeveloperApiExample {
4040
val conf = new SparkConf().setAppName("DeveloperApiExample")
4141
val sc = new SparkContext(conf)
4242
val sqlContext = new SQLContext(sc)
43-
import sqlContext._
43+
import sqlContext.implicits._
4444

4545
// Prepare training data.
46-
val training = sparkContext.parallelize(Seq(
46+
val training = sc.parallelize(Seq(
4747
LabeledPoint(1.0, Vectors.dense(0.0, 1.1, 0.1)),
4848
LabeledPoint(0.0, Vectors.dense(2.0, 1.0, -1.0)),
4949
LabeledPoint(0.0, Vectors.dense(2.0, 1.3, 1.0)),
@@ -61,7 +61,7 @@ object DeveloperApiExample {
6161
val model = lr.fit(training)
6262

6363
// Prepare test data.
64-
val test = sparkContext.parallelize(Seq(
64+
val test = sc.parallelize(Seq(
6565
LabeledPoint(1.0, Vectors.dense(-1.0, 1.5, 1.3)),
6666
LabeledPoint(0.0, Vectors.dense(3.0, 2.0, -0.1)),
6767
LabeledPoint(1.0, Vectors.dense(0.0, 2.2, -1.5))))

examples/src/main/scala/org/apache/spark/examples/ml/SimpleParamsExample.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ object SimpleParamsExample {
8181
println("Model 2 was fit using parameters: " + model2.fittingParamMap)
8282

8383
// Prepare test data.
84-
val test = sparkContext.parallelize(Seq(
84+
val test = sc.parallelize(Seq(
8585
LabeledPoint(1.0, Vectors.dense(-1.0, 1.5, 1.3)),
8686
LabeledPoint(0.0, Vectors.dense(3.0, 2.0, -0.1)),
8787
LabeledPoint(1.0, Vectors.dense(0.0, 2.2, -1.5))))

mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -181,8 +181,8 @@ private[ml] object ClassificationModel {
181181
val raw2pred: Vector => Double = (rawPred) => {
182182
rawPred.toArray.zipWithIndex.maxBy(_._1)._2
183183
}
184-
tmpData = tmpData.select($"*",
185-
callUDF(raw2pred, col(map(model.rawPredictionCol))).as(map(model.predictionCol)))
184+
tmpData = tmpData.select($"*", callUDF(raw2pred, DoubleType,
185+
col(map(model.rawPredictionCol))).as(map(model.predictionCol)))
186186
numColsOutput += 1
187187
}
188188
} else if (map(model.predictionCol) != "") {

mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala

Lines changed: 0 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@ import org.apache.spark.ml.param._
2222
import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS
2323
import org.apache.spark.mllib.linalg.{BLAS, Vector, Vectors}
2424
import org.apache.spark.sql.DataFrame
25-
import org.apache.spark.sql.Dsl._
2625
import org.apache.spark.storage.StorageLevel
2726

2827

@@ -103,69 +102,6 @@ class LogisticRegressionModel private[ml] (
103102
1.0 / (1.0 + math.exp(-m))
104103
}
105104

106-
override def transform(dataset: DataFrame, paramMap: ParamMap): DataFrame = {
107-
// Check schema
108-
transformSchema(dataset.schema, paramMap, logging = true)
109-
110-
val map = this.paramMap ++ paramMap
111-
112-
// Output selected columns only.
113-
// This is a bit complicated since it tries to avoid repeated computation.
114-
// rawPrediction (-margin, margin)
115-
// probability (1.0-score, score)
116-
// prediction (max margin)
117-
var tmpData = dataset
118-
var numColsOutput = 0
119-
if (map(rawPredictionCol) != "") {
120-
val features2raw: Vector => Vector = predictRaw
121-
tmpData = tmpData.select($"*",
122-
callUDF(features2raw, col(map(featuresCol))).as(map(rawPredictionCol)))
123-
numColsOutput += 1
124-
}
125-
if (map(probabilityCol) != "") {
126-
if (map(rawPredictionCol) != "") {
127-
val raw2prob: Vector => Vector = (rawPreds) => {
128-
val prob1 = 1.0 / (1.0 + math.exp(-rawPreds(1)))
129-
Vectors.dense(1.0 - prob1, prob1)
130-
}
131-
tmpData = tmpData.select($"*",
132-
callUDF(raw2prob, col(map(rawPredictionCol))).as(map(probabilityCol)))
133-
} else {
134-
val features2prob: Vector => Vector = predictProbabilities
135-
tmpData = tmpData.select($"*",
136-
callUDF(features2prob, col(map(featuresCol))).as(map(probabilityCol)))
137-
}
138-
numColsOutput += 1
139-
}
140-
if (map(predictionCol) != "") {
141-
val t = map(threshold)
142-
if (map(probabilityCol) != "") {
143-
val predict: Vector => Double = (probs) => {
144-
if (probs(1) > t) 1.0 else 0.0
145-
}
146-
tmpData = tmpData.select($"*",
147-
callUDF(predict, col(map(probabilityCol))).as(map(predictionCol)))
148-
} else if (map(rawPredictionCol) != "") {
149-
val predict: Vector => Double = (rawPreds) => {
150-
val prob1 = 1.0 / (1.0 + math.exp(-rawPreds(1)))
151-
if (prob1 > t) 1.0 else 0.0
152-
}
153-
tmpData = tmpData.select($"*",
154-
callUDF(predict, col(map(rawPredictionCol))).as(map(predictionCol)))
155-
} else {
156-
val predict: Vector => Double = this.predict
157-
tmpData = tmpData.select($"*",
158-
callUDF(predict, col(map(featuresCol))).as(map(predictionCol)))
159-
}
160-
numColsOutput += 1
161-
}
162-
if (numColsOutput == 0) {
163-
this.logWarning(s"$uid: LogisticRegressionModel.transform() was called as NOOP" +
164-
" since no output columns were set.")
165-
}
166-
tmpData
167-
}
168-
169105
override val numClasses: Int = 2
170106

171107
/**

0 commit comments

Comments
 (0)