File tree Expand file tree Collapse file tree 3 files changed +7
-4
lines changed
main/scala/org/apache/spark/ml/tuning
test/scala/org/apache/spark/ml/tuning Expand file tree Collapse file tree 3 files changed +7
-4
lines changed Original file line number Diff line number Diff line change @@ -105,15 +105,18 @@ private[ml] trait ValidatorParams extends HasSeed with Params {
105105 val spark = SparkSession .builder().getOrCreate()
106106 val sc = spark.sparkContext
107107
108+ // collect related params since paramMaps does not necessarily contain the same set of params.
108109 val tuningParamPairs = paramMaps.flatMap(map => map.toSeq)
109110 val tuningParams = tuningParamPairs.map(_.param.asInstanceOf [Param [Any ]]).distinct
110- .sortBy(_.name)
111111 val schema = new StructType (tuningParams.map(p => StructField (p.toString, StringType ))
112112 ++ Array (StructField (metricName, DoubleType )))
113+
114+ // get param values in paramMap, as well as the default values if not in paramMap.
113115 val rows = paramMaps.zip(metrics).map { case (pMap, metric) =>
114116 val est = $(estimator).copy(pMap)
115117 val values = tuningParams.map { param =>
116118 est match {
119+ // get param value in stages if est is a Pipeline.
117120 case pipeline : Pipeline =>
118121 val candidates = pipeline.getStages.flatMap { stage =>
119122 stage.extractParamMap().get(param)
Original file line number Diff line number Diff line change @@ -89,7 +89,7 @@ class CrossValidatorSuite
8989 .setNumFolds(3 )
9090 val cvModel = cv.fit(dataset)
9191 val expected = lrParamMaps.zip(cvModel.avgMetrics).map { case (map, metric) =>
92- Row .fromSeq(map.toSeq.sortBy(_.param.toString). map(_.value.toString) ++ Seq (metric))
92+ Row .fromSeq(map.toSeq.map(_.value.toString) ++ Seq (metric))
9393 }
9494 assert(cvModel.tuningSummary.collect().toSet === expected.toSet)
9595 assert(cvModel.tuningSummary.columns.last === eval.getMetricName)
@@ -114,7 +114,7 @@ class CrossValidatorSuite
114114 .setNumFolds(3 )
115115 val cvModel = cv.fit(dataset)
116116 val expected = lrParamMaps.zip(cvModel.avgMetrics).map { case (map, metric) =>
117- Row .fromSeq(map.toSeq.sortBy(_.param.name). map(_.value.toString) ++ Seq (metric))
117+ Row .fromSeq(map.toSeq.map(_.value.toString) ++ Seq (metric))
118118 }
119119 assert(cvModel.tuningSummary.collect().toSet === expected.toSet)
120120 assert(cvModel.tuningSummary.columns.last === eval.getMetricName)
Original file line number Diff line number Diff line change @@ -88,7 +88,7 @@ class TrainValidationSplitSuite
8888 .setEvaluator(eval)
8989 val tvsModel = tvs.fit(dataset)
9090 val expected = lrParamMaps.zip(tvsModel.validationMetrics).map { case (map, metric) =>
91- Row .fromSeq(map.toSeq.sortBy(_.param.name). map(_.value.toString) ++ Seq (metric))
91+ Row .fromSeq(map.toSeq.map(_.value.toString) ++ Seq (metric))
9292 }
9393 assert(tvsModel.tuningSummary.collect().toSet === expected.toSet)
9494 assert(tvsModel.tuningSummary.columns.last === eval.getMetricName)
You can’t perform that action at this time.
0 commit comments