Skip to content

Commit 4aef3aa

Browse files
committed
remove sort add comments
1 parent 41c4c12 commit 4aef3aa

File tree

3 files changed

+7
-4
lines changed

3 files changed

+7
-4
lines changed

mllib/src/main/scala/org/apache/spark/ml/tuning/ValidatorParams.scala

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,15 +105,18 @@ private[ml] trait ValidatorParams extends HasSeed with Params {
105105
val spark = SparkSession.builder().getOrCreate()
106106
val sc = spark.sparkContext
107107

108+
// collect related params since paramMaps does not necessarily contain the same set of params.
108109
val tuningParamPairs = paramMaps.flatMap(map => map.toSeq)
109110
val tuningParams = tuningParamPairs.map(_.param.asInstanceOf[Param[Any]]).distinct
110-
.sortBy(_.name)
111111
val schema = new StructType(tuningParams.map(p => StructField(p.toString, StringType))
112112
++ Array(StructField(metricName, DoubleType)))
113+
114+
// get param values in paramMap, as well as the default values if not in paramMap.
113115
val rows = paramMaps.zip(metrics).map { case (pMap, metric) =>
114116
val est = $(estimator).copy(pMap)
115117
val values = tuningParams.map { param =>
116118
est match {
119+
// get param value in stages if est is a Pipeline.
117120
case pipeline: Pipeline =>
118121
val candidates = pipeline.getStages.flatMap { stage =>
119122
stage.extractParamMap().get(param)

mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ class CrossValidatorSuite
8989
.setNumFolds(3)
9090
val cvModel = cv.fit(dataset)
9191
val expected = lrParamMaps.zip(cvModel.avgMetrics).map { case (map, metric) =>
92-
Row.fromSeq(map.toSeq.sortBy(_.param.toString).map(_.value.toString) ++ Seq(metric))
92+
Row.fromSeq(map.toSeq.map(_.value.toString) ++ Seq(metric))
9393
}
9494
assert(cvModel.tuningSummary.collect().toSet === expected.toSet)
9595
assert(cvModel.tuningSummary.columns.last === eval.getMetricName)
@@ -114,7 +114,7 @@ class CrossValidatorSuite
114114
.setNumFolds(3)
115115
val cvModel = cv.fit(dataset)
116116
val expected = lrParamMaps.zip(cvModel.avgMetrics).map { case (map, metric) =>
117-
Row.fromSeq(map.toSeq.sortBy(_.param.name).map(_.value.toString) ++ Seq(metric))
117+
Row.fromSeq(map.toSeq.map(_.value.toString) ++ Seq(metric))
118118
}
119119
assert(cvModel.tuningSummary.collect().toSet === expected.toSet)
120120
assert(cvModel.tuningSummary.columns.last === eval.getMetricName)

mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ class TrainValidationSplitSuite
8888
.setEvaluator(eval)
8989
val tvsModel = tvs.fit(dataset)
9090
val expected = lrParamMaps.zip(tvsModel.validationMetrics).map { case (map, metric) =>
91-
Row.fromSeq(map.toSeq.sortBy(_.param.name).map(_.value.toString) ++ Seq(metric))
91+
Row.fromSeq(map.toSeq.map(_.value.toString) ++ Seq(metric))
9292
}
9393
assert(tvsModel.tuningSummary.collect().toSet === expected.toSet)
9494
assert(tvsModel.tuningSummary.columns.last === eval.getMetricName)

0 commit comments

Comments
 (0)