-
Notifications
You must be signed in to change notification settings - Fork 28.9k
[SPARK-2550][MLLIB][APACHE SPARK] Support regularization and intercept in pyspark's linear methods #1775
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[SPARK-2550][MLLIB][APACHE SPARK] Support regularization and intercept in pyspark's linear methods #1775
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -271,6 +271,7 @@ class PythonMLLibAPI extends Serializable { | |
| .setNumIterations(numIterations) | ||
| .setRegParam(regParam) | ||
| .setStepSize(stepSize) | ||
| .setMiniBatchFraction(miniBatchFraction) | ||
| if (regType == "l2") { | ||
| lrAlg.optimizer.setUpdater(new SquaredL2Updater) | ||
| } else if (regType == "l1") { | ||
|
|
@@ -341,16 +342,27 @@ class PythonMLLibAPI extends Serializable { | |
| stepSize: Double, | ||
| regParam: Double, | ||
| miniBatchFraction: Double, | ||
| initialWeightsBA: Array[Byte]): java.util.List[java.lang.Object] = { | ||
| initialWeightsBA: Array[Byte], | ||
| regType: String, | ||
| intercept: Boolean): java.util.List[java.lang.Object] = { | ||
| val SVMAlg = new SVMWithSGD() | ||
| SVMAlg.setIntercept(intercept) | ||
| SVMAlg.optimizer | ||
| .setNumIterations(numIterations) | ||
| .setRegParam(regParam) | ||
| .setStepSize(stepSize) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You forgot to set miniBatchFraction here
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks! I am fixing it right now! |
||
| .setMiniBatchFraction(miniBatchFraction) | ||
| if (regType == "l2") { | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Py4j will pass through Python So you could do: |
||
| SVMAlg.optimizer.setUpdater(new SquaredL2Updater) | ||
| } else if (regType == "l1") { | ||
| SVMAlg.optimizer.setUpdater(new L1Updater) | ||
| } else if (regType != "none") { | ||
| throw new java.lang.IllegalArgumentException("Invalid value for 'regType' parameter." | ||
| + " Can only be initialized using the following string values: [l1, l2, none].") | ||
| } | ||
| trainRegressionModel( | ||
| (data, initialWeights) => | ||
| SVMWithSGD.train( | ||
| data, | ||
| numIterations, | ||
| stepSize, | ||
| regParam, | ||
| miniBatchFraction, | ||
| initialWeights), | ||
| SVMAlg.run(data, initialWeights), | ||
| dataBytesJRDD, | ||
| initialWeightsBA) | ||
| } | ||
|
|
@@ -363,15 +375,28 @@ class PythonMLLibAPI extends Serializable { | |
| numIterations: Int, | ||
| stepSize: Double, | ||
| miniBatchFraction: Double, | ||
| initialWeightsBA: Array[Byte]): java.util.List[java.lang.Object] = { | ||
| initialWeightsBA: Array[Byte], | ||
| regParam: Double, | ||
| regType: String, | ||
| intercept: Boolean): java.util.List[java.lang.Object] = { | ||
| val LogRegAlg = new LogisticRegressionWithSGD() | ||
| LogRegAlg.setIntercept(intercept) | ||
| LogRegAlg.optimizer | ||
| .setNumIterations(numIterations) | ||
| .setRegParam(regParam) | ||
| .setStepSize(stepSize) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. miniBatchFraction missing here too |
||
| .setMiniBatchFraction(miniBatchFraction) | ||
| if (regType == "l2") { | ||
| LogRegAlg.optimizer.setUpdater(new SquaredL2Updater) | ||
| } else if (regType == "l1") { | ||
| LogRegAlg.optimizer.setUpdater(new L1Updater) | ||
| } else if (regType != "none") { | ||
| throw new java.lang.IllegalArgumentException("Invalid value for 'regType' parameter." | ||
| + " Can only be initialized using the following string values: [l1, l2, none].") | ||
| } | ||
| trainRegressionModel( | ||
| (data, initialWeights) => | ||
| LogisticRegressionWithSGD.train( | ||
| data, | ||
| numIterations, | ||
| stepSize, | ||
| miniBatchFraction, | ||
| initialWeights), | ||
| LogRegAlg.run(data, initialWeights), | ||
| dataBytesJRDD, | ||
| initialWeightsBA) | ||
| } | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -73,11 +73,36 @@ def predict(self, x): | |
|
|
||
| class LogisticRegressionWithSGD(object): | ||
| @classmethod | ||
| def train(cls, data, iterations=100, step=1.0, miniBatchFraction=1.0, initialWeights=None): | ||
| """Train a logistic regression model on the given data.""" | ||
| def train(cls, data, iterations=100, step=1.0, miniBatchFraction=1.0, | ||
| initialWeights=None, regParam=1.0, regType=None, intercept=False): | ||
| """ | ||
| Train a logistic regression model on the given data. | ||
|
|
||
| @param data: The training data. | ||
| @param iterations: The number of iterations (default: 100). | ||
| @param step: The step parameter used in SGD | ||
| (default: 1.0). | ||
| @param miniBatchFraction: Fraction of data to be used for each SGD | ||
| iteration. | ||
| @param initialWeights: The initial weights (default: None). | ||
| @param regParam: The regularizer parameter (default: 1.0). | ||
| @param regType: The type of regularizer used for training | ||
| our model. | ||
| Allowed values: "l1" for using L1Updater, | ||
| "l2" for using | ||
| SquaredL2Updater, | ||
| "none" for no regularizer. | ||
| (default: "none") | ||
| @param intercept: Boolean parameter which indicates the use | ||
| or not of the augmented representation for | ||
| training data (i.e. whether bias features | ||
| are activated or not). | ||
| """ | ||
| sc = data.context | ||
| if regType is None: | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. As per above comment, you can just pass
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Xiangrui suggested to keep Scala code as simple as possible and only to throw the
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ok fair enough (@mengxr I wasn't suggesting enumerations, just a pattern match on the
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same here. No strong feelings either way.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Either way - this if-else code is now repeated 3 times (twice here, once for regression). Might be worth refactoring into a getUpdater function?— On Tue, Aug 5, 2014 at 7:48 AM, Michael Giannakopoulos
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ok! I will update the code with the suggested approach of MLnick...
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @miccagiann I think it is fine to keep it as it is now. I'm merging this for v1.1 QA. We can update the code style later. |
||
| regType = "none" | ||
| train_func = lambda d, i: sc._jvm.PythonMLLibAPI().trainLogisticRegressionModelWithSGD( | ||
| d._jrdd, iterations, step, miniBatchFraction, i) | ||
| d._jrdd, iterations, step, miniBatchFraction, i, regParam, regType, intercept) | ||
| return _regression_train_wrapper(sc, train_func, LogisticRegressionModel, data, | ||
| initialWeights) | ||
|
|
||
|
|
@@ -115,11 +140,35 @@ def predict(self, x): | |
| class SVMWithSGD(object): | ||
| @classmethod | ||
| def train(cls, data, iterations=100, step=1.0, regParam=1.0, | ||
| miniBatchFraction=1.0, initialWeights=None): | ||
| """Train a support vector machine on the given data.""" | ||
| miniBatchFraction=1.0, initialWeights=None, regType=None, intercept=False): | ||
| """ | ||
| Train a support vector machine on the given data. | ||
|
|
||
| @param data: The training data. | ||
| @param iterations: The number of iterations (default: 100). | ||
| @param step: The step parameter used in SGD | ||
| (default: 1.0). | ||
| @param regParam: The regularizer parameter (default: 1.0). | ||
| @param miniBatchFraction: Fraction of data to be used for each SGD | ||
| iteration. | ||
| @param initialWeights: The initial weights (default: None). | ||
| @param regType: The type of regularizer used for training | ||
| our model. | ||
| Allowed values: "l1" for using L1Updater, | ||
| "l2" for using | ||
| SquaredL2Updater, | ||
| "none" for no regularizer. | ||
| (default: "none") | ||
| @param intercept: Boolean parameter which indicates the use | ||
| or not of the augmented representation for | ||
| training data (i.e. whether bias features | ||
| are activated or not). | ||
| """ | ||
| sc = data.context | ||
| if regType is None: | ||
| regType = "none" | ||
| train_func = lambda d, i: sc._jvm.PythonMLLibAPI().trainSVMModelWithSGD( | ||
| d._jrdd, iterations, step, regParam, miniBatchFraction, i) | ||
| d._jrdd, iterations, step, regParam, miniBatchFraction, i, regType, intercept) | ||
| return _regression_train_wrapper(sc, train_func, SVMModel, data, initialWeights) | ||
|
|
||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also maybe prefer to do the pattern matching on
regTypebefore this, and do something like: