1717
1818package org .apache .spark .mllib .tree .model
1919
20+ import scala .collection .mutable
21+
22+ import org .apache .spark .SparkContext
2023import org .apache .spark .annotation .Experimental
2124import org .apache .spark .api .java .JavaRDD
2225import org .apache .spark .mllib .linalg .Vector
26+ import org .apache .spark .mllib .tree .configuration .{Algo , FeatureType }
2327import org .apache .spark .mllib .tree .configuration .Algo ._
28+ import org .apache .spark .mllib .util .{Loader , Saveable }
2429import org .apache .spark .rdd .RDD
30+ import org .apache .spark .sql .{DataFrame , Row , SQLContext }
2531
2632/**
2733 * :: Experimental ::
@@ -31,7 +37,7 @@ import org.apache.spark.rdd.RDD
3137 * @param algo algorithm type -- classification or regression
3238 */
3339@ Experimental
34- class DecisionTreeModel (val topNode : Node , val algo : Algo ) extends Serializable {
40+ class DecisionTreeModel (val topNode : Node , val algo : Algo ) extends Serializable with Saveable {
3541
3642 /**
3743 * Predict values for a single data point using the model trained.
@@ -98,4 +104,193 @@ class DecisionTreeModel(val topNode: Node, val algo: Algo) extends Serializable
98104 header + topNode.subtreeToString(2 )
99105 }
100106
107+ override def save (sc : SparkContext , path : String ): Unit = {
108+ DecisionTreeModel .SaveLoadV1_0 .save(sc, path, this )
109+ }
110+
111+ override protected def formatVersion : String = " 1.0"
112+ }
113+
114+ object DecisionTreeModel extends Loader [DecisionTreeModel ] {
115+
116+ private [tree] object SaveLoadV1_0 {
117+
118+ def thisFormatVersion = " 1.0"
119+
120+ // Hard-code class name string in case it changes in the future
121+ def thisClassName = " org.apache.spark.mllib.tree.DecisionTreeModel"
122+
123+ case class PredictData (predict : Double , prob : Double ) {
124+ def toPredict : Predict = new Predict (predict, prob)
125+ }
126+
127+ object PredictData {
128+ def apply (p : Predict ): PredictData = PredictData (p.predict, p.prob)
129+
130+ def apply (r : Row ): PredictData = PredictData (r.getDouble(0 ), r.getDouble(1 ))
131+ }
132+
133+ case class SplitData (
134+ feature : Int ,
135+ threshold : Double ,
136+ featureType : Int ,
137+ categories : Seq [Double ]) { // TODO: Change to List once SPARK-3365 is fixed
138+ def toSplit : Split = {
139+ new Split (feature, threshold, FeatureType (featureType), categories.toList)
140+ }
141+ }
142+
143+ object SplitData {
144+ def apply (s : Split ): SplitData = {
145+ SplitData (s.feature, s.threshold, s.featureType.id, s.categories)
146+ }
147+
148+ def apply (r : Row ): SplitData = {
149+ SplitData (r.getInt(0 ), r.getDouble(1 ), r.getInt(2 ), r.getAs[Seq [Double ]](3 ))
150+ }
151+ }
152+
153+ /** Model data for model import/export */
154+ case class NodeData (
155+ treeId : Int ,
156+ nodeId : Int ,
157+ predict : PredictData ,
158+ impurity : Double ,
159+ isLeaf : Boolean ,
160+ split : Option [SplitData ],
161+ leftNodeId : Option [Int ],
162+ rightNodeId : Option [Int ],
163+ infoGain : Option [Double ])
164+
165+ object NodeData {
166+ def apply (treeId : Int , n : Node ): NodeData = {
167+ NodeData (treeId, n.id, PredictData (n.predict), n.impurity, n.isLeaf,
168+ n.split.map(SplitData .apply), n.leftNode.map(_.id), n.rightNode.map(_.id),
169+ n.stats.map(_.gain))
170+ }
171+
172+ def apply (r : Row ): NodeData = {
173+ val split = if (r.isNullAt(5 )) None else Some (SplitData (r.getStruct(5 )))
174+ val leftNodeId = if (r.isNullAt(6 )) None else Some (r.getInt(6 ))
175+ val rightNodeId = if (r.isNullAt(7 )) None else Some (r.getInt(7 ))
176+ val infoGain = if (r.isNullAt(8 )) None else Some (r.getDouble(8 ))
177+ NodeData (r.getInt(0 ), r.getInt(1 ), PredictData (r.getStruct(2 )), r.getDouble(3 ),
178+ r.getBoolean(4 ), split, leftNodeId, rightNodeId, infoGain)
179+ }
180+ }
181+
182+ def save (sc : SparkContext , path : String , model : DecisionTreeModel ): Unit = {
183+ val sqlContext = new SQLContext (sc)
184+ import sqlContext .implicits ._
185+
186+ // Create JSON metadata.
187+ val metadataRDD = sc.parallelize(
188+ Seq ((thisClassName, thisFormatVersion, model.algo.toString, model.numNodes)), 1 )
189+ .toDataFrame(" class" , " version" , " algo" , " numNodes" )
190+ metadataRDD.toJSON.saveAsTextFile(Loader .metadataPath(path))
191+
192+ // Create Parquet data.
193+ val nodes = model.topNode.subtreeIterator.toSeq
194+ val dataRDD : DataFrame = sc.parallelize(nodes)
195+ .map(NodeData .apply(0 , _))
196+ .toDataFrame
197+ dataRDD.saveAsParquetFile(Loader .dataPath(path))
198+ }
199+
200+ def load (sc : SparkContext , path : String , algo : String , numNodes : Int ): DecisionTreeModel = {
201+ val datapath = Loader .dataPath(path)
202+ val sqlContext = new SQLContext (sc)
203+ // Load Parquet data.
204+ val dataRDD = sqlContext.parquetFile(datapath)
205+ // Check schema explicitly since erasure makes it hard to use match-case for checking.
206+ Loader .checkSchema[NodeData ](dataRDD.schema)
207+ val nodes = dataRDD.map(NodeData .apply)
208+ // Build node data into a tree.
209+ val trees = constructTrees(nodes)
210+ assert(trees.size == 1 ,
211+ " Decision tree should contain exactly one tree but got ${trees.size} trees." )
212+ val model = new DecisionTreeModel (trees(0 ), Algo .fromString(algo))
213+ assert(model.numNodes == numNodes, s " Unable to load DecisionTreeModel data from: $datapath. " +
214+ s " Expected $numNodes nodes but found ${model.numNodes}" )
215+ model
216+ }
217+
218+ def constructTrees (nodes : RDD [NodeData ]): Array [Node ] = {
219+ val trees = nodes
220+ .groupBy(_.treeId)
221+ .mapValues(_.toArray)
222+ .collect()
223+ .map { case (treeId, data) =>
224+ (treeId, constructTree(data))
225+ }.sortBy(_._1)
226+ val numTrees = trees.size
227+ val treeIndices = trees.map(_._1).toSeq
228+ assert(treeIndices == (0 until numTrees),
229+ s " Tree indices must start from 0 and increment by 1, but we found $treeIndices. " )
230+ trees.map(_._2)
231+ }
232+
233+ /**
234+ * Given a list of nodes from a tree, construct the tree.
235+ * @param data array of all node data in a tree.
236+ */
237+ def constructTree (data : Array [NodeData ]): Node = {
238+ val dataMap : Map [Int , NodeData ] = data.map(n => n.nodeId -> n).toMap
239+ assert(dataMap.contains(1 ),
240+ s " DecisionTree missing root node (id = 1). " )
241+ constructNode(1 , dataMap, mutable.Map .empty)
242+ }
243+
244+ /**
245+ * Builds a node from the node data map and adds new nodes to the input nodes map.
246+ */
247+ private def constructNode (
248+ id : Int ,
249+ dataMap : Map [Int , NodeData ],
250+ nodes : mutable.Map [Int , Node ]): Node = {
251+ if (nodes.contains(id)) {
252+ return nodes(id)
253+ }
254+ val data = dataMap(id)
255+ val node =
256+ if (data.isLeaf) {
257+ Node (data.nodeId, data.predict.toPredict, data.impurity, data.isLeaf)
258+ } else {
259+ val leftNode = constructNode(data.leftNodeId.get, dataMap, nodes)
260+ val rightNode = constructNode(data.rightNodeId.get, dataMap, nodes)
261+ val stats = new InformationGainStats (data.infoGain.get, data.impurity, leftNode.impurity,
262+ rightNode.impurity, leftNode.predict, rightNode.predict)
263+ new Node (data.nodeId, data.predict.toPredict, data.impurity, data.isLeaf,
264+ data.split.map(_.toSplit), Some (leftNode), Some (rightNode), Some (stats))
265+ }
266+ nodes += node.id -> node
267+ node
268+ }
269+ }
270+
271+ override def load (sc : SparkContext , path : String ): DecisionTreeModel = {
272+ val (loadedClassName, version, metadata) = Loader .loadMetadata(sc, path)
273+ val (algo : String , numNodes : Int ) = try {
274+ val algo_numNodes = metadata.select(" algo" , " numNodes" ).collect()
275+ assert(algo_numNodes.length == 1 )
276+ algo_numNodes(0 ) match {
277+ case Row (a : String , n : Int ) => (a, n)
278+ }
279+ } catch {
280+ // Catch both Error and Exception since the checks above can throw either.
281+ case e : Throwable =>
282+ throw new Exception (
283+ s " Unable to load DecisionTreeModel metadata from: ${Loader .metadataPath(path)}. "
284+ + s " Error message: ${e.getMessage}" )
285+ }
286+ val classNameV1_0 = SaveLoadV1_0 .thisClassName
287+ (loadedClassName, version) match {
288+ case (className, " 1.0" ) if className == classNameV1_0 =>
289+ SaveLoadV1_0 .load(sc, path, algo, numNodes)
290+ case _ => throw new Exception (
291+ s " DecisionTreeModel.load did not recognize model with (className, format version): " +
292+ s " ( $loadedClassName, $version). Supported: \n " +
293+ s " ( $classNameV1_0, 1.0) " )
294+ }
295+ }
101296}
0 commit comments