Skip to content

Commit f549e34

Browse files
committed
Updates based on code review. Major ones are:
* Created weakly typed Predictor.train() method which is called by fit() so that developers do not have to call schema validation or copy parameters. * Made Predictor.featuresDataType have a default value of VectorUDT. * NOTE: This could be dangerous since the FeaturesType type parameter cannot have a default value.
1 parent 343e7bd commit f549e34

File tree

8 files changed

+65
-72
lines changed

8 files changed

+65
-72
lines changed

examples/src/main/scala/org/apache/spark/examples/ml/CrossValidatorExample.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ object CrossValidatorExample {
104104
.select('id, 'text, 'probability, 'prediction)
105105
.collect()
106106
.foreach { case Row(id: Long, text: String, prob: Vector, prediction: Double) =>
107-
println("(" + id + ", " + text + ") --> prob=" + prob + ", prediction=" + prediction)
107+
println(s"($id, $text) --> prob=$prob, prediction=$prediction")
108108
}
109109

110110
sc.stop()

examples/src/main/scala/org/apache/spark/examples/ml/DeveloperApiExample.scala

Lines changed: 17 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,9 @@ import org.apache.spark.{SparkConf, SparkContext}
2121
import org.apache.spark.SparkContext._
2222
import org.apache.spark.ml.classification.{Classifier, ClassifierParams, ClassificationModel}
2323
import org.apache.spark.ml.param.{Params, IntParam, ParamMap}
24-
import org.apache.spark.mllib.linalg.{BLAS, Vector, Vectors, VectorUDT}
24+
import org.apache.spark.mllib.linalg.{BLAS, Vector, Vectors}
2525
import org.apache.spark.mllib.regression.LabeledPoint
26-
import org.apache.spark.sql.{DataType, SchemaRDD, Row, SQLContext}
26+
import org.apache.spark.sql.{SchemaRDD, Row, SQLContext}
2727

2828
/**
2929
* A simple example demonstrating how to write your own learning algorithm using Estimator,
@@ -85,7 +85,14 @@ object DeveloperApiExample {
8585
*/
8686
private trait MyLogisticRegressionParams extends ClassifierParams {
8787

88-
/** param for max number of iterations */
88+
/**
89+
* Param for max number of iterations
90+
*
91+
* NOTE: The usual way to add a parameter to a model or algorithm is to include:
92+
* - val myParamName: ParamType
93+
* - def getMyParamName
94+
* - def setMyParamName
95+
*/
8996
val maxIter: IntParam = new IntParam(this, "maxIter", "max number of iterations")
9097
def getMaxIter: Int = get(maxIter)
9198
}
@@ -101,40 +108,23 @@ private class MyLogisticRegression
101108

102109
setMaxIter(100) // Initialize
103110

111+
// The parameter setter is in this class since it should return type MyLogisticRegression.
104112
def setMaxIter(value: Int): this.type = set(maxIter, value)
105113

106-
override def fit(dataset: SchemaRDD, paramMap: ParamMap): MyLogisticRegressionModel = {
107-
// Check schema (types). This allows early failure before running the algorithm.
108-
transformSchema(dataset.schema, paramMap, logging = true)
109-
114+
// This method is used by fit()
115+
override protected def train(
116+
dataset: SchemaRDD,
117+
paramMap: ParamMap): MyLogisticRegressionModel = {
110118
// Extract columns from data using helper method.
111119
val oldDataset = extractLabeledPoints(dataset, paramMap)
112120

113-
// Combine given parameters with the embedded parameters, where the given paramMap overrides
114-
// any embedded settings.
115-
val map = this.paramMap ++ paramMap
116-
117121
// Do learning to estimate the weight vector.
118122
val numFeatures = oldDataset.take(1)(0).features.size
119123
val weights = Vectors.zeros(numFeatures) // Learning would happen here.
120124

121-
// Create a model to return.
122-
val lrm = new MyLogisticRegressionModel(this, map, weights)
123-
124-
// Copy model params.
125-
// An Estimator stores the parameters for the Model it produces, and this copies any relevant
126-
// parameters to the model.
127-
Params.inheritValues(map, this, lrm)
128-
129-
// Return the learned model.
130-
lrm
125+
// Create a model, and return it.
126+
new MyLogisticRegressionModel(this, paramMap, weights)
131127
}
132-
133-
/**
134-
* Returns the SQL DataType corresponding to the FeaturesType type parameter.
135-
* This is used by [[ClassifierParams.validateAndTransformSchema()]] to check the input data.
136-
*/
137-
override protected def featuresDataType: DataType = new VectorUDT
138128
}
139129

140130
/**
@@ -186,10 +176,4 @@ private class MyLogisticRegressionModel(
186176
Params.inheritValues(this.paramMap, this, m)
187177
m
188178
}
189-
190-
/**
191-
* Returns the SQL DataType corresponding to the FeaturesType type parameter.
192-
* This is used by [[ClassifierParams.validateAndTransformSchema()]] to check the input data.
193-
*/
194-
override protected def featuresDataType: DataType = new VectorUDT
195179
}

examples/src/main/scala/org/apache/spark/examples/ml/SimpleParamsExample.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ object SimpleParamsExample {
9494
.select('features, 'label, 'myProbability, 'prediction)
9595
.collect()
9696
.foreach { case Row(features: Vector, label: Double, prob: Vector, prediction: Double) =>
97-
println("(" + features + ", " + label + ") -> prob=" + prob + ", prediction=" + prediction)
97+
println("($features, $label) -> prob=$prob, prediction=$prediction")
9898
}
9999

100100
sc.stop()

examples/src/main/scala/org/apache/spark/examples/ml/SimpleTextClassificationPipeline.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ object SimpleTextClassificationPipeline {
8383
.select('id, 'text, 'probability, 'prediction)
8484
.collect()
8585
.foreach { case Row(id: Long, text: String, prob: Vector, prediction: Double) =>
86-
println("(" + id + ", " + text + ") --> prob=" + prob + ", prediction=" + prediction)
86+
println("($id, $text) --> prob=$prob, prediction=$prediction")
8787
}
8888

8989
sc.stop()

mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala

Lines changed: 5 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ package org.apache.spark.ml.classification
2020
import org.apache.spark.annotation.AlphaComponent
2121
import org.apache.spark.ml.param._
2222
import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS
23-
import org.apache.spark.mllib.linalg.{BLAS, Vector, VectorUDT, Vectors}
23+
import org.apache.spark.mllib.linalg.{BLAS, Vector, Vectors}
2424
import org.apache.spark.sql._
2525
import org.apache.spark.sql.Dsl._
2626
import org.apache.spark.sql.types.{DoubleType, StructField, StructType}
@@ -52,13 +52,9 @@ class LogisticRegression
5252
def setMaxIter(value: Int): this.type = set(maxIter, value)
5353
def setThreshold(value: Double): this.type = set(threshold, value)
5454

55-
override def fit(dataset: SchemaRDD, paramMap: ParamMap): LogisticRegressionModel = {
56-
// Check schema
57-
transformSchema(dataset.schema, paramMap, logging = true)
58-
55+
override protected def train(dataset: SchemaRDD, paramMap: ParamMap): LogisticRegressionModel = {
5956
// Extract columns from data. If dataset is persisted, do not persist oldDataset.
6057
val oldDataset = extractLabeledPoints(dataset, paramMap)
61-
val map = this.paramMap ++ paramMap
6258
val handlePersistence = dataset.getStorageLevel == StorageLevel.NONE
6359
if (handlePersistence) {
6460
oldDataset.persist(StorageLevel.MEMORY_AND_DISK)
@@ -67,21 +63,16 @@ class LogisticRegression
6763
// Train model
6864
val lr = new LogisticRegressionWithLBFGS
6965
lr.optimizer
70-
.setRegParam(map(regParam))
71-
.setNumIterations(map(maxIter))
66+
.setRegParam(paramMap(regParam))
67+
.setNumIterations(paramMap(maxIter))
7268
val oldModel = lr.run(oldDataset)
73-
val lrm = new LogisticRegressionModel(this, map, oldModel.weights, oldModel.intercept)
69+
val lrm = new LogisticRegressionModel(this, paramMap, oldModel.weights, oldModel.intercept)
7470

7571
if (handlePersistence) {
7672
oldDataset.unpersist()
7773
}
78-
79-
// copy model params
80-
Params.inheritValues(map, this, lrm)
8174
lrm
8275
}
83-
84-
override protected def featuresDataType: DataType = new VectorUDT
8576
}
8677

8778

@@ -215,6 +206,4 @@ class LogisticRegressionModel private[ml] (
215206
Params.inheritValues(this.paramMap, this, m)
216207
m
217208
}
218-
219-
override protected def featuresDataType: DataType = new VectorUDT
220209
}

mllib/src/main/scala/org/apache/spark/ml/impl/estimator/Predictor.scala

Lines changed: 32 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ package org.apache.spark.ml.impl.estimator
2020
import org.apache.spark.annotation.{AlphaComponent, DeveloperApi}
2121
import org.apache.spark.ml.{Estimator, Model}
2222
import org.apache.spark.ml.param._
23-
import org.apache.spark.mllib.linalg.Vector
23+
import org.apache.spark.mllib.linalg.{VectorUDT, Vector}
2424
import org.apache.spark.mllib.regression.LabeledPoint
2525
import org.apache.spark.rdd.RDD
2626
import org.apache.spark.sql._
@@ -84,16 +84,43 @@ abstract class Predictor[
8484
def setFeaturesCol(value: String): Learner = set(featuresCol, value).asInstanceOf[Learner]
8585
def setPredictionCol(value: String): Learner = set(predictionCol, value).asInstanceOf[Learner]
8686

87+
override def fit(dataset: SchemaRDD, paramMap: ParamMap): M = {
88+
// This handles a few items such as schema validation.
89+
// Developers only need to implement train().
90+
transformSchema(dataset.schema, paramMap, logging = true)
91+
val map = this.paramMap ++ paramMap
92+
val model = train(dataset, map)
93+
Params.inheritValues(map, this, model) // copy params to model
94+
model
95+
}
96+
97+
/**
98+
* :: DeveloperApi ::
99+
*
100+
* Train a model using the given dataset and parameters.
101+
* Developers can implement this instead of [[fit()]] to avoid dealing with schema validation
102+
* and copying parameters into the model.
103+
*
104+
* @param dataset Training dataset
105+
* @param paramMap Parameter map. Unlike [[fit()]]'s paramMap, this paramMap has already
106+
* been combined with the embedded ParamMap.
107+
* @return Fitted model
108+
*/
109+
@DeveloperApi
110+
protected def train(dataset: SchemaRDD, paramMap: ParamMap): M
111+
87112
/**
88113
* :: DeveloperApi ::
89114
*
90115
* Returns the SQL DataType corresponding to the FeaturesType type parameter.
91116
*
92117
* This is used by [[validateAndTransformSchema()]].
93118
* This workaround is needed since SQL has different APIs for Scala and Java.
119+
*
120+
* The default value is VectorUDT, but it may be overridden if FeaturesType is not Vector.
94121
*/
95122
@DeveloperApi
96-
protected def featuresDataType: DataType
123+
protected def featuresDataType: DataType = new VectorUDT
97124

98125
private[ml] override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
99126
validateAndTransformSchema(schema, paramMap, fitting = true, featuresDataType)
@@ -138,9 +165,11 @@ abstract class PredictionModel[FeaturesType, M <: PredictionModel[FeaturesType,
138165
*
139166
* This is used by [[validateAndTransformSchema()]].
140167
* This workaround is needed since SQL has different APIs for Scala and Java.
168+
*
169+
* The default value is VectorUDT, but it may be overridden if FeaturesType is not Vector.
141170
*/
142171
@DeveloperApi
143-
protected def featuresDataType: DataType
172+
protected def featuresDataType: DataType = new VectorUDT
144173

145174
private[ml] override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
146175
validateAndTransformSchema(schema, paramMap, fitting = false, featuresDataType)

mllib/src/main/scala/org/apache/spark/ml/param/sharedParams.scala

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,9 @@
1818
package org.apache.spark.ml.param
1919

2020
/* NOTE TO DEVELOPERS:
21-
* If you add these parameter traits into your algorithm, you need to add a setter method as well.
21+
* If you mix these parameter traits into your algorithm, please add a setter method as well
22+
* so that users may use a builder pattern:
23+
* val myLearner = new MyLearner().setParam1(x).setParam2(y)...
2224
*/
2325

2426
private[ml] trait HasRegParam extends Params {

mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala

Lines changed: 5 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ package org.apache.spark.ml.regression
1919

2020
import org.apache.spark.annotation.AlphaComponent
2121
import org.apache.spark.ml.param.{Params, ParamMap, HasMaxIter, HasRegParam}
22-
import org.apache.spark.mllib.linalg.{VectorUDT, BLAS, Vector}
22+
import org.apache.spark.mllib.linalg.{BLAS, Vector}
2323
import org.apache.spark.mllib.regression.LinearRegressionWithSGD
2424
import org.apache.spark.sql._
2525
import org.apache.spark.storage.StorageLevel
@@ -45,13 +45,9 @@ class LinearRegression extends Regressor[Vector, LinearRegression, LinearRegress
4545
def setRegParam(value: Double): this.type = set(regParam, value)
4646
def setMaxIter(value: Int): this.type = set(maxIter, value)
4747

48-
override def fit(dataset: SchemaRDD, paramMap: ParamMap): LinearRegressionModel = {
49-
// Check schema
50-
transformSchema(dataset.schema, paramMap, logging = true)
51-
48+
override protected def train(dataset: SchemaRDD, paramMap: ParamMap): LinearRegressionModel = {
5249
// Extract columns from data. If dataset is persisted, do not persist oldDataset.
5350
val oldDataset = extractLabeledPoints(dataset, paramMap)
54-
val map = this.paramMap ++ paramMap
5551
val handlePersistence = dataset.getStorageLevel == StorageLevel.NONE
5652
if (handlePersistence) {
5753
oldDataset.persist(StorageLevel.MEMORY_AND_DISK)
@@ -60,21 +56,16 @@ class LinearRegression extends Regressor[Vector, LinearRegression, LinearRegress
6056
// Train model
6157
val lr = new LinearRegressionWithSGD()
6258
lr.optimizer
63-
.setRegParam(map(regParam))
64-
.setNumIterations(map(maxIter))
59+
.setRegParam(paramMap(regParam))
60+
.setNumIterations(paramMap(maxIter))
6561
val model = lr.run(oldDataset)
66-
val lrm = new LinearRegressionModel(this, map, model.weights, model.intercept)
62+
val lrm = new LinearRegressionModel(this, paramMap, model.weights, model.intercept)
6763

6864
if (handlePersistence) {
6965
oldDataset.unpersist()
7066
}
71-
72-
// copy model params
73-
Params.inheritValues(map, this, lrm)
7467
lrm
7568
}
76-
77-
override protected def featuresDataType: DataType = new VectorUDT
7869
}
7970

8071
/**
@@ -100,6 +91,4 @@ class LinearRegressionModel private[ml] (
10091
Params.inheritValues(this.paramMap, this, m)
10192
m
10293
}
103-
104-
override protected def featuresDataType: DataType = new VectorUDT
10594
}

0 commit comments

Comments
 (0)