Skip to content

Commit 72aea62

Browse files
committed
remove TuningSummary
1 parent b6a7c53 commit 72aea62

File tree

8 files changed

+47
-124
lines changed

8 files changed

+47
-124
lines changed

examples/src/main/scala/org/apache/spark/examples/ml/ModelSelectionViaCrossValidationExample.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@ object ModelSelectionViaCrossValidationExample {
112112
.foreach { case Row(id: Long, text: String, prob: Vector, prediction: Double) =>
113113
println(s"($id, $text) --> prob=$prob, prediction=$prediction")
114114
}
115-
cvModel.summary.trainingMetrics.show()
115+
cvModel.tuningSummary.show()
116116
// $example off$
117117

118118
spark.stop()

examples/src/main/scala/org/apache/spark/examples/ml/ModelSelectionViaTrainValidationSplitExample.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ object ModelSelectionViaTrainValidationSplitExample {
7474
model.transform(test)
7575
.select("features", "label", "prediction")
7676
.show()
77-
model.summary.trainingMetrics.show()
77+
model.tuningSummary.show()
7878
// $example off$
7979

8080
spark.stop()

mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala

Lines changed: 4 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@ import com.github.fommil.netlib.F2jBLAS
2525
import org.apache.hadoop.fs.Path
2626
import org.json4s.DefaultFormats
2727

28-
import org.apache.spark.SparkException
2928
import org.apache.spark.annotation.Since
3029
import org.apache.spark.internal.Logging
3130
import org.apache.spark.ml._
@@ -134,10 +133,7 @@ class CrossValidator @Since("1.2.0") (@Since("1.4.0") override val uid: String)
134133
logInfo(s"Best cross-validation metric: $bestMetric.")
135134
val bestModel = est.fit(dataset, epm(bestIndex)).asInstanceOf[Model[_]]
136135
instr.logSuccess(bestModel)
137-
val model = new CrossValidatorModel(uid, bestModel, metrics).setParent(this)
138-
val summary = new TuningSummary(epm, metrics, bestIndex)
139-
model.setSummary(Some(summary))
140-
copyValues(model)
136+
copyValues(new CrossValidatorModel(uid, bestModel, metrics).setParent(this))
141137
}
142138

143139
@Since("1.4.0")
@@ -233,28 +229,12 @@ class CrossValidatorModel private[ml] (
233229
bestModel.transformSchema(schema)
234230
}
235231

236-
private var trainingSummary: Option[TuningSummary] = None
237-
238-
private[tuning] def setSummary(summary: Option[TuningSummary]): this.type = {
239-
this.trainingSummary = summary
240-
this
241-
}
242-
243-
/**
244-
* Return true if there exists summary of model.
245-
*/
246-
@Since("2.3.0")
247-
def hasSummary: Boolean = trainingSummary.nonEmpty
248-
249232
/**
250-
* Gets summary of model on training set. An exception is
251-
* thrown if `trainingSummary == None`.
233+
* Summary of grid search tuning in the format of DataFrame. Each row contains one candidate
234+
* paramMap and the corresponding metric of trained model.
252235
*/
253236
@Since("2.3.0")
254-
def summary: TuningSummary = trainingSummary.getOrElse {
255-
throw new SparkException(
256-
s"No training summary available for the ${this.getClass.getSimpleName}")
257-
}
237+
lazy val tuningSummary: DataFrame = this.getTuningSummaryDF(avgMetrics)
258238

259239
@Since("1.4.0")
260240
override def copy(extra: ParamMap): CrossValidatorModel = {

mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala

Lines changed: 8 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@ import scala.language.existentials
2525
import org.apache.hadoop.fs.Path
2626
import org.json4s.DefaultFormats
2727

28-
import org.apache.spark.SparkException
2928
import org.apache.spark.annotation.Since
3029
import org.apache.spark.internal.Logging
3130
import org.apache.spark.ml.{Estimator, Model}
@@ -129,10 +128,7 @@ class TrainValidationSplit @Since("1.5.0") (@Since("1.5.0") override val uid: St
129128
logInfo(s"Best train validation split metric: $bestMetric.")
130129
val bestModel = est.fit(dataset, epm(bestIndex)).asInstanceOf[Model[_]]
131130
instr.logSuccess(bestModel)
132-
val model = copyValues(new TrainValidationSplitModel(uid, bestModel, metrics).setParent(this))
133-
val summary = new TuningSummary(epm, metrics, bestIndex)
134-
model.setSummary(Some(summary))
135-
model
131+
copyValues(new TrainValidationSplitModel(uid, bestModel, metrics).setParent(this))
136132
}
137133

138134
@Since("1.5.0")
@@ -224,6 +220,13 @@ class TrainValidationSplitModel private[ml] (
224220
bestModel.transformSchema(schema)
225221
}
226222

223+
/**
224+
* Summary of grid search tuning in the format of DataFrame. Each row contains one candidate
225+
* paramMap and the corresponding metric of trained model.
226+
*/
227+
@Since("2.3.0")
228+
lazy val tuningSummary: DataFrame = this.getTuningSummaryDF(validationMetrics)
229+
227230
@Since("1.5.0")
228231
override def copy(extra: ParamMap): TrainValidationSplitModel = {
229232
val copied = new TrainValidationSplitModel (
@@ -235,29 +238,6 @@ class TrainValidationSplitModel private[ml] (
235238

236239
@Since("2.0.0")
237240
override def write: MLWriter = new TrainValidationSplitModel.TrainValidationSplitModelWriter(this)
238-
239-
private var trainingSummary: Option[TuningSummary] = None
240-
241-
private[tuning] def setSummary(summary: Option[TuningSummary]): this.type = {
242-
this.trainingSummary = summary
243-
this
244-
}
245-
246-
/**
247-
* Return true if there exists summary of model.
248-
*/
249-
@Since("2.3.0")
250-
def hasSummary: Boolean = trainingSummary.nonEmpty
251-
252-
/**
253-
* Gets summary of model on training set. An exception is
254-
* thrown if `trainingSummary == None`.
255-
*/
256-
@Since("2.3.0")
257-
def summary: TuningSummary = trainingSummary.getOrElse {
258-
throw new SparkException(
259-
s"No training summary available for the ${this.getClass.getSimpleName}")
260-
}
261241
}
262242

263243
@Since("2.0.0")

mllib/src/main/scala/org/apache/spark/ml/tuning/TuningSummary.scala

Lines changed: 0 additions & 58 deletions
This file was deleted.

mllib/src/main/scala/org/apache/spark/ml/tuning/ValidatorParams.scala

Lines changed: 29 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,12 +23,13 @@ import org.json4s.jackson.JsonMethods._
2323

2424
import org.apache.spark.SparkContext
2525
import org.apache.spark.ml.{Estimator, Model}
26-
import org.apache.spark.ml.evaluation.Evaluator
26+
import org.apache.spark.ml.evaluation.{BinaryClassificationEvaluator, Evaluator, MulticlassClassificationEvaluator, RegressionEvaluator}
2727
import org.apache.spark.ml.param.{Param, ParamMap, ParamPair, Params}
2828
import org.apache.spark.ml.param.shared.HasSeed
2929
import org.apache.spark.ml.util._
3030
import org.apache.spark.ml.util.DefaultParamsReader.Metadata
31-
import org.apache.spark.sql.types.StructType
31+
import org.apache.spark.sql.{DataFrame, Row, SparkSession}
32+
import org.apache.spark.sql.types.{StringType, StructField, StructType}
3233

3334
/**
3435
* Common params for [[TrainValidationSplitParams]] and [[CrossValidatorParams]].
@@ -85,6 +86,32 @@ private[ml] trait ValidatorParams extends HasSeed with Params {
8586
instrumentation.logNamedValue("evaluator", $(evaluator).getClass.getCanonicalName)
8687
instrumentation.logNamedValue("estimatorParamMapsLength", $(estimatorParamMaps).length)
8788
}
89+
90+
91+
/**
92+
* Summary of grid search tuning in the format of DataFrame. Each row contains one candidate
93+
* paramMap and the corresponding metric of trained model.
94+
*/
95+
protected def getTuningSummaryDF(metrics: Array[Double]): DataFrame = {
96+
val params = $(estimatorParamMaps)
97+
require(params.nonEmpty, "estimator param maps should not be empty")
98+
require(params.length == metrics.length, "estimator param maps number should match metrics")
99+
val metricName = $(evaluator) match {
100+
case b: BinaryClassificationEvaluator => b.getMetricName
101+
case m: MulticlassClassificationEvaluator => m.getMetricName
102+
case r: RegressionEvaluator => r.getMetricName
103+
case _ => "metrics"
104+
}
105+
val spark = SparkSession.builder().getOrCreate()
106+
val sc = spark.sparkContext
107+
val fields = params(0).toSeq.sortBy(_.param.name).map(_.param.name) ++ Seq(metricName)
108+
val schema = new StructType(fields.map(name => StructField(name, StringType)).toArray)
109+
val rows = sc.parallelize(params.zip(metrics)).map { case (param, metric) =>
110+
val values = param.toSeq.sortBy(_.param.name).map(_.value.toString) ++ Seq(metric.toString)
111+
Row.fromSeq(values)
112+
}
113+
spark.createDataFrame(rows, schema)
114+
}
88115
}
89116

90117
private[ml] object ValidatorParams {

mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -79,14 +79,11 @@ class CrossValidatorSuite
7979
.setEvaluator(eval)
8080
.setNumFolds(3)
8181
val cvModel = cv.fit(dataset)
82-
assert(cvModel.hasSummary)
83-
assert(cvModel.summary.params === lrParamMaps)
84-
assert(cvModel.summary.trainingMetrics.count() === lrParamMaps.length)
85-
8682
val expected = lrParamMaps.zip(cvModel.avgMetrics).map { case (map, metric) =>
8783
Row.fromSeq(map.toSeq.sortBy(_.param.name).map(_.value.toString) ++ Seq(metric.toString))
8884
}
89-
assert(cvModel.summary.trainingMetrics.collect().toSet === expected.toSet)
85+
assert(cvModel.tuningSummary.collect().toSet === expected.toSet)
86+
assert(cvModel.tuningSummary.columns.last === eval.getMetricName)
9087
}
9188

9289
test("cross validation with linear regression") {

mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -72,14 +72,11 @@ class TrainValidationSplitSuite
7272
.setEstimatorParamMaps(lrParamMaps)
7373
.setEvaluator(eval)
7474
val tvsModel = tvs.fit(dataset)
75-
assert(tvsModel.hasSummary)
76-
assert(tvsModel.summary.params === lrParamMaps)
77-
assert(tvsModel.summary.trainingMetrics.count() === lrParamMaps.length)
78-
7975
val expected = lrParamMaps.zip(tvsModel.validationMetrics).map { case (map, metric) =>
8076
Row.fromSeq(map.toSeq.sortBy(_.param.name).map(_.value.toString) ++ Seq(metric.toString))
8177
}
82-
assert(tvsModel.summary.trainingMetrics.collect().toSet === expected.toSet)
78+
assert(tvsModel.tuningSummary.collect().toSet === expected.toSet)
79+
assert(tvsModel.tuningSummary.columns.last === eval.getMetricName)
8380
}
8481

8582
test("train validation with linear regression") {

0 commit comments

Comments
 (0)