Skip to content

Commit dcdbf85

Browse files
committed
added save/load for decision tree but need to generalize it to ensembles
1 parent c4b1108 commit dcdbf85

File tree

4 files changed

+368
-4
lines changed

4 files changed

+368
-4
lines changed

mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala

Lines changed: 254 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,17 @@
1717

1818
package org.apache.spark.mllib.tree.model
1919

20+
import scala.collection.mutable
21+
22+
import org.apache.spark.SparkContext
2023
import org.apache.spark.annotation.Experimental
2124
import org.apache.spark.api.java.JavaRDD
2225
import org.apache.spark.mllib.linalg.Vector
26+
import org.apache.spark.mllib.tree.configuration.{Algo, FeatureType}
2327
import org.apache.spark.mllib.tree.configuration.Algo._
28+
import org.apache.spark.mllib.util.{Loader, Saveable}
2429
import org.apache.spark.rdd.RDD
30+
import org.apache.spark.sql.{DataFrame, Row, SQLContext}
2531

2632
/**
2733
* :: Experimental ::
@@ -31,7 +37,7 @@ import org.apache.spark.rdd.RDD
3137
* @param algo algorithm type -- classification or regression
3238
*/
3339
@Experimental
34-
class DecisionTreeModel(val topNode: Node, val algo: Algo) extends Serializable {
40+
class DecisionTreeModel(val topNode: Node, val algo: Algo) extends Serializable with Saveable {
3541

3642
/**
3743
* Predict values for a single data point using the model trained.
@@ -98,4 +104,251 @@ class DecisionTreeModel(val topNode: Node, val algo: Algo) extends Serializable
98104
header + topNode.subtreeToString(2)
99105
}
100106

107+
override def save(sc: SparkContext, path: String): Unit = {
108+
DecisionTreeModel.SaveLoadV1_0.save(sc, path, this)
109+
}
110+
111+
override protected def formatVersion: String = "1.0"
112+
}
113+
114+
object DecisionTreeModel extends Loader[DecisionTreeModel] {
115+
116+
/**
117+
* Iterator which does a DFS traversal (left to right) of a decision tree.
118+
*
119+
* Note: This is private[ml] to permit unit tests.
120+
*/
121+
private[mllib] class NodeIterator(model: DecisionTreeModel) extends Iterator[Node] {
122+
123+
/**
124+
* FILO stack of Nodes during our DFS.
125+
* The top Node is returned by next().
126+
* Any Node on the queue is either a leaf or has children whom we have not yet visited.
127+
* This is empty once all Nodes have been traversed.
128+
*/
129+
val nodeTrace: mutable.Stack[Node] = new mutable.Stack[Node]()
130+
131+
nodeTrace.push(model.topNode)
132+
133+
override def hasNext: Boolean = nodeTrace.nonEmpty
134+
135+
/**
136+
* Produces the next element of this iterator.
137+
* If [[hasNext]] is false, then this throws an exception.
138+
*/
139+
override def next(): Node = {
140+
if (nodeTrace.isEmpty) {
141+
throw new Exception(
142+
"DecisionTreeModel.NodeIterator.next() was called, but no more elements remain.")
143+
}
144+
val n = nodeTrace.pop()
145+
if (!n.isLeaf) {
146+
// n is a parent
147+
nodeTrace.push(n.rightNode.get, n.leftNode.get)
148+
}
149+
n
150+
}
151+
}
152+
153+
private object SaveLoadV1_0 {
154+
155+
def thisFormatVersion = "1.0"
156+
157+
def thisClassName = "org.apache.spark.mllib.tree.DecisionTreeModel"
158+
159+
private case class PredictData(predict: Double, prob: Double)
160+
161+
private object PredictData {
162+
def apply(p: Predict): PredictData = PredictData(p.predict, p.prob)
163+
}
164+
165+
private case class InformationGainStatsData(
166+
gain: Double,
167+
impurity: Double,
168+
leftImpurity: Double,
169+
rightImpurity: Double,
170+
leftPredict: PredictData,
171+
rightPredict: PredictData)
172+
173+
private object InformationGainStatsData {
174+
def apply(i: InformationGainStats): InformationGainStatsData = {
175+
InformationGainStatsData(i.gain, i.impurity, i.leftImpurity, i.rightImpurity,
176+
PredictData(i.leftPredict), PredictData(i.rightPredict))
177+
}
178+
}
179+
180+
private case class SplitData(
181+
feature: Int,
182+
threshold: Double,
183+
featureType: Int,
184+
categories: Seq[Double]) // TODO: Change to List once SPARK-3365 is fixed
185+
186+
private object SplitData {
187+
def apply(s: Split): SplitData = {
188+
SplitData(s.feature, s.threshold, s.featureType.id, s.categories)
189+
}
190+
}
191+
192+
/** Model data for model import/export */
193+
private case class NodeData(
194+
id: Int,
195+
predict: PredictData,
196+
impurity: Double,
197+
isLeaf: Boolean,
198+
split: Option[SplitData],
199+
leftNodeId: Option[Int],
200+
rightNodeId: Option[Int],
201+
stats: Option[InformationGainStatsData])
202+
203+
private object NodeData {
204+
def apply(n: Node): NodeData = {
205+
NodeData(n.id, PredictData(n.predict), n.impurity, n.isLeaf, n.split.map(SplitData.apply),
206+
n.leftNode.map(_.id), n.rightNode.map(_.id), n.stats.map(InformationGainStatsData.apply))
207+
}
208+
}
209+
210+
def save(sc: SparkContext, path: String, model: DecisionTreeModel): Unit = {
211+
val sqlContext = new SQLContext(sc)
212+
import sqlContext.implicits._
213+
214+
// Create JSON metadata.
215+
val metadataRDD =
216+
sc.parallelize(Seq((thisClassName, thisFormatVersion, model.algo.toString, model.numNodes)))
217+
.toDataFrame("class", "version", "algo", "numNodes")
218+
metadataRDD.toJSON.repartition(1).saveAsTextFile(Loader.metadataPath(path))
219+
220+
// Create Parquet data.
221+
val nodeIterator = new DecisionTreeModel.NodeIterator(model)
222+
val dataRDD: DataFrame = sc.parallelize(nodeIterator.toSeq).map(NodeData.apply)
223+
dataRDD.saveAsParquetFile(Loader.dataPath(path))
224+
}
225+
226+
def load(sc: SparkContext, path: String, algo: String, numNodes: Int): DecisionTreeModel = {
227+
val datapath = Loader.dataPath(path)
228+
val sqlContext = new SQLContext(sc)
229+
// Load Parquet data.
230+
val dataRDD = sqlContext.parquetFile(datapath)
231+
// Check schema explicitly since erasure makes it hard to use match-case for checking.
232+
Loader.checkSchema[NodeData](dataRDD.schema)
233+
// TODO: Extract save/load for 1 tree so that it can be reused for ensembles?
234+
val splitsRDD: RDD[Option[Split]] =
235+
dataRDD.select("split.feature", "split.threshold", "split.featureType", "split.categories")
236+
.map { row: Row =>
237+
if (row.isNullAt(0)) {
238+
None
239+
} else {
240+
row match {
241+
case Row(feature: Int, threshold: Double, featureType: Int, categories: Seq[_]) =>
242+
// Note: The type cast for categories is safe since we checked the schema.
243+
Some(Split(feature, threshold, FeatureType(featureType),
244+
categories.asInstanceOf[Seq[Double]].toList))
245+
}
246+
}
247+
}
248+
val lrChildNodesRDD: RDD[Option[(Int, Int)]] =
249+
dataRDD.select("leftNodeId", "rightNodeId").map { row: Row =>
250+
if (row.isNullAt(0)) {
251+
None
252+
} else {
253+
row match {
254+
case Row(leftNodeId: Int, rightNodeId: Int) =>
255+
Some((leftNodeId, rightNodeId))
256+
}
257+
}
258+
}
259+
val gainStatsRDD: RDD[Option[InformationGainStats]] = dataRDD.select(
260+
"stats.gain", "stats.impurity", "stats.leftImpurity", "stats.rightImpurity",
261+
"stats.leftPredict.predict", "stats.leftPredict.prob",
262+
"stats.rightPredict.predict", "stats.rightPredict.prob").map { row: Row =>
263+
if (row.isNullAt(0)) {
264+
None
265+
} else {
266+
row match {
267+
case Row(gain: Double, impurity: Double, leftImpurity: Double, rightImpurity: Double,
268+
leftPredictPredict: Double, leftPredictProb: Double,
269+
rightPredictPredict: Double, rightPredictProb: Double) =>
270+
Some(new InformationGainStats(gain, impurity, leftImpurity, rightImpurity,
271+
new Predict(leftPredictPredict, leftPredictProb),
272+
new Predict(rightPredictPredict, rightPredictProb)))
273+
}
274+
}
275+
}
276+
// nodesRDD stores (Node, leftChildId, rightChildId) where the child ids are only relevant if
277+
// Node.isLeaf == false
278+
val nodesRDD: RDD[(Node, Int, Int)] =
279+
dataRDD.select("id", "predict.predict", "predict.prob", "impurity", "isLeaf").rdd
280+
.zip(splitsRDD).zip(lrChildNodesRDD).zip(gainStatsRDD).map {
281+
case (((Row(id: Int, predictPredict: Double, predictProb: Double,
282+
impurity: Double, isLeaf: Boolean),
283+
split: Option[Split]), lrChildNodes: Option[(Int, Int)]),
284+
gainStats: Option[InformationGainStats]) =>
285+
val (leftChildId, rightChildId) = lrChildNodes.getOrElse((-1, -1))
286+
(new Node(id, new Predict(predictPredict, predictProb), impurity, isLeaf,
287+
split, None, None, gainStats),
288+
leftChildId, rightChildId)
289+
}
290+
// Collect tree nodes, and build them into a tree.
291+
// nodesMap: node id -> (node, leftChild, rightChild)
292+
val nodesMap: Map[Int, (Node, Int, Int)] = nodesRDD.collect().map(n => n._1.id -> n).toMap
293+
assert(nodesMap.contains(1),
294+
s"DecisionTree missing root node (id = 1) after loading from: $datapath")
295+
val topNode = nodesMap(1)
296+
linkSubtree(topNode._1, topNode._2, topNode._3, nodesMap)
297+
assert(nodesMap.size == numNodes,
298+
s"Unable to load DecisionTreeModel data from: $datapath." +
299+
s" Expected $numNodes nodes but found ${nodesMap.size}")
300+
new DecisionTreeModel(topNode._1, Algo.fromString(algo))
301+
}
302+
}
303+
304+
/**
305+
* Link the given node to its children (if any), and recurse down the subtree.
306+
* @param node Node to link. Node.leftNode and Node.rightNode will be set if there are children.
307+
* @param leftChildId Id of left child. Ignored if node is a leaf.
308+
* @param rightChildId Id of right child. Ignored if node is a leaf.
309+
* @param nodesMap Map storing all nodes as a map: node id -> (Node, leftChildId, rightChildId).
310+
*/
311+
private def linkSubtree(
312+
node: Node,
313+
leftChildId: Int,
314+
rightChildId: Int,
315+
nodesMap: Map[Int, (Node, Int, Int)]): Unit = {
316+
if (node.isLeaf) return
317+
assert(nodesMap.contains(leftChildId),
318+
s"DecisionTreeModel.load could not find child (id=$leftChildId) of node ${node.id}.")
319+
assert(nodesMap.contains(rightChildId),
320+
s"DecisionTreeModel.load could not find child (id=$rightChildId) of node ${node.id}.")
321+
val leftChild = nodesMap(leftChildId)
322+
val rightChild = nodesMap(rightChildId)
323+
node.leftNode = Some(leftChild._1)
324+
node.rightNode = Some(rightChild._1)
325+
linkSubtree(leftChild._1, leftChild._2, leftChild._3, nodesMap)
326+
linkSubtree(rightChild._1, rightChild._2, rightChild._3, nodesMap)
327+
}
328+
329+
override def load(sc: SparkContext, path: String): DecisionTreeModel = {
330+
val (loadedClassName, version, metadata) = Loader.loadMetadata(sc, path)
331+
val (algo: String, numNodes: Int) = try {
332+
val algo_numNodes = metadata.select("algo", "numNodes").collect()
333+
assert(algo_numNodes.length == 1)
334+
algo_numNodes(0) match {
335+
case Row(a: String, n: Int) => (a, n)
336+
}
337+
} catch {
338+
// Catch both Error and Exception since the checks above can throw either.
339+
case e: Throwable =>
340+
throw new Exception(
341+
s"Unable to load DecisionTreeModel metadata from: ${Loader.metadataPath(path)}."
342+
+ s" Error message: ${e.getMessage}")
343+
}
344+
val classNameV1_0 = SaveLoadV1_0.thisClassName
345+
(loadedClassName, version) match {
346+
case (className, "1.0") if className == classNameV1_0 =>
347+
SaveLoadV1_0.load(sc, path, algo, numNodes)
348+
case _ => throw new Exception(
349+
s"DecisionTreeModel.load did not recognize model with (className, format version):" +
350+
s"($loadedClassName, $version). Supported:\n" +
351+
s" ($classNameV1_0, 1.0)")
352+
}
353+
}
101354
}

mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,9 @@ class InformationGainStats(
4949
gain == other.gain &&
5050
impurity == other.impurity &&
5151
leftImpurity == other.leftImpurity &&
52-
rightImpurity == other.rightImpurity
52+
rightImpurity == other.rightImpurity &&
53+
leftPredict == other.leftPredict &&
54+
rightPredict == other.rightPredict
5355
}
5456
case _ => false
5557
}

mllib/src/main/scala/org/apache/spark/mllib/tree/model/Predict.scala

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,4 +32,11 @@ class Predict(
3232
override def toString = {
3333
"predict = %f, prob = %f".format(predict, prob)
3434
}
35+
36+
override def equals(other: Any): Boolean = {
37+
other match {
38+
case p: Predict => predict == p.predict && prob == p.prob
39+
case _ => false
40+
}
41+
}
3542
}

0 commit comments

Comments
 (0)