Skip to content

Commit 52f4fde

Browse files
committed
removing everything except for simple class hierarchy for classification
1 parent d35bb5d commit 52f4fde

17 files changed

+8
-717
lines changed

mllib/src/main/scala/org/apache/spark/ml/LabeledPoint.scala

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,14 +11,12 @@ import org.apache.spark.mllib.linalg.Vector
1111
*/
1212
case class LabeledPoint(label: Double, features: Vector, weight: Double) {
1313

14-
/** Default constructor which sets instance weight to 1.0 */
15-
def this(label: Double, features: Vector) = this(label, features, 1.0)
16-
1714
override def toString: String = {
1815
"(%s,%s,%s)".format(label, features, weight)
1916
}
2017
}
2118

2219
object LabeledPoint {
23-
def apply(label: Double, features: Vector) = new LabeledPoint(label, features)
20+
/** Constructor which sets instance weight to 1.0 */
21+
def apply(label: Double, features: Vector) = new LabeledPoint(label, features, 1.0)
2422
}

mllib/src/main/scala/org/apache/spark/ml/classification/AdaBoost.scala

Lines changed: 0 additions & 209 deletions
This file was deleted.

mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala

Lines changed: 4 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,8 @@
1818
package org.apache.spark.ml.classification
1919

2020
import org.apache.spark.annotation.AlphaComponent
21-
import org.apache.spark.ml.evaluation.ClassificationEvaluator
21+
import org.apache.spark.ml.impl.estimator.{PredictionModel, Predictor, PredictorParams}
2222
import org.apache.spark.mllib.linalg.Vector
23-
import org.apache.spark.ml._
24-
import org.apache.spark.ml.impl.estimator.{HasDefaultEvaluator, PredictionModel, Predictor,
25-
PredictorParams}
26-
import org.apache.spark.rdd.RDD
2723

2824
@AlphaComponent
2925
private[classification] trait ClassifierParams extends PredictorParams
@@ -33,10 +29,9 @@ private[classification] trait ClassifierParams extends PredictorParams
3329
*/
3430
abstract class Classifier[Learner <: Classifier[Learner, M], M <: ClassificationModel[M]]
3531
extends Predictor[Learner, M]
36-
with ClassifierParams
37-
with HasDefaultEvaluator {
32+
with ClassifierParams {
3833

39-
override def defaultEvaluator: Evaluator = new ClassificationEvaluator
34+
// TODO: defaultEvaluator (follow-up PR)
4035
}
4136

4237

@@ -60,14 +55,6 @@ private[ml] abstract class ClassificationModel[M <: ClassificationModel[M]]
6055
*/
6156
def predictRaw(features: Vector): Vector
6257

63-
/**
64-
* Compute this model's accuracy on the given dataset.
65-
*/
66-
def accuracy(dataset: RDD[LabeledPoint]): Double = {
67-
// TODO: Handle instance weights.
68-
val predictionsAndLabels = dataset.map(lp => predict(lp.features))
69-
.zip(dataset.map(_.label))
70-
ClassificationEvaluator.computeMetric(predictionsAndLabels, "accuracy")
71-
}
58+
// TODO: accuracy(dataset: RDD[LabeledPoint]): Double (follow-up PR)
7259

7360
}

mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,8 +58,6 @@ private[classification] trait LogisticRegressionParams extends ClassifierParams
5858
class LogisticRegression extends Classifier[LogisticRegression, LogisticRegressionModel]
5959
with LogisticRegressionParams {
6060

61-
// TODO: Extend IterativeEstimator
62-
6361
setRegParam(0.1)
6462
setMaxIter(100)
6563
setThreshold(0.5)

mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala

Lines changed: 0 additions & 67 deletions
This file was deleted.

mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala

Lines changed: 1 addition & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@ import org.apache.spark.annotation.AlphaComponent
2121
import org.apache.spark.ml._
2222
import org.apache.spark.ml.param._
2323
import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics
24-
import org.apache.spark.rdd.RDD
2524
import org.apache.spark.sql.{DataFrame, Row}
2625
import org.apache.spark.sql.types.DoubleType
2726

@@ -57,16 +56,8 @@ class BinaryClassificationEvaluator extends Evaluator with Params
5756
.map { case Row(score: Double, label: Double) =>
5857
(score, label)
5958
}
60-
BinaryClassificationEvaluator.computeMetric(scoreAndLabels, map(metricName))
61-
}
62-
63-
}
64-
65-
private[ml] object BinaryClassificationEvaluator {
66-
67-
def computeMetric(scoreAndLabels: RDD[(Double, Double)], metricName: String): Double = {
6859
val metrics = new BinaryClassificationMetrics(scoreAndLabels)
69-
val metric = metricName match {
60+
val metric = map(metricName) match {
7061
case "areaUnderROC" =>
7162
metrics.areaUnderROC()
7263
case "areaUnderPR" =>
@@ -77,5 +68,4 @@ private[ml] object BinaryClassificationEvaluator {
7768
metrics.unpersist()
7869
metric
7970
}
80-
8171
}

0 commit comments

Comments
 (0)