Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ private[classification] trait LogisticRegressionParams extends ProbabilisticClas
final val family: Param[String] = new Param(this, "family",
"The name of family which is a description of the label distribution to be used in the " +
s"model. Supported options: ${supportedFamilyNames.mkString(", ")}.",
ParamValidators.inArray[String](supportedFamilyNames))
(value: String) => supportedFamilyNames.contains(value.toLowerCase(Locale.ROOT)))

/** @group getParam */
@Since("2.1.0")
Expand Down Expand Up @@ -526,7 +526,7 @@ class LogisticRegression @Since("1.2.0") (
case None => histogram.length
}

val isMultinomial = $(family) match {
val isMultinomial = getFamily.toLowerCase(Locale.ROOT) match {
case "binomial" =>
require(numClasses == 1 || numClasses == 2, s"Binomial family only supports 1 or 2 " +
s"outcome classes but found $numClasses.")
Expand Down
30 changes: 15 additions & 15 deletions mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala
Original file line number Diff line number Diff line change
Expand Up @@ -174,8 +174,7 @@ private[clustering] trait LDAParams extends Params with HasFeaturesCol with HasM
@Since("1.6.0")
final val optimizer = new Param[String](this, "optimizer", "Optimizer or inference" +
" algorithm used to estimate the LDA model. Supported: " + supportedOptimizers.mkString(", "),
(o: String) =>
ParamValidators.inArray(supportedOptimizers).apply(o.toLowerCase(Locale.ROOT)))
(value: String) => supportedOptimizers.contains(value.toLowerCase(Locale.ROOT)))

/** @group getParam */
@Since("1.6.0")
Expand Down Expand Up @@ -325,7 +324,7 @@ private[clustering] trait LDAParams extends Params with HasFeaturesCol with HasM
s" ${getDocConcentration.length}, but k = $getK. docConcentration must be an array of" +
s" length either 1 (scalar) or k (num topics).")
}
getOptimizer match {
getOptimizer.toLowerCase(Locale.ROOT) match {
case "online" =>
require(getDocConcentration.forall(_ >= 0),
"For Online LDA optimizer, docConcentration values must be >= 0. Found values: " +
Expand All @@ -337,7 +336,7 @@ private[clustering] trait LDAParams extends Params with HasFeaturesCol with HasM
}
}
if (isSet(topicConcentration)) {
getOptimizer match {
getOptimizer.toLowerCase(Locale.ROOT) match {
case "online" =>
require(getTopicConcentration >= 0, s"For Online LDA optimizer, topicConcentration" +
s" must be >= 0. Found value: $getTopicConcentration")
Expand All @@ -350,17 +349,18 @@ private[clustering] trait LDAParams extends Params with HasFeaturesCol with HasM
SchemaUtils.appendColumn(schema, $(topicDistributionCol), new VectorUDT)
}

private[clustering] def getOldOptimizer: OldLDAOptimizer = getOptimizer match {
case "online" =>
new OldOnlineLDAOptimizer()
.setTau0($(learningOffset))
.setKappa($(learningDecay))
.setMiniBatchFraction($(subsamplingRate))
.setOptimizeDocConcentration($(optimizeDocConcentration))
case "em" =>
new OldEMLDAOptimizer()
.setKeepLastCheckpoint($(keepLastCheckpoint))
}
private[clustering] def getOldOptimizer: OldLDAOptimizer =
getOptimizer.toLowerCase(Locale.ROOT) match {
case "online" =>
new OldOnlineLDAOptimizer()
.setTau0($(learningOffset))
.setKappa($(learningDecay))
.setMiniBatchFraction($(subsamplingRate))
.setOptimizeDocConcentration($(optimizeDocConcentration))
case "em" =>
new OldEMLDAOptimizer()
.setKeepLastCheckpoint($(keepLastCheckpoint))
}
}

private object LDAParams {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2582,6 +2582,17 @@ class LogisticRegressionSuite
assert(expected.coefficients.toArray === actual.coefficients.toArray)
}
}

test("string params should be case-insensitive") {
val lr = new LogisticRegression()
Seq(("AuTo", smallBinaryDataset), ("biNoMial", smallBinaryDataset),
("mulTinomIAl", smallMultinomialDataset)).foreach { case (family, data) =>
lr.setFamily(family)
assert(lr.getFamily === family)
val model = lr.fit(data)
assert(model.getFamily === family)
}
}
}

object LogisticRegressionSuite {
Expand Down
10 changes: 10 additions & 0 deletions mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -313,4 +313,14 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext with DefaultRead

assert(model.getCheckpointFiles.isEmpty)
}

test("string params should be case-insensitive") {
val lda = new LDA()
Seq("eM", "oNLinE").foreach { optimizer =>
lda.setOptimizer(optimizer)
assert(lda.getOptimizer === optimizer)
val model = lda.fit(dataset)
assert(model.getOptimizer === optimizer)
}
}
}