@@ -19,32 +19,27 @@ package org.apache.spark.ml.tuning
1919
2020import com .github .fommil .netlib .F2jBLAS
2121import org .apache .hadoop .fs .Path
22- import org .json4s .{DefaultFormats , JObject }
23- import org .json4s .jackson .JsonMethods ._
22+ import org .json4s .DefaultFormats
2423
25- import org .apache .spark .SparkContext
2624import org .apache .spark .annotation .{Experimental , Since }
2725import org .apache .spark .internal .Logging
2826import org .apache .spark .ml ._
29- import org .apache .spark .ml .classification .OneVsRestParams
3027import org .apache .spark .ml .evaluation .Evaluator
31- import org .apache .spark .ml .feature .RFormulaModel
3228import org .apache .spark .ml .param ._
3329import org .apache .spark .ml .param .shared .HasSeed
3430import org .apache .spark .ml .util ._
35- import org .apache .spark .ml .util .DefaultParamsReader .Metadata
3631import org .apache .spark .mllib .util .MLUtils
3732import org .apache .spark .sql .DataFrame
3833import org .apache .spark .sql .types .StructType
3934
40-
4135/**
4236 * Params for [[CrossValidator ]] and [[CrossValidatorModel ]].
4337 */
4438private [ml] trait CrossValidatorParams extends ValidatorParams with HasSeed {
4539 /**
4640 * Param for number of folds for cross validation. Must be >= 2.
4741 * Default: 3
42+ *
4843 * @group param
4944 */
5045 val numFolds : IntParam = new IntParam (this , " numFolds" ,
@@ -163,10 +158,10 @@ object CrossValidator extends MLReadable[CrossValidator] {
163158
164159 private [CrossValidator ] class CrossValidatorWriter (instance : CrossValidator ) extends MLWriter {
165160
166- SharedReadWrite .validateParams(instance)
161+ ValidatorParams .validateParams(instance)
167162
168163 override protected def saveImpl (path : String ): Unit =
169- SharedReadWrite .saveImpl(path, instance, sc)
164+ ValidatorParams .saveImpl(path, instance, sc)
170165 }
171166
172167 private class CrossValidatorReader extends MLReader [CrossValidator ] {
@@ -175,132 +170,18 @@ object CrossValidator extends MLReadable[CrossValidator] {
175170 private val className = classOf [CrossValidator ].getName
176171
177172 override def load (path : String ): CrossValidator = {
178- val (metadata, estimator, evaluator, estimatorParamMaps, numFolds) =
179- SharedReadWrite .load(path, sc, className)
173+ implicit val format = DefaultFormats
174+
175+ val (metadata, estimator, evaluator, estimatorParamMaps) =
176+ ValidatorParams .loadImpl(path, sc, className)
177+ val numFolds = (metadata.params \ " numFolds" ).extract[Int ]
180178 new CrossValidator (metadata.uid)
181179 .setEstimator(estimator)
182180 .setEvaluator(evaluator)
183181 .setEstimatorParamMaps(estimatorParamMaps)
184182 .setNumFolds(numFolds)
185183 }
186184 }
187-
188- private object CrossValidatorReader {
189- /**
190- * Examine the given estimator (which may be a compound estimator) and extract a mapping
191- * from UIDs to corresponding [[Params ]] instances.
192- */
193- def getUidMap (instance : Params ): Map [String , Params ] = {
194- val uidList = getUidMapImpl(instance)
195- val uidMap = uidList.toMap
196- if (uidList.size != uidMap.size) {
197- throw new RuntimeException (" CrossValidator.load found a compound estimator with stages" +
198- s " with duplicate UIDs. List of UIDs: ${uidList.map(_._1).mkString(" , " )}" )
199- }
200- uidMap
201- }
202-
203- def getUidMapImpl (instance : Params ): List [(String , Params )] = {
204- val subStages : Array [Params ] = instance match {
205- case p : Pipeline => p.getStages.asInstanceOf [Array [Params ]]
206- case pm : PipelineModel => pm.stages.asInstanceOf [Array [Params ]]
207- case v : ValidatorParams => Array (v.getEstimator, v.getEvaluator)
208- case ovr : OneVsRestParams =>
209- // TODO: SPARK-11892: This case may require special handling.
210- throw new UnsupportedOperationException (" CrossValidator write will fail because it" +
211- " cannot yet handle an estimator containing type: ${ovr.getClass.getName}" )
212- case rformModel : RFormulaModel => Array (rformModel.pipelineModel)
213- case _ : Params => Array ()
214- }
215- val subStageMaps = subStages.map(getUidMapImpl).foldLeft(List .empty[(String , Params )])(_ ++ _)
216- List ((instance.uid, instance)) ++ subStageMaps
217- }
218- }
219-
220- private [tuning] object SharedReadWrite {
221-
222- /**
223- * Check that [[CrossValidator.evaluator ]] and [[CrossValidator.estimator ]] are Writable.
224- * This does not check [[CrossValidator.estimatorParamMaps ]].
225- */
226- def validateParams (instance : ValidatorParams ): Unit = {
227- def checkElement (elem : Params , name : String ): Unit = elem match {
228- case stage : MLWritable => // good
229- case other =>
230- throw new UnsupportedOperationException (" CrossValidator write will fail " +
231- s " because it contains $name which does not implement Writable. " +
232- s " Non-Writable $name: ${other.uid} of type ${other.getClass}" )
233- }
234- checkElement(instance.getEvaluator, " evaluator" )
235- checkElement(instance.getEstimator, " estimator" )
236- // Check to make sure all Params apply to this estimator. Throw an error if any do not.
237- // Extraneous Params would cause problems when loading the estimatorParamMaps.
238- val uidToInstance : Map [String , Params ] = CrossValidatorReader .getUidMap(instance)
239- instance.getEstimatorParamMaps.foreach { case pMap : ParamMap =>
240- pMap.toSeq.foreach { case ParamPair (p, v) =>
241- require(uidToInstance.contains(p.parent), s " CrossValidator save requires all Params in " +
242- s " estimatorParamMaps to apply to this CrossValidator, its Estimator, or its " +
243- s " Evaluator. An extraneous Param was found: $p" )
244- }
245- }
246- }
247-
248- private [tuning] def saveImpl (
249- path : String ,
250- instance : CrossValidatorParams ,
251- sc : SparkContext ,
252- extraMetadata : Option [JObject ] = None ): Unit = {
253- import org .json4s .JsonDSL ._
254-
255- val estimatorParamMapsJson = compact(render(
256- instance.getEstimatorParamMaps.map { case paramMap =>
257- paramMap.toSeq.map { case ParamPair (p, v) =>
258- Map (" parent" -> p.parent, " name" -> p.name, " value" -> p.jsonEncode(v))
259- }
260- }.toSeq
261- ))
262- val jsonParams = List (
263- " numFolds" -> parse(instance.numFolds.jsonEncode(instance.getNumFolds)),
264- " estimatorParamMaps" -> parse(estimatorParamMapsJson)
265- )
266- DefaultParamsWriter .saveMetadata(instance, path, sc, extraMetadata, Some (jsonParams))
267-
268- val evaluatorPath = new Path (path, " evaluator" ).toString
269- instance.getEvaluator.asInstanceOf [MLWritable ].save(evaluatorPath)
270- val estimatorPath = new Path (path, " estimator" ).toString
271- instance.getEstimator.asInstanceOf [MLWritable ].save(estimatorPath)
272- }
273-
274- private [tuning] def load [M <: Model [M ]](
275- path : String ,
276- sc : SparkContext ,
277- expectedClassName : String ): (Metadata , Estimator [M ], Evaluator , Array [ParamMap ], Int ) = {
278-
279- val metadata = DefaultParamsReader .loadMetadata(path, sc, expectedClassName)
280-
281- implicit val format = DefaultFormats
282- val evaluatorPath = new Path (path, " evaluator" ).toString
283- val evaluator = DefaultParamsReader .loadParamsInstance[Evaluator ](evaluatorPath, sc)
284- val estimatorPath = new Path (path, " estimator" ).toString
285- val estimator = DefaultParamsReader .loadParamsInstance[Estimator [M ]](estimatorPath, sc)
286-
287- val uidToParams = Map (evaluator.uid -> evaluator) ++ CrossValidatorReader .getUidMap(estimator)
288-
289- val numFolds = (metadata.params \ " numFolds" ).extract[Int ]
290- val estimatorParamMaps : Array [ParamMap ] =
291- (metadata.params \ " estimatorParamMaps" ).extract[Seq [Seq [Map [String , String ]]]].map {
292- pMap =>
293- val paramPairs = pMap.map { case pInfo : Map [String , String ] =>
294- val est = uidToParams(pInfo(" parent" ))
295- val param = est.getParam(pInfo(" name" ))
296- val value = param.jsonDecode(pInfo(" value" ))
297- param -> value
298- }
299- ParamMap (paramPairs : _* )
300- }.toArray
301- (metadata, estimator, evaluator, estimatorParamMaps, numFolds)
302- }
303- }
304185}
305186
306187/**
@@ -346,8 +227,6 @@ class CrossValidatorModel private[ml] (
346227@ Since (" 1.6.0" )
347228object CrossValidatorModel extends MLReadable [CrossValidatorModel ] {
348229
349- import CrossValidator .SharedReadWrite
350-
351230 @ Since (" 1.6.0" )
352231 override def read : MLReader [CrossValidatorModel ] = new CrossValidatorModelReader
353232
@@ -357,12 +236,12 @@ object CrossValidatorModel extends MLReadable[CrossValidatorModel] {
357236 private [CrossValidatorModel ]
358237 class CrossValidatorModelWriter (instance : CrossValidatorModel ) extends MLWriter {
359238
360- SharedReadWrite .validateParams(instance)
239+ ValidatorParams .validateParams(instance)
361240
362241 override protected def saveImpl (path : String ): Unit = {
363242 import org .json4s .JsonDSL ._
364243 val extraMetadata = " avgMetrics" -> instance.avgMetrics.toSeq
365- SharedReadWrite .saveImpl(path, instance, sc, Some (extraMetadata))
244+ ValidatorParams .saveImpl(path, instance, sc, Some (extraMetadata))
366245 val bestModelPath = new Path (path, " bestModel" ).toString
367246 instance.bestModel.asInstanceOf [MLWritable ].save(bestModelPath)
368247 }
@@ -376,8 +255,9 @@ object CrossValidatorModel extends MLReadable[CrossValidatorModel] {
376255 override def load (path : String ): CrossValidatorModel = {
377256 implicit val format = DefaultFormats
378257
379- val (metadata, estimator, evaluator, estimatorParamMaps, numFolds) =
380- SharedReadWrite .load(path, sc, className)
258+ val (metadata, estimator, evaluator, estimatorParamMaps) =
259+ ValidatorParams .loadImpl(path, sc, className)
260+ val numFolds = (metadata.params \ " numFolds" ).extract[Int ]
381261 val bestModelPath = new Path (path, " bestModel" ).toString
382262 val bestModel = DefaultParamsReader .loadParamsInstance[Model [_]](bestModelPath, sc)
383263 val avgMetrics = (metadata.metadata \ " avgMetrics" ).extract[Seq [Double ]].toArray
0 commit comments