1717
1818package org .apache .spark .mllib .tree .model
1919
20+ import scala .collection .mutable
21+
22+ import org .apache .spark .SparkContext
2023import org .apache .spark .annotation .Experimental
2124import org .apache .spark .api .java .JavaRDD
2225import org .apache .spark .mllib .linalg .Vector
26+ import org .apache .spark .mllib .tree .configuration .{Algo , FeatureType }
2327import org .apache .spark .mllib .tree .configuration .Algo ._
28+ import org .apache .spark .mllib .util .{Loader , Saveable }
2429import org .apache .spark .rdd .RDD
30+ import org .apache .spark .sql .{DataFrame , Row , SQLContext }
2531
2632/**
2733 * :: Experimental ::
@@ -31,7 +37,7 @@ import org.apache.spark.rdd.RDD
3137 * @param algo algorithm type -- classification or regression
3238 */
3339@ Experimental
34- class DecisionTreeModel (val topNode : Node , val algo : Algo ) extends Serializable {
40+ class DecisionTreeModel (val topNode : Node , val algo : Algo ) extends Serializable with Saveable {
3541
3642 /**
3743 * Predict values for a single data point using the model trained.
@@ -98,4 +104,251 @@ class DecisionTreeModel(val topNode: Node, val algo: Algo) extends Serializable
98104 header + topNode.subtreeToString(2 )
99105 }
100106
107+ override def save (sc : SparkContext , path : String ): Unit = {
108+ DecisionTreeModel .SaveLoadV1_0 .save(sc, path, this )
109+ }
110+
111+ override protected def formatVersion : String = " 1.0"
112+ }
113+
114+ object DecisionTreeModel extends Loader [DecisionTreeModel ] {
115+
116+ /**
117+ * Iterator which does a DFS traversal (left to right) of a decision tree.
118+ *
119+ * Note: This is private[ml] to permit unit tests.
120+ */
121+ private [mllib] class NodeIterator (model : DecisionTreeModel ) extends Iterator [Node ] {
122+
123+ /**
124+ * FILO stack of Nodes during our DFS.
125+ * The top Node is returned by next().
126+ * Any Node on the queue is either a leaf or has children whom we have not yet visited.
127+ * This is empty once all Nodes have been traversed.
128+ */
129+ val nodeTrace : mutable.Stack [Node ] = new mutable.Stack [Node ]()
130+
131+ nodeTrace.push(model.topNode)
132+
133+ override def hasNext : Boolean = nodeTrace.nonEmpty
134+
135+ /**
136+ * Produces the next element of this iterator.
137+ * If [[hasNext ]] is false, then this throws an exception.
138+ */
139+ override def next (): Node = {
140+ if (nodeTrace.isEmpty) {
141+ throw new Exception (
142+ " DecisionTreeModel.NodeIterator.next() was called, but no more elements remain." )
143+ }
144+ val n = nodeTrace.pop()
145+ if (! n.isLeaf) {
146+ // n is a parent
147+ nodeTrace.push(n.rightNode.get, n.leftNode.get)
148+ }
149+ n
150+ }
151+ }
152+
153+ private object SaveLoadV1_0 {
154+
155+ def thisFormatVersion = " 1.0"
156+
157+ def thisClassName = " org.apache.spark.mllib.tree.DecisionTreeModel"
158+
159+ private case class PredictData (predict : Double , prob : Double )
160+
161+ private object PredictData {
162+ def apply (p : Predict ): PredictData = PredictData (p.predict, p.prob)
163+ }
164+
165+ private case class InformationGainStatsData (
166+ gain : Double ,
167+ impurity : Double ,
168+ leftImpurity : Double ,
169+ rightImpurity : Double ,
170+ leftPredict : PredictData ,
171+ rightPredict : PredictData )
172+
173+ private object InformationGainStatsData {
174+ def apply (i : InformationGainStats ): InformationGainStatsData = {
175+ InformationGainStatsData (i.gain, i.impurity, i.leftImpurity, i.rightImpurity,
176+ PredictData (i.leftPredict), PredictData (i.rightPredict))
177+ }
178+ }
179+
180+ private case class SplitData (
181+ feature : Int ,
182+ threshold : Double ,
183+ featureType : Int ,
184+ categories : Seq [Double ]) // TODO: Change to List once SPARK-3365 is fixed
185+
186+ private object SplitData {
187+ def apply (s : Split ): SplitData = {
188+ SplitData (s.feature, s.threshold, s.featureType.id, s.categories)
189+ }
190+ }
191+
192+ /** Model data for model import/export */
193+ private case class NodeData (
194+ id : Int ,
195+ predict : PredictData ,
196+ impurity : Double ,
197+ isLeaf : Boolean ,
198+ split : Option [SplitData ],
199+ leftNodeId : Option [Int ],
200+ rightNodeId : Option [Int ],
201+ stats : Option [InformationGainStatsData ])
202+
203+ private object NodeData {
204+ def apply (n : Node ): NodeData = {
205+ NodeData (n.id, PredictData (n.predict), n.impurity, n.isLeaf, n.split.map(SplitData .apply),
206+ n.leftNode.map(_.id), n.rightNode.map(_.id), n.stats.map(InformationGainStatsData .apply))
207+ }
208+ }
209+
210+ def save (sc : SparkContext , path : String , model : DecisionTreeModel ): Unit = {
211+ val sqlContext = new SQLContext (sc)
212+ import sqlContext .implicits ._
213+
214+ // Create JSON metadata.
215+ val metadataRDD =
216+ sc.parallelize(Seq ((thisClassName, thisFormatVersion, model.algo.toString, model.numNodes)))
217+ .toDataFrame(" class" , " version" , " algo" , " numNodes" )
218+ metadataRDD.toJSON.repartition(1 ).saveAsTextFile(Loader .metadataPath(path))
219+
220+ // Create Parquet data.
221+ val nodeIterator = new DecisionTreeModel .NodeIterator (model)
222+ val dataRDD : DataFrame = sc.parallelize(nodeIterator.toSeq).map(NodeData .apply)
223+ dataRDD.saveAsParquetFile(Loader .dataPath(path))
224+ }
225+
226+ def load (sc : SparkContext , path : String , algo : String , numNodes : Int ): DecisionTreeModel = {
227+ val datapath = Loader .dataPath(path)
228+ val sqlContext = new SQLContext (sc)
229+ // Load Parquet data.
230+ val dataRDD = sqlContext.parquetFile(datapath)
231+ // Check schema explicitly since erasure makes it hard to use match-case for checking.
232+ Loader .checkSchema[NodeData ](dataRDD.schema)
233+ // TODO: Extract save/load for 1 tree so that it can be reused for ensembles?
234+ val splitsRDD : RDD [Option [Split ]] =
235+ dataRDD.select(" split.feature" , " split.threshold" , " split.featureType" , " split.categories" )
236+ .map { row : Row =>
237+ if (row.isNullAt(0 )) {
238+ None
239+ } else {
240+ row match {
241+ case Row (feature : Int , threshold : Double , featureType : Int , categories : Seq [_]) =>
242+ // Note: The type cast for categories is safe since we checked the schema.
243+ Some (Split (feature, threshold, FeatureType (featureType),
244+ categories.asInstanceOf [Seq [Double ]].toList))
245+ }
246+ }
247+ }
248+ val lrChildNodesRDD : RDD [Option [(Int , Int )]] =
249+ dataRDD.select(" leftNodeId" , " rightNodeId" ).map { row : Row =>
250+ if (row.isNullAt(0 )) {
251+ None
252+ } else {
253+ row match {
254+ case Row (leftNodeId : Int , rightNodeId : Int ) =>
255+ Some ((leftNodeId, rightNodeId))
256+ }
257+ }
258+ }
259+ val gainStatsRDD : RDD [Option [InformationGainStats ]] = dataRDD.select(
260+ " stats.gain" , " stats.impurity" , " stats.leftImpurity" , " stats.rightImpurity" ,
261+ " stats.leftPredict.predict" , " stats.leftPredict.prob" ,
262+ " stats.rightPredict.predict" , " stats.rightPredict.prob" ).map { row : Row =>
263+ if (row.isNullAt(0 )) {
264+ None
265+ } else {
266+ row match {
267+ case Row (gain : Double , impurity : Double , leftImpurity : Double , rightImpurity : Double ,
268+ leftPredictPredict : Double , leftPredictProb : Double ,
269+ rightPredictPredict : Double , rightPredictProb : Double ) =>
270+ Some (new InformationGainStats (gain, impurity, leftImpurity, rightImpurity,
271+ new Predict (leftPredictPredict, leftPredictProb),
272+ new Predict (rightPredictPredict, rightPredictProb)))
273+ }
274+ }
275+ }
276+ // nodesRDD stores (Node, leftChildId, rightChildId) where the child ids are only relevant if
277+ // Node.isLeaf == false
278+ val nodesRDD : RDD [(Node , Int , Int )] =
279+ dataRDD.select(" id" , " predict.predict" , " predict.prob" , " impurity" , " isLeaf" ).rdd
280+ .zip(splitsRDD).zip(lrChildNodesRDD).zip(gainStatsRDD).map {
281+ case (((Row (id : Int , predictPredict : Double , predictProb : Double ,
282+ impurity : Double , isLeaf : Boolean ),
283+ split : Option [Split ]), lrChildNodes : Option [(Int , Int )]),
284+ gainStats : Option [InformationGainStats ]) =>
285+ val (leftChildId, rightChildId) = lrChildNodes.getOrElse((- 1 , - 1 ))
286+ (new Node (id, new Predict (predictPredict, predictProb), impurity, isLeaf,
287+ split, None , None , gainStats),
288+ leftChildId, rightChildId)
289+ }
290+ // Collect tree nodes, and build them into a tree.
291+ // nodesMap: node id -> (node, leftChild, rightChild)
292+ val nodesMap : Map [Int , (Node , Int , Int )] = nodesRDD.collect().map(n => n._1.id -> n).toMap
293+ assert(nodesMap.contains(1 ),
294+ s " DecisionTree missing root node (id = 1) after loading from: $datapath" )
295+ val topNode = nodesMap(1 )
296+ linkSubtree(topNode._1, topNode._2, topNode._3, nodesMap)
297+ assert(nodesMap.size == numNodes,
298+ s " Unable to load DecisionTreeModel data from: $datapath. " +
299+ s " Expected $numNodes nodes but found ${nodesMap.size}" )
300+ new DecisionTreeModel (topNode._1, Algo .fromString(algo))
301+ }
302+ }
303+
304+ /**
305+ * Link the given node to its children (if any), and recurse down the subtree.
306+ * @param node Node to link. Node.leftNode and Node.rightNode will be set if there are children.
307+ * @param leftChildId Id of left child. Ignored if node is a leaf.
308+ * @param rightChildId Id of right child. Ignored if node is a leaf.
309+ * @param nodesMap Map storing all nodes as a map: node id -> (Node, leftChildId, rightChildId).
310+ */
311+ private def linkSubtree (
312+ node : Node ,
313+ leftChildId : Int ,
314+ rightChildId : Int ,
315+ nodesMap : Map [Int , (Node , Int , Int )]): Unit = {
316+ if (node.isLeaf) return
317+ assert(nodesMap.contains(leftChildId),
318+ s " DecisionTreeModel.load could not find child (id= $leftChildId) of node ${node.id}. " )
319+ assert(nodesMap.contains(rightChildId),
320+ s " DecisionTreeModel.load could not find child (id= $rightChildId) of node ${node.id}. " )
321+ val leftChild = nodesMap(leftChildId)
322+ val rightChild = nodesMap(rightChildId)
323+ node.leftNode = Some (leftChild._1)
324+ node.rightNode = Some (rightChild._1)
325+ linkSubtree(leftChild._1, leftChild._2, leftChild._3, nodesMap)
326+ linkSubtree(rightChild._1, rightChild._2, rightChild._3, nodesMap)
327+ }
328+
329+ override def load (sc : SparkContext , path : String ): DecisionTreeModel = {
330+ val (loadedClassName, version, metadata) = Loader .loadMetadata(sc, path)
331+ val (algo : String , numNodes : Int ) = try {
332+ val algo_numNodes = metadata.select(" algo" , " numNodes" ).collect()
333+ assert(algo_numNodes.length == 1 )
334+ algo_numNodes(0 ) match {
335+ case Row (a : String , n : Int ) => (a, n)
336+ }
337+ } catch {
338+ // Catch both Error and Exception since the checks above can throw either.
339+ case e : Throwable =>
340+ throw new Exception (
341+ s " Unable to load DecisionTreeModel metadata from: ${Loader .metadataPath(path)}. "
342+ + s " Error message: ${e.getMessage}" )
343+ }
344+ val classNameV1_0 = SaveLoadV1_0 .thisClassName
345+ (loadedClassName, version) match {
346+ case (className, " 1.0" ) if className == classNameV1_0 =>
347+ SaveLoadV1_0 .load(sc, path, algo, numNodes)
348+ case _ => throw new Exception (
349+ s " DecisionTreeModel.load did not recognize model with (className, format version): " +
350+ s " ( $loadedClassName, $version). Supported: \n " +
351+ s " ( $classNameV1_0, 1.0) " )
352+ }
353+ }
101354}
0 commit comments