From 2f2ac8d635f63f3d54252c94097c7621dd3319bc Mon Sep 17 00:00:00 2001 From: Firas Abuzaid Date: Tue, 20 Oct 2015 17:40:00 -0400 Subject: [PATCH 1/8] removed partitionInfosDebug --- .../org/apache/spark/ml/tree/impl/AltDT.scala | 25 ++++++------------- 1 file changed, 7 insertions(+), 18 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/AltDT.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/AltDT.scala index a9094310da8c..172c82ad09e9 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/AltDT.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/AltDT.scala @@ -17,8 +17,6 @@ package org.apache.spark.ml.tree.impl -import scala.collection.mutable.ArrayBuffer - import org.apache.spark.Logging import org.apache.spark.broadcast.Broadcast import org.apache.spark.ml.tree._ @@ -26,7 +24,7 @@ import org.apache.spark.ml.tree.impl.TreeUtil._ import org.apache.spark.mllib.linalg.Vector import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.configuration.Strategy -import org.apache.spark.mllib.tree.impurity.{Variance, Gini, Entropy, Impurity} +import org.apache.spark.mllib.tree.impurity.{Entropy, Gini, Impurity, Variance} import org.apache.spark.mllib.tree.model.ImpurityStats import org.apache.spark.rdd.RDD import org.apache.spark.storage.StorageLevel @@ -144,15 +142,13 @@ private[ml] object AltDT extends Logging { } // Group columns together into one array of columns per partition. // TODO: Test avoiding this grouping, and see if it matters. - val groupedColStore: RDD[Array[FeatureVector]] = colStore.mapPartitions { iterator => - val groupedCols = new ArrayBuffer[FeatureVector] - iterator.foreach(groupedCols += _) - if (groupedCols.nonEmpty) Iterator(groupedCols.toArray) else Iterator() + val groupedColStore: RDD[Array[FeatureVector]] = colStore.mapPartitions { iterator: Iterator[FeatureVector] => + if (iterator.nonEmpty) Iterator(iterator.toArray) else Iterator() } groupedColStore.persist(StorageLevel.MEMORY_AND_DISK) // Initialize partitions with 1 node (each instance at the root node). - var partitionInfosA: RDD[PartitionInfo] = groupedColStore.map { groupedCols => + var partitionInfos: RDD[PartitionInfo] = groupedColStore.map { groupedCols => val initActive = new BitSet(1) initActive.set(0) new PartitionInfo(groupedCols, Array[Int](0, numRows), initActive) @@ -165,16 +161,10 @@ private[ml] object AltDT extends Logging { var activeNodePeriphery: Array[LearningNode] = Array(rootNode) var numNodeOffsets: Int = 2 - val partitionInfosDebug = new scala.collection.mutable.ArrayBuffer[RDD[PartitionInfo]]() - partitionInfosDebug.append(partitionInfosA) - // Iteratively learn, one level of the tree at a time. var currentLevel = 0 var doneLearning = false while (currentLevel < strategy.maxDepth && !doneLearning) { - - val partitionInfos = partitionInfosDebug.last - // Compute best split for each active node. val bestSplitsAndGains: Array[(Option[Split], ImpurityStats)] = computeBestSplits(partitionInfos, labelsBc, metadata) @@ -208,12 +198,11 @@ private[ml] object AltDT extends Logging { // Broadcast aggregated bit vectors. On each partition, update instance--node map. val aggBitVectorsBc = input.sparkContext.broadcast(aggBitVectors) - // partitionInfos = partitionInfos.map { partitionInfo => - val partitionInfosB = partitionInfos.map { partitionInfo => + val newPartitionInfos = partitionInfos.map { partitionInfo => partitionInfo.update(aggBitVectorsBc.value, numNodeOffsets) } - partitionInfosB.cache().count() // TODO: remove. For some reason, this is needed to make things work. Probably messing up somewhere above... - partitionInfosDebug.append(partitionInfosB) + newPartitionInfos.cache().count() // TODO: remove. For some reason, this is needed to make things work. Probably messing up somewhere above... + partitionInfos = newPartitionInfos // TODO: unpersist aggBitVectorsBc after action. } From 158fb0073d23b80dfd73a17e85bae350cdb56e00 Mon Sep 17 00:00:00 2001 From: Firas Abuzaid Date: Fri, 30 Oct 2015 13:17:57 -0400 Subject: [PATCH 2/8] Several changes: 1) PartitionInfo.update has been rewritten so that there are fewer zipWithIndex operations and sortBys. We manually sort by bit, but we still use standard library functions to sort by value, then index. This is a further improvement that needs to be made. 2) Instead of using bitSubvectors to encode which instances split left or right, we now use a single bitVector to encode this information. This reduces some of the overhead in the communication costs, since this is broadcasted to all workers. This change was made to make it easier to look up the corresponding bit for each instance -- we now use the original index of the instance to look up the its bit in the bit vector. 3) We introduce a second bitVector -- called nodeSplitBitVector -- to encode whether the node was split or not. Previously we would determine this by examining the number of instances that split left for a given bitSubvector. Since we no longer use bitSubvector, we needed an alternate way of encoding this information. This increases the overhead of communcation between the master and the workers (now, we have to broadcast two bitVectors instead of one), but this should still be less than the communication cost we had previously. --- .../org/apache/spark/ml/tree/impl/AltDT.scala | 201 +++++++++++------- .../spark/ml/tree/impl/BitSubvector.scala | 81 ------- .../spark/ml/tree/impl/AltDTSuite.scala | 57 ++--- .../ml/tree/impl/BitSubvectorSuite.scala | 69 ------ 4 files changed, 159 insertions(+), 249 deletions(-) delete mode 100644 mllib/src/main/scala/org/apache/spark/ml/tree/impl/BitSubvector.scala delete mode 100644 mllib/src/test/scala/org/apache/spark/ml/tree/impl/BitSubvectorSuite.scala diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/AltDT.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/AltDT.scala index 172c82ad09e9..d186aef7ace3 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/AltDT.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/AltDT.scala @@ -192,19 +192,28 @@ private[ml] object AltDT extends Logging { doneLearning = currentLevel + 1 >= strategy.maxDepth || estimatedRemainingActive == 0 if (!doneLearning) { + val splits: Array[Option[Split]] = bestSplitsAndGains.map(_._1) + // construct bit vector encoding which active nodes found a split + val nodeSplitBitVector: BitSet = splits.zipWithIndex.foldLeft(new BitSet(splits.length)) { (acc: BitSet, splitAndIdx: (Option[Split], Int)) => + if (splitAndIdx._1.isDefined) + acc.set(splitAndIdx._2) + acc + } + // Aggregate bit vector (1 bit/instance) indicating whether each instance goes left/right. - val aggBitVectors: Array[BitSubvector] = - collectBitVectors(partitionInfos, bestSplitsAndGains.map(_._1)) + val aggBitVector: BitSet = aggregateBitVector(partitionInfos, splits, numRows) // Broadcast aggregated bit vectors. On each partition, update instance--node map. - val aggBitVectorsBc = input.sparkContext.broadcast(aggBitVectors) + val aggBitVectorBc = input.sparkContext.broadcast(aggBitVector) + val nodeSplitBitVectorBc = input.sparkContext.broadcast(nodeSplitBitVector) val newPartitionInfos = partitionInfos.map { partitionInfo => - partitionInfo.update(aggBitVectorsBc.value, numNodeOffsets) + partitionInfo.update(aggBitVectorBc.value, nodeSplitBitVectorBc.value, numNodeOffsets) } newPartitionInfos.cache().count() // TODO: remove. For some reason, this is needed to make things work. Probably messing up somewhere above... partitionInfos = newPartitionInfos - // TODO: unpersist aggBitVectorsBc after action. + aggBitVectorBc.unpersist() + nodeSplitBitVectorBc.unpersist() } currentLevel += 1 @@ -321,25 +330,27 @@ private[ml] object AltDT extends Logging { * @param bestSplits Split for each active node, or None if that node will not be split * @return Array of bit vectors, ordered by offset ranges */ - private[impl] def collectBitVectors( + private[impl] def aggregateBitVector( partitionInfos: RDD[PartitionInfo], - bestSplits: Array[Option[Split]]): Array[BitSubvector] = { + bestSplits: Array[Option[Split]], + numRows: Int): BitSet = { + val bestSplitsBc: Broadcast[Array[Option[Split]]] = partitionInfos.sparkContext.broadcast(bestSplits) - val workerBitSubvectors: RDD[Array[BitSubvector]] = partitionInfos.map { + val workerBitSubvectors: RDD[BitSet] = partitionInfos.map { case PartitionInfo(columns: Array[FeatureVector], nodeOffsets: Array[Int], activeNodes: BitSet) => val localBestSplits: Array[Option[Split]] = bestSplitsBc.value // localFeatureIndex[feature index] = index into PartitionInfo.columns val localFeatureIndex: Map[Int, Int] = columns.map(_.featureIndex).zipWithIndex.toMap - activeNodes.iterator.zip(localBestSplits.iterator).flatMap { + val bitSetForNodes: Iterator[BitSet] = activeNodes.iterator.zip(localBestSplits.iterator).flatMap { case (nodeIndexInLevel: Int, Some(split: Split)) => if (localFeatureIndex.contains(split.featureIndex)) { // This partition has the column (feature) used for this split. val fromOffset = nodeOffsets(nodeIndexInLevel) val toOffset = nodeOffsets(nodeIndexInLevel + 1) val colIndex: Int = localFeatureIndex(split.featureIndex) - Iterator(bitSubvectorFromSplit(columns(colIndex), fromOffset, toOffset, split)) + Iterator(bitVectorFromSplit(columns(colIndex), fromOffset, toOffset, split, numRows)) } else { Iterator() } @@ -347,11 +358,19 @@ private[ml] object AltDT extends Logging { // Do not create a BitSubvector when there is no split. // This requires PartitionInfo.update to handle missing BitSubvectors. Iterator() - }.toArray + } + if (bitSetForNodes.isEmpty) + new BitSet(0) + else + bitSetForNodes.reduce[BitSet] { (acc: BitSet, bitv: BitSet) => + acc | bitv + } + } + val aggBitVector: BitSet = workerBitSubvectors.reduce { (acc: BitSet, bitv: BitSet) => + acc | bitv } - val aggBitVectors: Array[BitSubvector] = workerBitSubvectors.reduce(BitSubvector.merge) bestSplitsBc.unpersist() - aggBitVectors + aggBitVector } /** @@ -653,19 +672,23 @@ private[ml] object AltDT extends Logging { * second by sorted row indices within the node's rows. * bit[index in sorted array of row indices] = false for left, true for right */ - private[impl] def bitSubvectorFromSplit( + private[impl] def bitVectorFromSplit( col: FeatureVector, fromOffset: Int, toOffset: Int, - split: Split): BitSubvector = { - val nodeRowIndices = col.indices.view.slice(fromOffset, toOffset).toArray - val nodeRowValues = col.values.view.slice(fromOffset, toOffset).toArray - val nodeRowValuesSortedByIndices = nodeRowIndices.zip(nodeRowValues).sortBy(_._1).map(_._2) - val bitv = new BitSubvector(fromOffset, toOffset) - nodeRowValuesSortedByIndices.zipWithIndex.foreach { case (value, i) => + split: Split, + numRows: Int): BitSet = { + val nodeRowIndices = col.indices.slice(fromOffset, toOffset) + val nodeRowValues = col.values.slice(fromOffset, toOffset) + val bitv = new BitSet(numRows) + var i = 0 + while (i < nodeRowValues.length) { + val value = nodeRowValues(i) + val idx = nodeRowIndices(i) if (!split.shouldGoLeft(value)) { - bitv.set(fromOffset + i) + bitv.set(idx) } + i += 1 } bitv } @@ -717,69 +740,105 @@ private[ml] object AltDT extends Logging { * Update nodeOffsets, activeNodes: * Split offsets for nodes which split (which can be identified using the bit vector). * - * @param bitVectors Bit vectors encoding splits for the next level of the tree. + * @param instanceBitVector Bit vector encoding splits for the next level of the tree. * These must follow a 2-level ordering, where the first level is by node * and the second level is by row index. * bitVector(i) = false iff instance i goes to the left child. * For instances at inactive (leaf) nodes, the value can be arbitrary. - * When an active node is not split (e.g., because no good split was found), - * then the corresponding BitSubvector can be missing. + * @param nodeSplitBitVector Bit vector encoding whether an active node was split or not * @return Updated partition info */ - def update(bitVectors: Array[BitSubvector], newNumNodeOffsets: Int): PartitionInfo = { - val newColumns = columns.map { oldCol => - val col = oldCol.deepCopy() - var curBitVecIdx = 0 + def update(instanceBitVector: BitSet, nodeSplitBitVector: BitSet, newNumNodeOffsets: Int): PartitionInfo = { + // Create a 2-level representation of the new nodeOffsets (to be flattened). + // These 2 levels correspond to original nodes and their children (if split). + val newNodeOffsets = nodeOffsets.map(Array(_)) + + val newColumns = columns.map { col => activeNodes.iterator.foreach { nodeIdx => - val from = nodeOffsets(nodeIdx) - val to = nodeOffsets(nodeIdx + 1) - // TODO: Allow missing vectors when no split is chosen. - if (bitVectors(curBitVecIdx).to <= from) curBitVecIdx += 1 - val curBitVector = bitVectors(curBitVecIdx) // If the current BitVector does not cover this node, then this node was not split, // so we do not need to update its part of the column. Otherwise, we update it. - if (curBitVector.from <= from && to <= curBitVector.to) { - // Sort range [from, to) based on indices. This is required to match the bit vector - // across all workers. See [[bitSubvectorFromSplit]] for details. - val rangeIndices = col.indices.view.slice(from, to).toArray - val rangeValues = col.values.view.slice(from, to).toArray - val sortedRange = rangeIndices.zip(rangeValues).sortBy(_._1) - // Sort range [from, to) based on bit vector. - sortedRange.zipWithIndex.map { case ((idx, value), i) => - val bit = curBitVector.get(from + i) - // TODO: In-place merge, rather than general sort. - // TODO: We don't actually need to sort the categorical features using our approach. - (bit, value, idx) - }.sorted.zipWithIndex.foreach { case ((bit, value, idx), i) => - col.values(from + i) = value - col.indices(from + i) = idx - } - } - } - col - } + if (nodeSplitBitVector.get(nodeIdx)) { + val from = nodeOffsets(nodeIdx) + val to = nodeOffsets(nodeIdx + 1) + // Sort range [from, to) based on split, then value. This is required to match + // the bit vector across all workers. See [[bitVectorFromSplit]] for details. + // Within [from, to), we will have all "left child" instances (those that are false), + // then all "right child" instances. Then, within each child, we sort by value, so + // we can compute the best split for the next iteration. The corresponding index for + // an instance is used to look up the split value ("left" or "right") in the + // instanceBitVector, which is ordered by index. + val rangeIndices = col.indices.slice(from, to) + val rangeValues = col.values.slice(from, to) + + // BEGIN SORTING + var start = 0 + var end = rangeValues.length - 1 + // if this is the very first time we split + // we don't have to use the indices to figure + // out which bits are turned on + val numBitsSet = if (nodeOffsets.length == 2) instanceBitVector.cardinality + else rangeIndices.count(instanceBitVector.get) + val numBitsNotSet = to - from - numBitsSet - // Create a 2-level representation of the new nodeOffsets (to be flattened). - // These 2 levels correspond to original nodes and their children (if split). - val newNodeOffsets = nodeOffsets.map(Array(_)) - var curBitVecIdx = 0 - activeNodes.iterator.foreach { nodeIdx => - val from = nodeOffsets(nodeIdx) - val to = nodeOffsets(nodeIdx + 1) - if (bitVectors(curBitVecIdx).to <= from) curBitVecIdx += 1 - val curBitVector = bitVectors(curBitVecIdx) - // If the current BitVector does not cover this node, then this node was not split, - // so we do not need to create a new node offset. Otherwise, we create an offset. - if (curBitVector.from <= from && to <= curBitVector.to) { - // Count number of values splitting to left vs. right - val numRight = Range(from, to).count(curBitVector.get) - val numLeft = to - from - numRight - if (numLeft != 0 && numRight != 0) { - // node is split val oldOffset = newNodeOffsets(nodeIdx).head - newNodeOffsets(nodeIdx) = Array(oldOffset, oldOffset + numLeft) + // numBitsNotSet == number of instances going to the left + // which is how big the offset should be + newNodeOffsets(nodeIdx) = Array(oldOffset, oldOffset + numBitsNotSet) + + // first we move all of the values and indices that have + // zero-bits to the front + while (start < numBitsNotSet && start <= end) { + // "start <= end" isn't really necessary, but we + // include it anyways + val startBit = instanceBitVector.get(rangeIndices(start)) + val endBit = instanceBitVector.get(rangeIndices(end)) + // if the startBit is false, we increment and move on + if (!startBit) { + start += 1 + } + // if endBit is true, we decrement and move on + if (endBit) { + end -= 1 + // if startBit is true and endBit is false, we swap + } else if (startBit && !endBit) { + // swap both in rangeValues and in rangeIndices + // (this should be a separate helper function, + // but we want to avoid function calls) + val tempVal = rangeValues(start) + rangeValues(start) = rangeValues(end) + rangeValues(end) = tempVal + val tempIdx = rangeIndices(start) + rangeIndices(start) = rangeIndices(end) + rangeIndices(end) = tempIdx + // update both start and end + start += 1 + end -= 1 + } + } + // Now, we sort the sub-arrays from [0, numBitsNotSet) and [numBitsNotSet, rangeValues.length) + // TODO: implement our own sorting, so that we don't have to unnecessarily construct + // intermediate objects to sort + val leftValsAndIndices = rangeValues.slice(0, numBitsNotSet).zip(rangeIndices.slice(0, numBitsNotSet)).sorted + val rightValsAndIndices = rangeValues.slice(numBitsNotSet, rangeValues.length).zip(rangeIndices.slice(numBitsNotSet, rangeValues.length)).sorted + + val (sortedLeftRangeValues, sortedLeftRangeIndices) = leftValsAndIndices.unzip + val (sortedRightRangeValues, sortedRightRangeIndices) = rightValsAndIndices.unzip + + val sortedRangeValues = sortedLeftRangeValues.iterator ++ sortedRightRangeValues.iterator + val sortedRangeIndices = sortedLeftRangeIndices.iterator ++ sortedRightRangeIndices.iterator + // END SORTING + + // update the column values and indices + // with the corresponding indices + var i = 0 + while (i < rangeValues.length) { + col.values(from + i) = sortedRangeValues.next() + col.indices(from + i) = sortedRangeIndices.next() + i += 1 + } } } + col } assert(newNodeOffsets.map(_.length).sum == newNumNodeOffsets, diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/BitSubvector.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/BitSubvector.scala deleted file mode 100644 index fe9b171cfd1b..000000000000 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/BitSubvector.scala +++ /dev/null @@ -1,81 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.ml.tree.impl - -import scala.collection.mutable.ArrayBuffer - -import org.apache.spark.util.collection.BitSet - - -private[impl] class BitSubvector(val from: Int, val to: Int) extends Serializable { - - val numBits: Int = to - from - - /** Element i will be put at location i + offset in the BitSet */ - private val offset: Int = 64 - (numBits % 64) - - private val bits: BitSet = new BitSet(numBits + offset) - - def set(bit: Int): Unit = bits.set(bit + offset - from) - - def get(bit: Int): Boolean = bits.get(bit + offset - from) - - /** Get an iterator over the set bits. */ - def iterator: Iterator[Int] = new Iterator[Int] { - val iter = bits.iterator - override def hasNext: Boolean = iter.hasNext - override def next(): Int = iter.next() - offset + from - } -} - -private[impl] object BitSubvector { - - def merge(parts1: Array[BitSubvector], parts2: Array[BitSubvector]): Array[BitSubvector] = { - // Merge sorted parts1, parts2 - val sortedSubvectors = (parts1 ++ parts2).sortBy(_.from) - if (sortedSubvectors.nonEmpty) { - // Merge adjacent PartialBitVectors (for adjacent node ranges) - val newSubvectorRanges: Array[(Int, Int)] = { - val newSubvRanges = ArrayBuffer.empty[(Int, Int)] - var i = 1 - var currentFrom = sortedSubvectors.head.from - while (i < sortedSubvectors.length) { - if (sortedSubvectors(i - 1).to != sortedSubvectors(i).from) { - newSubvRanges.append((currentFrom, sortedSubvectors(i - 1).to)) - currentFrom = sortedSubvectors(i).from - } - i += 1 - } - newSubvRanges.append((currentFrom, sortedSubvectors.last.to)) - newSubvRanges.toArray - } - val newSubvectors = newSubvectorRanges.map { case (from, to) => new BitSubvector(from, to) } - var curNewSubvIdx = 0 - sortedSubvectors.foreach { subv => - if (subv.to > newSubvectors(curNewSubvIdx).to) curNewSubvIdx += 1 - val newSubv = newSubvectors(curNewSubvIdx) - // TODO: More efficient (word-level) copy. - subv.iterator.foreach(idx => newSubv.set(idx)) - } - assert(curNewSubvIdx + 1 == newSubvectors.length) // sanity check - newSubvectors - } else { - Array.empty[BitSubvector] - } - } -} diff --git a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/AltDTSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/AltDTSuite.scala index 860c8ea78760..ceda0b39dc7e 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/AltDTSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/AltDTSuite.scala @@ -21,8 +21,7 @@ import org.apache.spark.SparkFunSuite import org.apache.spark.ml.regression.DecisionTreeRegressor import org.apache.spark.ml.tree._ import org.apache.spark.ml.tree.impl.AltDT.{AltDTMetadata, FeatureVector, PartitionInfo} -import org.apache.spark.ml.tree.impl.TreeUtil._ -import org.apache.spark.mllib.linalg.{SparseVector, Vector, Vectors} +import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.impurity._ import org.apache.spark.mllib.tree.model.ImpurityStats @@ -100,11 +99,12 @@ class AltDTSuite extends SparkFunSuite with MLlibTestSparkContext { // Create bitVector for splitting the 4 rows: L, R, L, R // New groups are {0, 2}, {1, 3} - val bitVector = new BitSubvector(0, numRows) + val bitVector = new BitSet(numRows) bitVector.set(1) bitVector.set(3) - val newInfo = info.update(Array(bitVector), newNumNodeOffsets = 3) + // for these tests, use the activeNodes for nodeSplitBitVector + val newInfo = info.update(bitVector, activeNodes, newNumNodeOffsets = 3) assert(newInfo.columns.length === 2) val expectedCol1a = @@ -117,12 +117,16 @@ class AltDTSuite extends SparkFunSuite with MLlibTestSparkContext { assert(newInfo.activeNodes.iterator.toSet === Set(0, 1)) // Create 2 bitVectors for splitting into: 0, 2, 1, 3 - val bv2a = new BitSubvector(0, 2) - bv2a.set(1) - val bv2b = new BitSubvector(2, 4) - bv2b.set(3) + val bitVector2 = new BitSet(numRows) + bitVector2.set(2) // 2 goes to the right + bitVector2.set(3) // 3 goes to the right - val newInfo2 = newInfo.update(Array(bv2a, bv2b), newNumNodeOffsets = 5) + // both nodes find a split + val nodeSplits = new BitSet(2) + nodeSplits.set(0) + nodeSplits.set(1) + + val newInfo2 = newInfo.update(bitVector2, nodeSplits, newNumNodeOffsets = 5) assert(newInfo2.columns.length === 2) val expectedCol2a = @@ -286,31 +290,34 @@ class AltDTSuite extends SparkFunSuite with MLlibTestSparkContext { FeatureVector.fromOriginal(0, 0, Vectors.dense(0.1, 0.2, 0.4, 0.6, 0.7)) val fromOffset = 0 val toOffset = col.values.length + val numRows = col.values.length val split = new ContinuousSplit(0, threshold = 0.5) - val bitv = AltDT.bitSubvectorFromSplit(col, fromOffset, toOffset, split) - assert(bitv.from === fromOffset) - assert(bitv.to === toOffset) + val bitv = AltDT.bitVectorFromSplit(col, fromOffset, toOffset, split, numRows) assert(bitv.iterator.toSet === Set(3, 4)) } test("bitSubvectorFromSplit: 2 nodes") { // Initially, 1 split: (0, 2, 4) | (1, 3) + // (0.4, 0.2, 0.1) | (0.6, 0.7) + // 0, 1, 2, 3, 4 + // 0.4, 0.6, 0.2, 0.7, 0.1 val col = new FeatureVector(0, 0, Array(0.1, 0.2, 0.4, 0.6, 0.7), Array(4, 2, 0, 1, 3)) def checkSplit(fromOffset: Int, toOffset: Int, threshold: Double, expectedRight: Set[Int]): Unit = { val split = new ContinuousSplit(0, threshold) - val bitv = AltDT.bitSubvectorFromSplit(col, fromOffset, toOffset, split) - assert(bitv.from === fromOffset) - assert(bitv.to === toOffset) + val numRows = col.values.length + val bitv = AltDT.bitVectorFromSplit(col, fromOffset, toOffset, split, numRows) assert(bitv.iterator.toSet === expectedRight) } + // Left child node - checkSplit(0, 3, 0.15, Set(0, 1)) + checkSplit(0, 3, 0.05, Set(0, 2, 4)) + checkSplit(0, 3, 0.15, Set(0, 2)) checkSplit(0, 3, 0.2, Set(0)) checkSplit(0, 3, 0.5, Set()) // Right child node - checkSplit(3, 5, 0.1, Set(3, 4)) - checkSplit(3, 5, 0.65, Set(4)) + checkSplit(3, 5, 0.1, Set(1, 3)) + checkSplit(3, 5, 0.65, Set(3)) checkSplit(3, 5, 0.8, Set()) } @@ -323,11 +330,8 @@ class AltDTSuite extends SparkFunSuite with MLlibTestSparkContext { val info = PartitionInfo(Array(col), Array(0, numRows), activeNodes) val partitionInfos = sc.parallelize(Seq(info)) val bestSplit = new ContinuousSplit(0, threshold = 0.5) - val bitVectors = AltDT.collectBitVectors(partitionInfos, Array(Some(bestSplit))) - assert(bitVectors.length === 1) - val bitv = bitVectors.head - assert(bitv.numBits === numRows) - assert(bitv.iterator.toArray === Array(3, 4)) + val bitVector = AltDT.aggregateBitVector(partitionInfos, Array(Some(bestSplit)), numRows) + assert(bitVector.iterator.toSet === Set(3, 4)) } test("collectBitVectors with 1 vector, with tied threshold") { @@ -339,11 +343,8 @@ class AltDTSuite extends SparkFunSuite with MLlibTestSparkContext { val info = PartitionInfo(Array(col), Array(0, numRows), activeNodes) val partitionInfos = sc.parallelize(Seq(info)) val bestSplit = new ContinuousSplit(0, threshold = -2.0) - val bitVectors = AltDT.collectBitVectors(partitionInfos, Array(Some(bestSplit))) - assert(bitVectors.length === 1) - val bitv = bitVectors.head - assert(bitv.numBits === numRows) - assert(bitv.iterator.toArray === Array(0, 1, 4, 5)) + val bitVector = AltDT.aggregateBitVector(partitionInfos, Array(Some(bestSplit)), numRows) + assert(bitVector.iterator.toSet === Set(0, 1, 4, 5)) } //////////////////////////////// Active nodes ////////////////////////////////// diff --git a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/BitSubvectorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/BitSubvectorSuite.scala deleted file mode 100644 index 6e7b67623851..000000000000 --- a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/BitSubvectorSuite.scala +++ /dev/null @@ -1,69 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.ml.tree.impl - -import org.apache.spark.SparkFunSuite -import org.apache.spark.mllib.util.MLlibTestSparkContext - -/** - * Test suite for [[BitSubvector]]. - */ -class BitSubvectorSuite extends SparkFunSuite with MLlibTestSparkContext { - - test("BitSubvector basic ops") { - val from = 1 - val to = 4 - val bs = new BitSubvector(from, to) - assert(bs.numBits === to - from) - Range(from, to).foreach(x => assert(!bs.get(x))) - val setVals = Array(from, to - 1) - setVals.foreach { x => - bs.set(x) - assert(bs.get(x)) - } - assert(bs.iterator.toSet === setVals.toSet) - } - - test("BitSubvector merge") { - val b1 = new BitSubvector(0, 5) - b1.set(1) - val b2 = new BitSubvector(5, 7) - b2.set(5) - val b3 = new BitSubvector(9, 12) - b3.set(11) - val parts1 = Array(b1) - val parts2 = Array(b2, b3) - val newParts = BitSubvector.merge(parts1, parts2) - - val r1 = new BitSubvector(0, 7) - r1.set(1) - r1.set(5) - val r2 = new BitSubvector(9, 12) - r2.set(11) - val expectedParts = Array(r1, r2) - newParts.zip(expectedParts).foreach { case (x, y) => - assert(x.from === y.from) - assert(x.to === x.to) - assert(x.iterator.toSet === y.iterator.toSet) - } - } - - test("BitSubvector merge with empty BitSubvectors") { - val parts = BitSubvector.merge(Array.empty[BitSubvector], Array.empty[BitSubvector]) - } -} From 60f28f6b883fa802da6917fe227e467c9ab005ca Mon Sep 17 00:00:00 2001 From: Firas Abuzaid Date: Fri, 30 Oct 2015 18:46:51 -0400 Subject: [PATCH 3/8] Additional optimizations, fairly minor. Removed foreach calls, replaced them with while loops. Next optimization should be to replace zip + sort with our own custom sort --- .../org/apache/spark/ml/tree/impl/AltDT.scala | 35 ++++++++++++++----- 1 file changed, 27 insertions(+), 8 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/AltDT.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/AltDT.scala index d186aef7ace3..6be4adfc3f33 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/AltDT.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/AltDT.scala @@ -375,7 +375,6 @@ private[ml] object AltDT extends Logging { /** * Choose the best split for a feature at a node. - * * TODO: Return null or None when the split is invalid, such as putting all instances on one * child node. * @@ -433,10 +432,15 @@ private[ml] object AltDT extends Logging { // TODO: Support high-arity features by using a single array to hold the stats. // aggStats(category) = label statistics for category - val aggStats = Array.tabulate[ImpurityAggregatorSingle](featureArity)( + val aggStats: Array[ImpurityAggregatorSingle] = Array.tabulate[ImpurityAggregatorSingle](featureArity)( _ => metadata.createImpurityAggregator()) - values.zip(labels).foreach { case (cat, label) => + var i = 0 + val len = values.length + while (i < len) { + val cat = values(i) + val label = labels(i) aggStats(cat.toInt).update(label) + i += 1 } // Compute centroids. centroidsForCategories is a list: (category, centroid) @@ -484,9 +488,14 @@ private[ml] object AltDT extends Logging { val categoriesSortedByCentroid: List[Int] = centroidsForCategories.toList.sortBy(_._2).map(_._1) // Cumulative sums of bin statistics for left, right parts of split. - val leftImpurityAgg = metadata.createImpurityAggregator() - val rightImpurityAgg = metadata.createImpurityAggregator() - aggStats.foreach(rightImpurityAgg.add) + val leftImpurityAgg: ImpurityAggregatorSingle = metadata.createImpurityAggregator() + val rightImpurityAgg: ImpurityAggregatorSingle = metadata.createImpurityAggregator() + var j = 0 + val length = aggStats.length + while (j < length) { + rightImpurityAgg.add(aggStats(j)) + j += 1 + } var bestSplitIndex: Int = -1 // index into categoriesSortedByCentroid val bestLeftImpurityAgg = leftImpurityAgg.deepCopy() @@ -558,7 +567,12 @@ private[ml] object AltDT extends Logging { val leftImpurityAgg = metadata.createImpurityAggregator() val rightImpurityAgg = metadata.createImpurityAggregator() - labels.foreach(rightImpurityAgg.update(_, 1.0)) + var i = 0 + val len = labels.length + while (i < len) { + rightImpurityAgg.update(labels(i), 1.0) + i += 1 + } var bestThreshold: Double = Double.NegativeInfinity val bestLeftImpurityAgg = leftImpurityAgg.deepCopy() @@ -568,7 +582,11 @@ private[ml] object AltDT extends Logging { var rightCount: Double = rightImpurityAgg.getCount val fullCount: Double = rightCount var currentThreshold = values.headOption.getOrElse(bestThreshold) - values.zip(labels).foreach { case (value, label) => + var j = 0 + val length = values.length + while (j < length) { + val value = values(j) + val label = labels(j) if (value != currentThreshold) { // Check gain val leftWeight = leftCount / fullCount @@ -588,6 +606,7 @@ private[ml] object AltDT extends Logging { rightImpurityAgg.update(label, -1.0) leftCount += 1.0 rightCount -= 1.0 + j += 1 } val fullImpurityAgg = leftImpurityAgg.deepCopy().add(rightImpurityAgg) From 4a71a84345c59ce8031bd76a788ea02fd6c5fd36 Mon Sep 17 00:00:00 2001 From: Firas Abuzaid Date: Sun, 1 Nov 2015 22:51:49 -0500 Subject: [PATCH 4/8] removed unnecessary operations -- first(), zipWithIndex -- that introduced extra stages in the DAG --- .../org/apache/spark/ml/tree/impl/AltDT.scala | 15 ++++++++------- .../org/apache/spark/ml/tree/impl/TreeUtil.scala | 8 +------- 2 files changed, 9 insertions(+), 14 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/AltDT.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/AltDT.scala index 6be4adfc3f33..5b860ee63088 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/AltDT.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/AltDT.scala @@ -120,16 +120,17 @@ private[ml] object AltDT extends Logging { impurityCalculator) } + val numRows = { + val longNumRows: Long = input.count() + require(longNumRows < Int.MaxValue, s"rowToColumnStore given RDD with $longNumRows rows," + + s" but can handle at most ${Int.MaxValue} rows") + longNumRows.toInt + } // Prepare column store. - // Note: rowToColumnStoreDense checks to make sure numRows < Int.MaxValue. // TODO: Is this mapping from arrays to iterators to arrays (when constructing learningData)? // Or is the mapping implicit (i.e., not costly)? - val colStoreInit: RDD[(Int, Vector)] = rowToColumnStoreDense(input.map(_.features)) - val numRows: Int = colStoreInit.first()._2.size - val labels = new Array[Double](numRows) - input.map(_.label).zipWithIndex().collect().foreach { case (label: Double, rowIndex: Long) => - labels(rowIndex.toInt) = label - } + val colStoreInit: RDD[(Int, Vector)] = rowToColumnStoreDense(input.map(_.features), numRows) + val labels = input.map(_.label).collect() val labelsBc = input.sparkContext.broadcast(labels) // NOTE: Labels are not sorted with features since that would require 1 copy per feature, // rather than 1 copy per worker. This means a lot of random accesses. diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/TreeUtil.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/TreeUtil.scala index fa6eb63685d2..5c78f40f942a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/TreeUtil.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/TreeUtil.scala @@ -48,14 +48,8 @@ private[tree] object TreeUtil { * (First collect stats to decide how to partition.) * TODO: Move elsewhere in MLlib. */ - def rowToColumnStoreDense(rowStore: RDD[Vector]): RDD[(Int, Vector)] = { + def rowToColumnStoreDense(rowStore: RDD[Vector], numRows: Int): RDD[(Int, Vector)] = { - val numRows = { - val longNumRows: Long = rowStore.count() - require(longNumRows < Int.MaxValue, s"rowToColumnStore given RDD with $longNumRows rows," + - s" but can handle at most ${Int.MaxValue} rows") - longNumRows.toInt - } if (numRows == 0) { return rowStore.sparkContext.parallelize(Seq.empty[(Int, Vector)]) } From 68d59ca9ccd4a887dbfcd183da2fd6562b697770 Mon Sep 17 00:00:00 2001 From: Firas Abuzaid Date: Mon, 2 Nov 2015 22:31:39 -0500 Subject: [PATCH 5/8] Forgot to include change to TreeUtilSuite --- .../scala/org/apache/spark/ml/tree/impl/TreeUtilSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeUtilSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeUtilSuite.scala index 483fb8568f8b..abacb8b8b7d8 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeUtilSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeUtilSuite.scala @@ -32,7 +32,7 @@ class TreeUtilSuite extends SparkFunSuite with MLlibTestSparkContext { private def checkDense(rows: Seq[Vector]): Unit = { val numRowPartitions = 2 val rowStore = sc.parallelize(rows, numRowPartitions) - val colStore = rowToColumnStoreDense(rowStore) + val colStore = rowToColumnStoreDense(rowStore, rowStore.count.toInt) val numColPartitions = colStore.partitions.length val cols: Map[Int, Vector] = colStore.collect().toMap val numRows = rows.size From d6e32cd6f230b6008b504e8068edb3ef24eff3f7 Mon Sep 17 00:00:00 2001 From: Firas Abuzaid Date: Tue, 3 Nov 2015 00:03:01 -0500 Subject: [PATCH 6/8] Sorting now handled using a custom class, DualPivotQuicksort. No longer using zip or zipWithIndex to sort features with indices. --- .../org/apache/spark/ml/tree/impl/AltDT.scala | 36 +- .../ml/tree/impl/DualPivotQuicksort.java | 696 ++++++++++++++++++ .../spark/ml/tree/impl/AltDTSuite.scala | 35 + 3 files changed, 742 insertions(+), 25 deletions(-) create mode 100644 mllib/src/main/scala/org/apache/spark/ml/tree/impl/DualPivotQuicksort.java diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/AltDT.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/AltDT.scala index 5b860ee63088..6df650eada4f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/AltDT.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/AltDT.scala @@ -195,7 +195,8 @@ private[ml] object AltDT extends Logging { if (!doneLearning) { val splits: Array[Option[Split]] = bestSplitsAndGains.map(_._1) // construct bit vector encoding which active nodes found a split - val nodeSplitBitVector: BitSet = splits.zipWithIndex.foldLeft(new BitSet(splits.length)) { (acc: BitSet, splitAndIdx: (Option[Split], Int)) => + val nodeSplitBitVector: BitSet = splits.zipWithIndex.foldLeft(new BitSet(splits.length)) { + (acc: BitSet, splitAndIdx: (Option[Split], Int)) => if (splitAndIdx._1.isDefined) acc.set(splitAndIdx._2) acc @@ -203,18 +204,11 @@ private[ml] object AltDT extends Logging { // Aggregate bit vector (1 bit/instance) indicating whether each instance goes left/right. val aggBitVector: BitSet = aggregateBitVector(partitionInfos, splits, numRows) - - // Broadcast aggregated bit vectors. On each partition, update instance--node map. - val aggBitVectorBc = input.sparkContext.broadcast(aggBitVector) - val nodeSplitBitVectorBc = input.sparkContext.broadcast(nodeSplitBitVector) val newPartitionInfos = partitionInfos.map { partitionInfo => - partitionInfo.update(aggBitVectorBc.value, nodeSplitBitVectorBc.value, numNodeOffsets) + partitionInfo.update(aggBitVector, nodeSplitBitVector, numNodeOffsets) } newPartitionInfos.cache().count() // TODO: remove. For some reason, this is needed to make things work. Probably messing up somewhere above... partitionInfos = newPartitionInfos - - aggBitVectorBc.unpersist() - nodeSplitBitVectorBc.unpersist() } currentLevel += 1 @@ -672,8 +666,10 @@ private[ml] object AltDT extends Logging { featureIndex: Int, featureArity: Int, featureVector: Vector): FeatureVector = { - val (values, indices) = featureVector.toArray.zipWithIndex.sorted.unzip - new FeatureVector(featureIndex, featureArity, values.toArray, indices.toArray) + val values = featureVector.toArray + val indices = Range(0, values.length).toArray + DualPivotQuicksort.sort(values, indices) + new FeatureVector(featureIndex, featureArity, values, indices) } } @@ -836,24 +832,16 @@ private[ml] object AltDT extends Logging { } } // Now, we sort the sub-arrays from [0, numBitsNotSet) and [numBitsNotSet, rangeValues.length) - // TODO: implement our own sorting, so that we don't have to unnecessarily construct - // intermediate objects to sort - val leftValsAndIndices = rangeValues.slice(0, numBitsNotSet).zip(rangeIndices.slice(0, numBitsNotSet)).sorted - val rightValsAndIndices = rangeValues.slice(numBitsNotSet, rangeValues.length).zip(rangeIndices.slice(numBitsNotSet, rangeValues.length)).sorted - - val (sortedLeftRangeValues, sortedLeftRangeIndices) = leftValsAndIndices.unzip - val (sortedRightRangeValues, sortedRightRangeIndices) = rightValsAndIndices.unzip - - val sortedRangeValues = sortedLeftRangeValues.iterator ++ sortedRightRangeValues.iterator - val sortedRangeIndices = sortedLeftRangeIndices.iterator ++ sortedRightRangeIndices.iterator + DualPivotQuicksort.sort(rangeValues, rangeIndices, 0, numBitsNotSet - 1) + DualPivotQuicksort.sort(rangeValues, rangeIndices, numBitsNotSet, rangeValues.length - 1) // END SORTING // update the column values and indices // with the corresponding indices var i = 0 while (i < rangeValues.length) { - col.values(from + i) = sortedRangeValues.next() - col.indices(from + i) = sortedRangeIndices.next() + col.values(from + i) = rangeValues(i) + col.indices(from + i) = rangeIndices(i) i += 1 } } @@ -877,9 +865,7 @@ private[ml] object AltDT extends Logging { newNodeOffsetsIdx += 1 } } - PartitionInfo(newColumns, newNodeOffsets.flatten, newActiveNodes) } } - } diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/DualPivotQuicksort.java b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/DualPivotQuicksort.java new file mode 100644 index 000000000000..a781f50b4cd1 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/DualPivotQuicksort.java @@ -0,0 +1,696 @@ +package org.apache.spark.ml.tree.impl; + +/** + * Created by fabuzaid21 on 11/2/15. + * Helper utility class for sorting + * Double arrays with corresponding Int indices in + * {@link org.apache.spark.ml.tree.impl.AltDT}. + * Provides a more efficient alternative to: + * + *
+ * {@code
+ * val doubles = Array.fill[Double](len)(Random.nextDouble)
+ * val idxs = Array.fill[Int](len)(Random.nextInt(len))
+ * val (sortedDoubles, sortedIdxs) = doubles.zip(idxs).sorted.unzip
+ * }
+ * 
+ * + * NOTE: This implementation is directly borrowed from + * Yaroslavskiy et al.'s implementation of the Dual-Pivot + * Quicksort Algorithm in OpenJDK 7, which can be found + * here. + */ +public class DualPivotQuicksort { + + /** + * Prevents instantiation. + */ + private DualPivotQuicksort() {} + + /** + * The maximum number of runs in merge sort. + */ + private static final int MAX_RUN_COUNT = 67; + + /** + * The maximum length of run in merge sort. + */ + private static final int MAX_RUN_LENGTH = 33; + + /** + * If the length of an array to be sorted is less than this + * constant, Quicksort is used in preference to merge sort. + */ + private static final int QUICKSORT_THRESHOLD = 286; + + /** + * If the length of an array to be sorted is less than this + * constant, insertion sort is used in preference to Quicksort. + */ + private static final int INSERTION_SORT_THRESHOLD = 47; + + /** + * Sorts the specified array. + * + * @param arr the array to be sorted + * @param idxs the corresponding indices array, updated to match arr's sorted order + */ + public static void sort(double[] arr, int[] idxs) { + sort(arr, idxs, 0, arr.length - 1); + } + + /** + * Sorts the specified range of the array. + * + * @param arr the array to be sorted + * @param idxs the corresponding indices array, updated to match arr's sorted order + * @param left the index of the first element, inclusive, to be sorted + * @param right the index of the last element, inclusive, to be sorted + */ + public static void sort(double[] arr, int[] idxs, int left, int right) { + /* + * Phase 1: Move NaNs to the end of the array. + */ + while (left <= right && Double.isNaN(arr[right])) { + --right; + } + for (int k = right; --k >= left; ) { + double ak = arr[k]; + int bk = idxs[k]; + if (ak != ak) { // arr[k] is NaN + arr[k] = arr[right]; + arr[right] = ak; + idxs[k] = idxs[right]; + idxs[right] = bk; + --right; + } + } + + /* + * Phase 2: Sort everything except NaNs (which are already in place). + */ + doSort(arr, idxs, left, right); + + /* + * Phase 3: Place negative zeros before positive zeros. + */ + int hi = right; + + /* + * Find the first zero, or first positive, or last negative element. + */ + while (left < hi) { + int middle = (left + hi) >>> 1; + double middleValue = arr[middle]; + + if (middleValue < 0.0d) { + left = middle + 1; + } else { + hi = middle; + } + } + + /* + * Skip the last negative value (if any) or all leading negative zeros. + */ + while (left <= right && Double.doubleToRawLongBits(arr[left]) < 0) { + ++left; + } + + /* + * Move negative zeros to the beginning of the sub-range. + * + * Partitioning: + * + * +----------------------------------------------------+ + * | < 0.0 | -0.0 | 0.0 | ? ( >= 0.0 ) | + * +----------------------------------------------------+ + * ^ ^ ^ + * | | | + * left p k + * + * Invariants: + * + * all in (*, left) < 0.0 + * all in [left, p) == -0.0 + * all in [p, k) == 0.0 + * all in [k, right] >= 0.0 + * + * Pointer k is the first index of ?-part. + */ + for (int k = left, p = left - 1; ++k <= right; ) { + double ak = arr[k]; + if (ak != 0.0d) { + break; + } + if (Double.doubleToRawLongBits(ak) < 0) { // ak is -0.0d + arr[k] = 0.0d; + arr[++p] = -0.0d; + } + } + } + + /** + * Sorts the specified range of the array. + * + * @param arr the array to be sorted + * @param idxs the corresponding indices array, updated to match arr's sorted order + * @param left the index of the first element, inclusive, to be sorted + * @param right the index of the last element, inclusive, to be sorted + */ + private static void doSort(double[] arr, int[] idxs, int left, int right) { + // Use Quicksort on small arrays + if (right - left < QUICKSORT_THRESHOLD) { + sort(arr, idxs, left, right, true); + return; + } + + /* + * Index run[i] is the start of i-th run + * (ascending or descending sequence). + */ + int[] run = new int[MAX_RUN_COUNT + 1]; + int count = 0; run[0] = left; + + // Check if the array is nearly sorted + for (int k = left; k < right; run[count] = k) { + if (arr[k] < arr[k + 1]) { // ascending + while (++k <= right && arr[k - 1] <= arr[k]); + } else if (arr[k] > arr[k + 1]) { // descending + while (++k <= right && arr[k - 1] >= arr[k]); + for (int lo = run[count] - 1, hi = k; ++lo < --hi; ) { + double t = arr[lo]; arr[lo] = arr[hi]; arr[hi] = t; + int v = idxs[lo]; idxs[lo] = idxs[hi]; idxs[hi] = v; + } + } else { // equal + for (int m = MAX_RUN_LENGTH; ++k <= right && arr[k - 1] == arr[k]; ) { + if (--m == 0) { + sort(arr, idxs, left, right, true); + return; + } + } + } + + /* + * The array is not highly structured, + * use Quicksort instead of merge sort. + */ + if (++count == MAX_RUN_COUNT) { + sort(arr, idxs, left, right, true); + return; + } + } + + // Check special cases + if (run[count] == right++) { // The last run contains one element + run[++count] = right; + } else if (count == 1) { // The array is already sorted + return; + } + + /* + * Create temporary array, which is used for merging. + * Implementation note: variable "right" is increased by 1. + */ + double[] tempArr; byte odd = 0; int[] tempIdxs; + for (int n = 1; (n <<= 1) < count; odd ^= 1); + + if (odd == 0) { + tempArr = arr; arr = new double[tempArr.length]; + tempIdxs = idxs; idxs = new int[tempIdxs.length]; + for (int i = left - 1; ++i < right; arr[i] = tempArr[i], idxs[i] = tempIdxs[i]); + } else { + tempArr = new double[arr.length]; + tempIdxs = new int[idxs.length]; + } + + // Merging + for (int last; count > 1; count = last) { + for (int k = (last = 0) + 2; k <= count; k += 2) { + int hi = run[k], mi = run[k - 1]; + for (int i = run[k - 2], p = i, q = mi; i < hi; ++i) { + if (q >= hi || p < mi && arr[p] <= arr[q]) { + tempIdxs[i] = idxs[p]; + tempArr[i] = arr[p++]; + } else { + tempIdxs[i] = idxs[q]; + tempArr[i] = arr[q++]; + } + } + run[++last] = hi; + } + if ((count & 1) != 0) { + for (int i = right, lo = run[count - 1]; --i >= lo; + tempArr[i] = arr[i], tempIdxs[i] = idxs[i] + ); + run[++last] = right; + } + double[] t = arr; arr = tempArr; tempArr = t; + int[] v = idxs; idxs = tempIdxs; tempIdxs = v; + } + } + + /** + * Sorts the specified range of the array by Dual-Pivot Quicksort. + * + * @param arr the array to be sorted + * @param idxs the corresponding indices array, updated to match arr's sorted order + * @param left the index of the first element, inclusive, to be sorted + * @param right the index of the last element, inclusive, to be sorted + * @param leftmost indicates if this part is the leftmost in the range + */ + private static void sort(double[] arr, int[] idxs, int left, int right, boolean leftmost) { + int length = right - left + 1; + + // Use insertion sort on tiny arrays + if (length < INSERTION_SORT_THRESHOLD) { + if (leftmost) { + /* + * Traditional (without sentinel) insertion sort, + * optimized for server VM, is used in case of + * the leftmost part. + */ + for (int i = left, j = i; i < right; j = ++i) { + double ai = arr[i + 1]; + int bi = idxs[i + 1]; + while (ai < arr[j]) { + arr[j + 1] = arr[j]; + idxs[j + 1] = idxs[j]; + if (j-- == left) { + break; + } + } + arr[j + 1] = ai; + idxs[j + 1] = bi; + } + } else { + /* + * Skip the longest ascending sequence. + */ + do { + if (left >= right) { + return; + } + } while (arr[++left] >= arr[left - 1]); + + /* + * Every element from adjoining part plays the role + * of sentinel, therefore this allows us to avoid the + * left range check on each iteration. Moreover, we use + * the more optimized algorithm, so called pair insertion + * sort, which is faster (in the context of Quicksort) + * than traditional implementation of insertion sort. + */ + for (int k = left; ++left <= right; k = ++left) { + double a1 = arr[k], a2 = arr[left]; + int b1 = idxs[k], b2 = idxs[left]; + + if (a1 < a2) { + a2 = a1; a1 = arr[left]; + b2 = b1; b1 = idxs[left]; + } + while (a1 < arr[--k]) { + arr[k + 2] = arr[k]; + idxs[k + 2] = idxs[k]; + } + arr[++k + 1] = a1; + idxs[k + 1] = b1; + + while (a2 < arr[--k]) { + arr[k + 1] = arr[k]; + idxs[k + 1] = idxs[k]; + } + arr[k + 1] = a2; + idxs[k + 1] = b2; + } + double last = arr[right]; + int bLast = idxs[right]; + + while (last < arr[--right]) { + arr[right + 1] = arr[right]; + idxs[right + 1] = idxs[right]; + } + arr[right + 1] = last; + idxs[right + 1] = bLast; + } + return; + } + + // Inexpensive approximation of length / 7 + int seventh = (length >> 3) + (length >> 6) + 1; + + /* + * Sort five evenly spaced elements around (and including) the + * center element in the range. These elements will be used for + * pivot selection as described below. The choice for spacing + * these elements was empirically determined to work well on + * arr wide variety of inputs. + */ + int e3 = (left + right) >>> 1; // The midpoint + int e2 = e3 - seventh; + int e1 = e2 - seventh; + int e4 = e3 + seventh; + int e5 = e4 + seventh; + + // Sort these elements using insertion sort + if (arr[e2] < arr[e1]) { + double t = arr[e2]; + int v = idxs[e2]; + + arr[e2] = arr[e1]; + arr[e1] = t; + idxs[e2] = idxs[e1]; + idxs[e1] = v; + } + + if (arr[e3] < arr[e2]) { + double t = arr[e3]; + int v = idxs[e3]; + + arr[e3] = arr[e2]; + arr[e2] = t; + idxs[e3] = idxs[e2]; + idxs[e2] = v; + + if (t < arr[e1]) { + arr[e2] = arr[e1]; + arr[e1] = t; + idxs[e2] = idxs[e1]; + idxs[e1] = v; + } + } + + if (arr[e4] < arr[e3]) { + double t = arr[e4]; + int v = idxs[e4]; + + arr[e4] = arr[e3]; + arr[e3] = t; + idxs[e4] = idxs[e3]; + idxs[e3] = v; + + if (t < arr[e2]) { + arr[e3] = arr[e2]; + arr[e2] = t; + idxs[e3] = idxs[e2]; + idxs[e2] = v; + + if (t < arr[e1]) { + arr[e2] = arr[e1]; + arr[e1] = t; + idxs[e2] = idxs[e1]; + idxs[e1] = v; + } + } + } + + if (arr[e5] < arr[e4]) { + double t = arr[e5]; + int v = idxs[e5]; + + arr[e5] = arr[e4]; + arr[e4] = t; + idxs[e5] = idxs[e4]; + idxs[e4] = v; + + if (t < arr[e3]) { + arr[e4] = arr[e3]; + arr[e3] = t; + idxs[e4] = idxs[e3]; + idxs[e3] = v; + + if (t < arr[e2]) { + arr[e3] = arr[e2]; + arr[e2] = t; + idxs[e3] = idxs[e2]; + idxs[e2] = v; + + if (t < arr[e1]) { + arr[e2] = arr[e1]; + arr[e1] = t; + idxs[e2] = idxs[e1]; + idxs[e1] = v; + } + } + } + } + + // Pointers + int less = left; // The index of the first element of center part + int great = right; // The index before the first element of right part + + if (arr[e1] != arr[e2] && arr[e2] != arr[e3] && arr[e3] != arr[e4] && arr[e4] != arr[e5]) { + /* + * Use the second and fourth of the five sorted elements as pivots. + * These values are inexpensive approximations of the first and + * second terciles of the array. Note that pivot1 <= pivot2. + */ + double pivot1 = arr[e2]; + double pivot2 = arr[e4]; + int idxPivot1 = idxs[e2]; + int idxPivot2 = idxs[e4]; + + /* + * The first and the last elements to be sorted are moved to the + * locations formerly occupied by the pivots. When partitioning + * is complete, the pivots are swapped back into their final + * positions, and excluded from subsequent sorting. + */ + arr[e2] = arr[left]; + arr[e4] = arr[right]; + idxs[e2] = idxs[left]; + idxs[e4] = idxs[right]; + + /* + * Skip elements, which are less or greater than pivot values. + */ + while (arr[++less] < pivot1); + while (arr[--great] > pivot2); + + /* + * Partitioning: + * + * left part center part right part + * +--------------------------------------------------------------+ + * | < pivot1 | pivot1 <= && <= pivot2 | ? | > pivot2 | + * +--------------------------------------------------------------+ + * ^ ^ ^ + * | | | + * less k great + * + * Invariants: + * + * all in (left, less) < pivot1 + * pivot1 <= all in [less, k) <= pivot2 + * all in (great, right) > pivot2 + * + * Pointer k is the first index of ?-part. + */ + outer: + for (int k = less - 1; ++k <= great; ) { + double ak = arr[k]; + int bk = idxs[k]; + if (ak < pivot1) { // Move arr[k] to left part + arr[k] = arr[less]; + idxs[k] = idxs[less]; + /* + * Here and below we use "arr[i] = b; i++;" instead + * of "arr[i++] = b;" due to performance issue. + */ + arr[less] = ak; + idxs[less] = bk; + ++less; + } else if (ak > pivot2) { // Move arr[k] to right part + while (arr[great] > pivot2) { + if (great-- == k) { + break outer; + } + } + if (arr[great] < pivot1) { // arr[great] <= pivot2 + arr[k] = arr[less]; + arr[less] = arr[great]; + idxs[k] = idxs[less]; + idxs[less] = idxs[great]; + ++less; + } else { // pivot1 <= arr[great] <= pivot2 + arr[k] = arr[great]; + idxs[k] = idxs[great]; + } + /* + * Here and below we use "arr[i] = b; i--;" instead + * of "arr[i--] = b;" due to performance issue. + */ + arr[great] = ak; + idxs[great] = bk; + --great; + } + } + + // Swap pivots into their final positions + arr[left] = arr[less - 1]; arr[less - 1] = pivot1; + arr[right] = arr[great + 1]; arr[great + 1] = pivot2; + idxs[left] = idxs[less - 1]; idxs[less - 1] = idxPivot1; + idxs[right] = idxs[great + 1]; idxs[great + 1] = idxPivot2; + + // Sort left and right parts recursively, excluding known pivots + sort(arr, idxs, left, less - 2, leftmost); + sort(arr, idxs, great + 2, right, false); + + /* + * If center part is too large (comprises > 4/7 of the array), + * swap internal pivot values to ends. + */ + if (less < e1 && e5 < great) { + /* + * Skip elements, which are equal to pivot values. + */ + while (arr[less] == pivot1) { + ++less; + } + + while (arr[great] == pivot2) { + --great; + } + + /* + * Partitioning: + * + * left part center part right part + * +----------------------------------------------------------+ + * | == pivot1 | pivot1 < && < pivot2 | ? | == pivot2 | + * +----------------------------------------------------------+ + * ^ ^ ^ + * | | | + * less k great + * + * Invariants: + * + * all in (*, less) == pivot1 + * pivot1 < all in [less, k) < pivot2 + * all in (great, *) == pivot2 + * + * Pointer k is the first index of ?-part. + */ + outer: + for (int k = less - 1; ++k <= great; ) { + double ak = arr[k]; + int bk = idxs[k]; + if (ak == pivot1) { // Move arr[k] to left part + arr[k] = arr[less]; + arr[less] = ak; + idxs[k] = idxs[less]; + idxs[less] = bk; + ++less; + } else if (ak == pivot2) { // Move arr[k] to right part + while (arr[great] == pivot2) { + if (great-- == k) { + break outer; + } + } + if (arr[great] == pivot1) { // arr[great] < pivot2 + arr[k] = arr[less]; + idxs[k] = idxs[less]; + /* + * Even though arr[great] equals to pivot1, the + * assignment arr[less] = pivot1 may be incorrect, + * if arr[great] and pivot1 are floating-point zeros + * of different signs. Therefore in float and + * double sorting methods we have to use more + * accurate assignment arr[less] = arr[great]. + */ + arr[less] = arr[great]; + idxs[less] = idxs[great]; + ++less; + } else { // pivot1 < arr[great] < pivot2 + arr[k] = arr[great]; + idxs[k] = idxs[great]; + } + arr[great] = ak; + idxs[great] = bk; + --great; + } + } + } + + // Sort center part recursively + sort(arr, idxs, less, great, false); + + } else { // Partitioning with one pivot + /* + * Use the third of the five sorted elements as pivot. + * This value is inexpensive approximation of the median. + */ + double pivot = arr[e3]; + + /* + * Partitioning degenerates to the traditional 3-way + * (or "Dutch National Flag") schema: + * + * left part center part right part + * +-------------------------------------------------+ + * | < pivot | == pivot | ? | > pivot | + * +-------------------------------------------------+ + * ^ ^ ^ + * | | | + * less k great + * + * Invariants: + * + * all in (left, less) < pivot + * all in [less, k) == pivot + * all in (great, right) > pivot + * + * Pointer k is the first index of ?-part. + */ + for (int k = less; k <= great; ++k) { + if (arr[k] == pivot) { + continue; + } + double ak = arr[k]; + int bk = idxs[k]; + if (ak < pivot) { // Move arr[k] to left part + arr[k] = arr[less]; + arr[less] = ak; + idxs[k] = idxs[less]; + idxs[less] = bk; + ++less; + } else { // arr[k] > pivot - Move arr[k] to right part + while (arr[great] > pivot) { + --great; + } + if (arr[great] < pivot) { // arr[great] <= pivot + arr[k] = arr[less]; + arr[less] = arr[great]; + idxs[k] = idxs[less]; + idxs[less] = idxs[great]; + ++less; + } else { // arr[great] == pivot + /* + * Even though arr[great] equals to pivot, the + * assignment arr[k] = pivot may be incorrect, + * if arr[great] and pivot are floating-point + * zeros of different signs. Therefore in float + * and double sorting methods we have to use + * more accurate assignment arr[k] = arr[great]. + */ + arr[k] = arr[great]; + idxs[k] = idxs[great]; + } + arr[great] = ak; + idxs[great] = bk; + --great; + } + } + + /* + * Sort left and right parts recursively. + * All elements from center part are equal + * and, therefore, already sorted. + */ + sort(arr, idxs, left, less - 1, leftmost); + sort(arr, idxs, great + 1, right, false); + } + } +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/AltDTSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/AltDTSuite.scala index ceda0b39dc7e..1c24dcd73608 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/AltDTSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/AltDTSuite.scala @@ -28,6 +28,8 @@ import org.apache.spark.mllib.tree.model.ImpurityStats import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.util.collection.BitSet +import scala.util.Random + /** * Test suite for [[AltDT]]. */ @@ -71,6 +73,39 @@ class AltDTSuite extends SparkFunSuite with MLlibTestSparkContext { //////////////////////////////// Helper classes ////////////////////////////////// + test("DualPivotQuicksort: sort array") { + val r = new Random() + val len = 1000 + val doubles = Array.fill[Double](len)(r.nextDouble) + val idxs = Array.fill[Int](len)(r.nextInt(len)) + + val (sortedDoubles, sortedIdxs) = doubles.zip(idxs).sorted.unzip + DualPivotQuicksort.sort(doubles, idxs) + + assert(sortedDoubles.sameElements(doubles)) + assert(sortedIdxs.sameElements(idxs)) + } + + test("DualPivotQuicksort: sort sub-arrays") { + val r = new Random() + val len = 1000 + val partition = r.nextInt(len) // partition point between the two sub-arrays + val doubles = Array.fill[Double](len)(r.nextDouble) + val idxs = Array.fill[Int](len)(r.nextInt(len)) + + val (sortedLeftDoubles, sortedLeftIdxs) = doubles.slice(0, partition).zip(idxs.slice(0, partition)).sorted.unzip + val (sortedRightDoubles, sortedRightIdxs) = doubles.slice(partition, len).zip(idxs.slice(partition, len)).sorted.unzip + + val sortedDoubles = sortedLeftDoubles ++ sortedRightDoubles + val sortedIdxs = sortedLeftIdxs ++ sortedRightIdxs + + DualPivotQuicksort.sort(doubles, idxs, 0, partition - 1) + DualPivotQuicksort.sort(doubles, idxs, partition, len - 1) + + assert(sortedDoubles.sameElements(doubles)) + assert(sortedIdxs.sameElements(idxs)) + } + test("FeatureVector") { val v = new FeatureVector(1, 0, Array(0.1, 0.3, 0.7), Array(1, 2, 0)) From 6dd5e676badb8a55a358e39d8b899aba70153a9c Mon Sep 17 00:00:00 2001 From: Firas Abuzaid Date: Thu, 12 Nov 2015 17:41:40 -0500 Subject: [PATCH 7/8] couple more minor improvements --- .../org/apache/spark/ml/tree/impl/AltDT.scala | 21 ++++++++++++------- .../apache/spark/ml/tree/impl/TreeUtil.scala | 11 +++++++--- 2 files changed, 21 insertions(+), 11 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/AltDT.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/AltDT.scala index 6df650eada4f..aba7645c3182 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/AltDT.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/AltDT.scala @@ -256,15 +256,21 @@ private[ml] object AltDT extends Logging { case PartitionInfo(columns: Array[FeatureVector], nodeOffsets: Array[Int], activeNodes: BitSet) => val localLabels = labelsBc.value // Iterate over the active nodes in the current level. - activeNodes.iterator.map { nodeIndexInLevel: Int => + val toReturn = new Array[(Option[Split], ImpurityStats)](activeNodes.cardinality()) + val iter: Iterator[Int] = activeNodes.iterator + var i = 0 + while (iter.hasNext) { + val nodeIndexInLevel = iter.next val fromOffset = nodeOffsets(nodeIndexInLevel) val toOffset = nodeOffsets(nodeIndexInLevel + 1) val splitsAndStats = columns.map { col => chooseSplit(col, localLabels, fromOffset, toOffset, metadata) } - splitsAndStats.maxBy(_._2.gain) - }.toArray + toReturn(i) = splitsAndStats.maxBy(_._2.gain) + i += 1 + } + toReturn } // TODO: treeReduce @@ -799,7 +805,10 @@ private[ml] object AltDT extends Logging { val oldOffset = newNodeOffsets(nodeIdx).head // numBitsNotSet == number of instances going to the left // which is how big the offset should be - newNodeOffsets(nodeIdx) = Array(oldOffset, oldOffset + numBitsNotSet) + if (numBitsNotSet == 0) + newNodeOffsets(nodeIdx) = Array(oldOffset) + else + newNodeOffsets(nodeIdx) = Array(oldOffset, oldOffset + numBitsNotSet) // first we move all of the values and indices that have // zero-bits to the front @@ -849,10 +858,6 @@ private[ml] object AltDT extends Logging { col } - assert(newNodeOffsets.map(_.length).sum == newNumNodeOffsets, - s"(W) newNodeOffsets total size: ${newNodeOffsets.map(_.length).sum}," + - s" newNumNodeOffsets: $newNumNodeOffsets") - // Identify the new activeNodes based on the 2-level representation of the new nodeOffsets. val newActiveNodes = new BitSet(newNumNodeOffsets - 1) var newNodeOffsetsIdx = 0 diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/TreeUtil.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/TreeUtil.scala index 5c78f40f942a..710ec48f72cb 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/TreeUtil.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/TreeUtil.scala @@ -83,12 +83,16 @@ private[tree] object TreeUtil { // = column values for each instance in sourcePartitionIndex, // where colIdx is a 0-based index for columns for groupIndex val columnSets = new Array[Array[ArrayBuffer[Double]]](numTargetPartitions) - Range(0, numTargetPartitions).foreach { groupIndex => + var groupIndex = 0 + while(groupIndex < numTargetPartitions) { columnSets(groupIndex) = Array.fill[ArrayBuffer[Double]](getNumColsInGroup(groupIndex))(ArrayBuffer[Double]()) + groupIndex += 1 } - iterator.foreach { row => - Range(0, numTargetPartitions).foreach { groupIndex => + while (iterator.hasNext) { + val row: Vector = iterator.next() + var groupIndex = 0 + while (groupIndex < numTargetPartitions) { val fromCol = groupIndex * maxColumnsPerPartition val numColsInTargetPartition = getNumColsInGroup(groupIndex) // TODO: match-case here on row as Dense or Sparse Vector (for speed) @@ -97,6 +101,7 @@ private[tree] object TreeUtil { columnSets(groupIndex)(colIdx) += row(fromCol + colIdx) colIdx += 1 } + groupIndex += 1 } } Range(0, numTargetPartitions).map { groupIndex => From c228f7ff0e2f1d41d8e12c9348d7fd6dc2a8635b Mon Sep 17 00:00:00 2001 From: Firas Abuzaid Date: Sun, 22 Nov 2015 16:27:46 -0500 Subject: [PATCH 8/8] Now conforming to Spark style guidelines --- .../org/apache/spark/ml/tree/impl/AltDT.scala | 48 +++++++++++-------- .../spark/mllib/tree/impurity/Impurity.scala | 2 +- 2 files changed, 30 insertions(+), 20 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/AltDT.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/AltDT.scala index aba7645c3182..d1d19fe3635e 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/AltDT.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/AltDT.scala @@ -87,7 +87,8 @@ private[ml] object AltDT extends Logging { } private[impl] object AltDTMetadata { - def fromStrategy(strategy: Strategy) = new AltDTMetadata(strategy.numClasses, strategy.maxBins, + def fromStrategy(strategy: Strategy): AltDTMetadata = + new AltDTMetadata(strategy.numClasses, strategy.maxBins, strategy.minInfoGain, strategy.impurity) } @@ -143,7 +144,8 @@ private[ml] object AltDT extends Logging { } // Group columns together into one array of columns per partition. // TODO: Test avoiding this grouping, and see if it matters. - val groupedColStore: RDD[Array[FeatureVector]] = colStore.mapPartitions { iterator: Iterator[FeatureVector] => + val groupedColStore: RDD[Array[FeatureVector]] = colStore.mapPartitions { + iterator: Iterator[FeatureVector] => if (iterator.nonEmpty) Iterator(iterator.toArray) else Iterator() } groupedColStore.persist(StorageLevel.MEMORY_AND_DISK) @@ -197,8 +199,9 @@ private[ml] object AltDT extends Logging { // construct bit vector encoding which active nodes found a split val nodeSplitBitVector: BitSet = splits.zipWithIndex.foldLeft(new BitSet(splits.length)) { (acc: BitSet, splitAndIdx: (Option[Split], Int)) => - if (splitAndIdx._1.isDefined) + if (splitAndIdx._1.isDefined) { acc.set(splitAndIdx._2) + } acc } @@ -207,7 +210,9 @@ private[ml] object AltDT extends Logging { val newPartitionInfos = partitionInfos.map { partitionInfo => partitionInfo.update(aggBitVector, nodeSplitBitVector, numNodeOffsets) } - newPartitionInfos.cache().count() // TODO: remove. For some reason, this is needed to make things work. Probably messing up somewhere above... + // TODO: remove. For some reason, this is needed to make things work. + // Probably messing up somewhere above... + newPartitionInfos.cache().count() partitionInfos = newPartitionInfos } @@ -253,7 +258,8 @@ private[ml] object AltDT extends Logging { // for each active node, best split + info gain, // where the best split is None if no useful split exists val partBestSplitsAndGains: RDD[Array[(Option[Split], ImpurityStats)]] = partitionInfos.map { - case PartitionInfo(columns: Array[FeatureVector], nodeOffsets: Array[Int], activeNodes: BitSet) => + case PartitionInfo(columns: Array[FeatureVector], nodeOffsets: Array[Int], + activeNodes: BitSet) => val localLabels = labelsBc.value // Iterate over the active nodes in the current level. val toReturn = new Array[(Option[Split], ImpurityStats)](activeNodes.cardinality()) @@ -344,7 +350,8 @@ private[ml] object AltDT extends Logging { val localBestSplits: Array[Option[Split]] = bestSplitsBc.value // localFeatureIndex[feature index] = index into PartitionInfo.columns val localFeatureIndex: Map[Int, Int] = columns.map(_.featureIndex).zipWithIndex.toMap - val bitSetForNodes: Iterator[BitSet] = activeNodes.iterator.zip(localBestSplits.iterator).flatMap { + val bitSetForNodes: Iterator[BitSet] = activeNodes.iterator.zip(localBestSplits.iterator). + flatMap { case (nodeIndexInLevel: Int, Some(split: Split)) => if (localFeatureIndex.contains(split.featureIndex)) { // This partition has the column (feature) used for this split. @@ -360,12 +367,11 @@ private[ml] object AltDT extends Logging { // This requires PartitionInfo.update to handle missing BitSubvectors. Iterator() } - if (bitSetForNodes.isEmpty) - new BitSet(0) - else - bitSetForNodes.reduce[BitSet] { (acc: BitSet, bitv: BitSet) => - acc | bitv - } + if (bitSetForNodes.isEmpty) { + new BitSet(0) + } else { + bitSetForNodes.reduce[BitSet]((acc: BitSet, bitv: BitSet) => acc | bitv) + } } val aggBitVector: BitSet = workerBitSubvectors.reduce { (acc: BitSet, bitv: BitSet) => acc | bitv @@ -433,8 +439,8 @@ private[ml] object AltDT extends Logging { // TODO: Support high-arity features by using a single array to hold the stats. // aggStats(category) = label statistics for category - val aggStats: Array[ImpurityAggregatorSingle] = Array.tabulate[ImpurityAggregatorSingle](featureArity)( - _ => metadata.createImpurityAggregator()) + val aggStats: Array[ImpurityAggregatorSingle] = Array.tabulate[ImpurityAggregatorSingle]( + featureArity)(_ => metadata.createImpurityAggregator()) var i = 0 val len = values.length while (i < len) { @@ -770,7 +776,8 @@ private[ml] object AltDT extends Logging { * @param nodeSplitBitVector Bit vector encoding whether an active node was split or not * @return Updated partition info */ - def update(instanceBitVector: BitSet, nodeSplitBitVector: BitSet, newNumNodeOffsets: Int): PartitionInfo = { + def update(instanceBitVector: BitSet, nodeSplitBitVector: BitSet, newNumNodeOffsets: Int): + PartitionInfo = { // Create a 2-level representation of the new nodeOffsets (to be flattened). // These 2 levels correspond to original nodes and their children (if split). val newNodeOffsets = nodeOffsets.map(Array(_)) @@ -805,10 +812,11 @@ private[ml] object AltDT extends Logging { val oldOffset = newNodeOffsets(nodeIdx).head // numBitsNotSet == number of instances going to the left // which is how big the offset should be - if (numBitsNotSet == 0) + if (numBitsNotSet == 0) { newNodeOffsets(nodeIdx) = Array(oldOffset) - else + } else { newNodeOffsets(nodeIdx) = Array(oldOffset, oldOffset + numBitsNotSet) + } // first we move all of the values and indices that have // zero-bits to the front @@ -840,9 +848,11 @@ private[ml] object AltDT extends Logging { end -= 1 } } - // Now, we sort the sub-arrays from [0, numBitsNotSet) and [numBitsNotSet, rangeValues.length) + // Now, we sort the sub-arrays from [0, numBitsNotSet) and + // [numBitsNotSet, rangeValues.length) DualPivotQuicksort.sort(rangeValues, rangeIndices, 0, numBitsNotSet - 1) - DualPivotQuicksort.sort(rangeValues, rangeIndices, numBitsNotSet, rangeValues.length - 1) + DualPivotQuicksort.sort(rangeValues, rangeIndices, numBitsNotSet, + rangeValues.length - 1) // END SORTING // update the column values and indices diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala index d1e624fb8d6c..862e29bd1093 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala @@ -160,7 +160,7 @@ private[spark] abstract class ImpurityCalculator(val stats: Array[Double]) exten def prob(label: Double): Double = -1 /** Get [[Predict]] struct. */ - def getPredict = { + def getPredict: Predict = { val pred = this.predict new Predict(predict = pred, prob = this.prob(pred)) }