Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -271,6 +271,7 @@ class PythonMLLibAPI extends Serializable {
.setNumIterations(numIterations)
.setRegParam(regParam)
.setStepSize(stepSize)
.setMiniBatchFraction(miniBatchFraction)
if (regType == "l2") {
lrAlg.optimizer.setUpdater(new SquaredL2Updater)
} else if (regType == "l1") {
Expand Down Expand Up @@ -341,16 +342,27 @@ class PythonMLLibAPI extends Serializable {
stepSize: Double,
regParam: Double,
miniBatchFraction: Double,
initialWeightsBA: Array[Byte]): java.util.List[java.lang.Object] = {
initialWeightsBA: Array[Byte],
regType: String,
intercept: Boolean): java.util.List[java.lang.Object] = {
val SVMAlg = new SVMWithSGD()
SVMAlg.setIntercept(intercept)
SVMAlg.optimizer
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also maybe prefer to do the pattern matching on regType before this, and do something like:

val updater = Option(regType) match { 
...
}
optimizer
  .setUpdater(updater)
  .setNumIterations ...
}

.setNumIterations(numIterations)
.setRegParam(regParam)
.setStepSize(stepSize)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You forgot to set miniBatchFraction here

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! I am fixing it right now!

.setMiniBatchFraction(miniBatchFraction)
if (regType == "l2") {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Py4j will pass through Python None as null (at least it should, if I recall), so on the Java side you can wrap that in an Option instead of making it "none".

So you could do:

Option(regType) match {
  case Some("l1") => ...
  case Some("l2") => ...
  case Some(str) => throw Exception ...
  case None => ...
}

SVMAlg.optimizer.setUpdater(new SquaredL2Updater)
} else if (regType == "l1") {
SVMAlg.optimizer.setUpdater(new L1Updater)
} else if (regType != "none") {
throw new java.lang.IllegalArgumentException("Invalid value for 'regType' parameter."
+ " Can only be initialized using the following string values: [l1, l2, none].")
}
trainRegressionModel(
(data, initialWeights) =>
SVMWithSGD.train(
data,
numIterations,
stepSize,
regParam,
miniBatchFraction,
initialWeights),
SVMAlg.run(data, initialWeights),
dataBytesJRDD,
initialWeightsBA)
}
Expand All @@ -363,15 +375,28 @@ class PythonMLLibAPI extends Serializable {
numIterations: Int,
stepSize: Double,
miniBatchFraction: Double,
initialWeightsBA: Array[Byte]): java.util.List[java.lang.Object] = {
initialWeightsBA: Array[Byte],
regParam: Double,
regType: String,
intercept: Boolean): java.util.List[java.lang.Object] = {
val LogRegAlg = new LogisticRegressionWithSGD()
LogRegAlg.setIntercept(intercept)
LogRegAlg.optimizer
.setNumIterations(numIterations)
.setRegParam(regParam)
.setStepSize(stepSize)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

miniBatchFraction missing here too

.setMiniBatchFraction(miniBatchFraction)
if (regType == "l2") {
LogRegAlg.optimizer.setUpdater(new SquaredL2Updater)
} else if (regType == "l1") {
LogRegAlg.optimizer.setUpdater(new L1Updater)
} else if (regType != "none") {
throw new java.lang.IllegalArgumentException("Invalid value for 'regType' parameter."
+ " Can only be initialized using the following string values: [l1, l2, none].")
}
trainRegressionModel(
(data, initialWeights) =>
LogisticRegressionWithSGD.train(
data,
numIterations,
stepSize,
miniBatchFraction,
initialWeights),
LogRegAlg.run(data, initialWeights),
dataBytesJRDD,
initialWeightsBA)
}
Expand Down
61 changes: 55 additions & 6 deletions python/pyspark/mllib/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,11 +73,36 @@ def predict(self, x):

class LogisticRegressionWithSGD(object):
@classmethod
def train(cls, data, iterations=100, step=1.0, miniBatchFraction=1.0, initialWeights=None):
"""Train a logistic regression model on the given data."""
def train(cls, data, iterations=100, step=1.0, miniBatchFraction=1.0,
initialWeights=None, regParam=1.0, regType=None, intercept=False):
"""
Train a logistic regression model on the given data.

@param data: The training data.
@param iterations: The number of iterations (default: 100).
@param step: The step parameter used in SGD
(default: 1.0).
@param miniBatchFraction: Fraction of data to be used for each SGD
iteration.
@param initialWeights: The initial weights (default: None).
@param regParam: The regularizer parameter (default: 1.0).
@param regType: The type of regularizer used for training
our model.
Allowed values: "l1" for using L1Updater,
"l2" for using
SquaredL2Updater,
"none" for no regularizer.
(default: "none")
@param intercept: Boolean parameter which indicates the use
or not of the augmented representation for
training data (i.e. whether bias features
are activated or not).
"""
sc = data.context
if regType is None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As per above comment, you can just pass regType straight through if you then wrap the null in Option on the Scala/Java side.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Xiangrui suggested to keep Scala code as simple as possible and only to throw the IllegalArgumentException from there. I tried with pattern matching and by creating enumerations however the result was complicated and I ended up adding more classes to the scala and to python code.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok fair enough

(@mengxr I wasn't suggesting enumerations, just a pattern match on the Option[String] value as per comment. Don't believe this adds more code or complexity, but no strong feelings either way)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here. No strong feelings either way.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Either way - this if-else code is now repeated 3 times (twice here, once for regression). Might be worth refactoring into a getUpdater function?—
Sent from Mailbox

On Tue, Aug 5, 2014 at 7:48 AM, Michael Giannakopoulos
[email protected] wrote:

     sc = data.context
  •    if regType is None:
    

    Xiangrui suggested to keep Scala code as simple as possible and only to throw the IllegalArgumentException from there. I tried with pattern matching and by creating enumerations however the result was complicated and I ended up adding more classes to the scala and to python code.

    Reply to this email directly or view it on GitHub:
    https://github.com/apache/spark/pull/1775/files#r15795902

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok! I will update the code with the suggested approach of MLnick...

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@miccagiann I think it is fine to keep it as it is now. I'm merging this for v1.1 QA. We can update the code style later.

regType = "none"
train_func = lambda d, i: sc._jvm.PythonMLLibAPI().trainLogisticRegressionModelWithSGD(
d._jrdd, iterations, step, miniBatchFraction, i)
d._jrdd, iterations, step, miniBatchFraction, i, regParam, regType, intercept)
return _regression_train_wrapper(sc, train_func, LogisticRegressionModel, data,
initialWeights)

Expand Down Expand Up @@ -115,11 +140,35 @@ def predict(self, x):
class SVMWithSGD(object):
@classmethod
def train(cls, data, iterations=100, step=1.0, regParam=1.0,
miniBatchFraction=1.0, initialWeights=None):
"""Train a support vector machine on the given data."""
miniBatchFraction=1.0, initialWeights=None, regType=None, intercept=False):
"""
Train a support vector machine on the given data.

@param data: The training data.
@param iterations: The number of iterations (default: 100).
@param step: The step parameter used in SGD
(default: 1.0).
@param regParam: The regularizer parameter (default: 1.0).
@param miniBatchFraction: Fraction of data to be used for each SGD
iteration.
@param initialWeights: The initial weights (default: None).
@param regType: The type of regularizer used for training
our model.
Allowed values: "l1" for using L1Updater,
"l2" for using
SquaredL2Updater,
"none" for no regularizer.
(default: "none")
@param intercept: Boolean parameter which indicates the use
or not of the augmented representation for
training data (i.e. whether bias features
are activated or not).
"""
sc = data.context
if regType is None:
regType = "none"
train_func = lambda d, i: sc._jvm.PythonMLLibAPI().trainSVMModelWithSGD(
d._jrdd, iterations, step, regParam, miniBatchFraction, i)
d._jrdd, iterations, step, regParam, miniBatchFraction, i, regType, intercept)
return _regression_train_wrapper(sc, train_func, SVMModel, data, initialWeights)


Expand Down