@@ -37,7 +37,6 @@ class DecisionTree(val strategy : Strategy) {
3737 val (splits, bins) = DecisionTree .find_splits_bins(input, strategy)
3838
3939 // TODO: Level-wise training of tree and obtain Decision Tree model
40-
4140 val maxDepth = strategy.maxDepth
4241
4342 val maxNumNodes = scala.math.pow(2 ,maxDepth).toInt - 1
@@ -55,8 +54,20 @@ class DecisionTree(val strategy : Strategy) {
5554
5655}
5756
58- object DecisionTree extends Logging {
57+ object DecisionTree extends Serializable {
58+
59+ /*
60+ Returns an Array[Split] of optimal splits for all nodes at a given level
61+
62+ @param input RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] used as training data for DecisionTree
63+ @param strategy [[org.apache.spark.mllib.tree.Strategy]] instance containing parameters for construction the DecisionTree
64+ @param level Level of the tree
65+ @param filters Filter for all nodes at a given level
66+ @param splits possible splits for all features
67+ @param bins possible bins for all features
5968
69+ @return Array[Split] instance for best splits for all nodes at a given level.
70+ */
6071 def findBestSplits (
6172 input : RDD [LabeledPoint ],
6273 strategy : Strategy ,
@@ -65,6 +76,16 @@ object DecisionTree extends Logging {
6576 splits : Array [Array [Split ]],
6677 bins : Array [Array [Bin ]]) : Array [Split ] = {
6778
79+ // TODO: Move these calculations outside
80+ val numNodes = scala.math.pow(2 , level).toInt
81+ println(" numNodes = " + numNodes)
82+ // Find the number of features by looking at the first sample
83+ val numFeatures = input.take(1 )(0 ).features.length
84+ println(" numFeatures = " + numFeatures)
85+ val numSplits = strategy.numSplits
86+ println(" numSplits = " + numSplits)
87+
88+ /* Find the filters used before reaching the current code*/
6889 def findParentFilters (nodeIndex : Int ): List [Filter ] = {
6990 if (level == 0 ) {
7091 List [Filter ]()
@@ -75,6 +96,10 @@ object DecisionTree extends Logging {
7596 }
7697 }
7798
99+ /* Find whether the sample is valid input for the current node.
100+
101+ In other words, does it pass through all the filters for the current node.
102+ */
78103 def isSampleValid (parentFilters : List [Filter ], labeledPoint : LabeledPoint ): Boolean = {
79104
80105 for (filter <- parentFilters) {
@@ -91,79 +116,130 @@ object DecisionTree extends Logging {
91116 true
92117 }
93118
119+ /* Finds the right bin for the given feature*/
94120 def findBin (featureIndex : Int , labeledPoint : LabeledPoint ) : Int = {
95-
96121 // TODO: Do binary search
97122 for (binIndex <- 0 until strategy.numSplits) {
98123 val bin = bins(featureIndex)(binIndex)
99- // TODO: Remove this requirement post basic functional testing
100- require(bin.lowSplit.feature == featureIndex)
101- require(bin.highSplit.feature == featureIndex)
124+ // TODO: Remove this requirement post basic functional
102125 val lowThreshold = bin.lowSplit.threshold
103126 val highThreshold = bin.highSplit.threshold
104127 val features = labeledPoint.features
105- if ((lowThreshold < features(featureIndex)) & (highThreshold < features(featureIndex))) {
128+ if ((lowThreshold <= features(featureIndex)) & (highThreshold > features(featureIndex))) {
106129 return binIndex
107130 }
108131 }
109132 throw new UnknownError (" no bin was found." )
110133
111134 }
112- def findBinsForLevel : Array [Double ] = {
113135
114- val numNodes = scala.math.pow(2 , level).toInt
115- // Find the number of features by looking at the first sample
116- val numFeatures = input.take(1 )(0 ).features.length
136+ /* Finds bins for all nodes (and all features) at a given level
137+ k features, l nodes
138+ Storage label, b11, b12, b13, .., bk, b21, b22, .. ,bl1, bl2, .. ,blk
139+ Denotes invalid sample for tree by noting bin for feature 1 as -1
140+ */
141+ def findBinsForLevel (labeledPoint : LabeledPoint ) : Array [Double ] = {
142+
117143
118- // TODO: Bit pack more by removing redundant label storage
119144 // calculating bin index and label per feature per node
120- val arr = new Array [Double ](2 * numFeatures * numNodes)
145+ val arr = new Array [Double ](1 + (numFeatures * numNodes))
146+ arr(0 ) = labeledPoint.label
121147 for (nodeIndex <- 0 until numNodes) {
122148 val parentFilters = findParentFilters(nodeIndex)
123149 // Find out whether the sample qualifies for the particular node
124150 val sampleValid = isSampleValid(parentFilters, labeledPoint)
125- val shift = 2 * numFeatures * nodeIndex
126- if (sampleValid) {
151+ val shift = 1 + numFeatures * nodeIndex
152+ if (! sampleValid) {
127153 // Add to invalid bin index -1
128- for (featureIndex <- shift until (shift + numFeatures) by 2 ) {
129- arr(featureIndex + 1 ) = - 1
130- arr(featureIndex + 2 ) = labeledPoint.label
154+ for (featureIndex <- 0 until numFeatures) {
155+ arr(shift + featureIndex ) = - 1
156+ // TODO: Break since marking one bin is sufficient
131157 }
132158 } else {
133159 for (featureIndex <- 0 until numFeatures) {
134- arr( shift + ( featureIndex * 2 ) + 1 ) = findBin( featureIndex, labeledPoint )
135- arr(shift + ( featureIndex * 2 ) + 2 ) = labeledPoint.label
160+ // println(" shift+ featureIndex =" + (shift+ featureIndex) )
161+ arr(shift + featureIndex) = findBin(featureIndex, labeledPoint)
136162 }
137163 }
138164
139165 }
140166 arr
141167 }
142168
143- val binMappedRDD = input.map(labeledPoint => findBinsForLevel)
169+ /*
170+ Performs a sequential aggreation over a partition
171+
172+ @param agg Array[Double] storing aggregate calculation of size numSplits*numFeatures*numNodes for classification
173+ and 3*numSplits*numFeatures*numNodes for regression
174+ @param arr Array[Double] of size 1+(numFeatures*numNodes)
175+ @return Array[Double] storing aggregate calculation of size numSplits*numFeatures*numNodes for classification
176+ and 3*numSplits*numFeatures*numNodes for regression
177+ */
178+ def binSeqOp (agg : Array [Double ], arr : Array [Double ]) : Array [Double ] = {
179+ for (node <- 0 until numNodes) {
180+ val validSignalIndex = 1 + numFeatures* node
181+ val isSampleValidForNode = if (arr(validSignalIndex) != - 1 ) true else false
182+ if (isSampleValidForNode) {
183+ for (feature <- 0 until numFeatures){
184+ val arrShift = 1 + numFeatures* node
185+ val aggShift = numSplits* numFeatures* node
186+ val arrIndex = arrShift + feature
187+ val aggIndex = aggShift + feature* numSplits + arr(arrIndex).toInt
188+ agg(aggIndex) = agg(aggIndex) + 1
189+ }
190+ }
191+ }
192+ agg
193+ }
194+
195+ def binCombOp (par1 : Array [Double ], par2 : Array [Double ]) : Array [Double ] = {
196+ par1
197+ }
198+
199+ println(" input = " + input.count)
200+ val binMappedRDD = input.map(x => findBinsForLevel(x))
201+ println(" binMappedRDD.count = " + binMappedRDD.count)
144202 // calculate bin aggregates
203+
204+ val binAggregates = binMappedRDD.aggregate(Array .fill[Double ](numSplits* numFeatures* numNodes)(0 ))(binSeqOp,binCombOp)
205+
145206 // find best split
207+ println(" binAggregates.length = " + binAggregates.length)
146208
147209
148- Array [Split ]()
210+ val bestSplits = new Array [Split ](numNodes)
211+ for (node <- 0 until numNodes){
212+ val binsForNode = binAggregates.slice(node,numSplits* node)
213+ }
214+
215+ bestSplits
149216 }
150217
218+ /*
219+ Returns split and bins for decision tree calculation.
220+
221+ @param input RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] used as training data for DecisionTree
222+ @param strategy [[org.apache.spark.mllib.tree.Strategy]] instance containing parameters for construction the DecisionTree
223+ @return a tuple of (splits,bins) where Split is an Array[Array[Split]] of size (numFeatures,numSplits-1) and bins is an
224+ Array[Array[Bin]] of size (numFeatures,numSplits1)
225+ */
151226 def find_splits_bins (input : RDD [LabeledPoint ], strategy : Strategy ) : (Array [Array [Split ]], Array [Array [Bin ]]) = {
152227
153228 val numSplits = strategy.numSplits
154- logDebug (" numSplits = " + numSplits)
229+ println (" numSplits = " + numSplits)
155230
156231 // Calculate the number of sample for approximate quantile calculation
157232 // TODO: Justify this calculation
158233 val requiredSamples = numSplits* numSplits
159234 val count = input.count()
160235 val fraction = if (requiredSamples < count) requiredSamples.toDouble / count else 1.0
161- logDebug (" fraction of data used for calculating quantiles = " + fraction)
236+ println (" fraction of data used for calculating quantiles = " + fraction)
162237
163238 // sampled input for RDD calculation
164239 val sampledInput = input.sample(false , fraction, 42 ).collect()
165240 val numSamples = sampledInput.length
166241
242+ // TODO: Remove this requirement
167243 require(numSamples > numSplits, " length of input samples should be greater than numSplits" )
168244
169245 // Find the number of features by looking at the first sample
0 commit comments