diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/AltDT.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/AltDT.scala index a9094310da8c..d1d19fe3635e 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/AltDT.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/AltDT.scala @@ -17,8 +17,6 @@ package org.apache.spark.ml.tree.impl -import scala.collection.mutable.ArrayBuffer - import org.apache.spark.Logging import org.apache.spark.broadcast.Broadcast import org.apache.spark.ml.tree._ @@ -26,7 +24,7 @@ import org.apache.spark.ml.tree.impl.TreeUtil._ import org.apache.spark.mllib.linalg.Vector import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.configuration.Strategy -import org.apache.spark.mllib.tree.impurity.{Variance, Gini, Entropy, Impurity} +import org.apache.spark.mllib.tree.impurity.{Entropy, Gini, Impurity, Variance} import org.apache.spark.mllib.tree.model.ImpurityStats import org.apache.spark.rdd.RDD import org.apache.spark.storage.StorageLevel @@ -89,7 +87,8 @@ private[ml] object AltDT extends Logging { } private[impl] object AltDTMetadata { - def fromStrategy(strategy: Strategy) = new AltDTMetadata(strategy.numClasses, strategy.maxBins, + def fromStrategy(strategy: Strategy): AltDTMetadata = + new AltDTMetadata(strategy.numClasses, strategy.maxBins, strategy.minInfoGain, strategy.impurity) } @@ -122,16 +121,17 @@ private[ml] object AltDT extends Logging { impurityCalculator) } + val numRows = { + val longNumRows: Long = input.count() + require(longNumRows < Int.MaxValue, s"rowToColumnStore given RDD with $longNumRows rows," + + s" but can handle at most ${Int.MaxValue} rows") + longNumRows.toInt + } // Prepare column store. - // Note: rowToColumnStoreDense checks to make sure numRows < Int.MaxValue. // TODO: Is this mapping from arrays to iterators to arrays (when constructing learningData)? // Or is the mapping implicit (i.e., not costly)? - val colStoreInit: RDD[(Int, Vector)] = rowToColumnStoreDense(input.map(_.features)) - val numRows: Int = colStoreInit.first()._2.size - val labels = new Array[Double](numRows) - input.map(_.label).zipWithIndex().collect().foreach { case (label: Double, rowIndex: Long) => - labels(rowIndex.toInt) = label - } + val colStoreInit: RDD[(Int, Vector)] = rowToColumnStoreDense(input.map(_.features), numRows) + val labels = input.map(_.label).collect() val labelsBc = input.sparkContext.broadcast(labels) // NOTE: Labels are not sorted with features since that would require 1 copy per feature, // rather than 1 copy per worker. This means a lot of random accesses. @@ -144,15 +144,14 @@ private[ml] object AltDT extends Logging { } // Group columns together into one array of columns per partition. // TODO: Test avoiding this grouping, and see if it matters. - val groupedColStore: RDD[Array[FeatureVector]] = colStore.mapPartitions { iterator => - val groupedCols = new ArrayBuffer[FeatureVector] - iterator.foreach(groupedCols += _) - if (groupedCols.nonEmpty) Iterator(groupedCols.toArray) else Iterator() + val groupedColStore: RDD[Array[FeatureVector]] = colStore.mapPartitions { + iterator: Iterator[FeatureVector] => + if (iterator.nonEmpty) Iterator(iterator.toArray) else Iterator() } groupedColStore.persist(StorageLevel.MEMORY_AND_DISK) // Initialize partitions with 1 node (each instance at the root node). - var partitionInfosA: RDD[PartitionInfo] = groupedColStore.map { groupedCols => + var partitionInfos: RDD[PartitionInfo] = groupedColStore.map { groupedCols => val initActive = new BitSet(1) initActive.set(0) new PartitionInfo(groupedCols, Array[Int](0, numRows), initActive) @@ -165,16 +164,10 @@ private[ml] object AltDT extends Logging { var activeNodePeriphery: Array[LearningNode] = Array(rootNode) var numNodeOffsets: Int = 2 - val partitionInfosDebug = new scala.collection.mutable.ArrayBuffer[RDD[PartitionInfo]]() - partitionInfosDebug.append(partitionInfosA) - // Iteratively learn, one level of the tree at a time. var currentLevel = 0 var doneLearning = false while (currentLevel < strategy.maxDepth && !doneLearning) { - - val partitionInfos = partitionInfosDebug.last - // Compute best split for each active node. val bestSplitsAndGains: Array[(Option[Split], ImpurityStats)] = computeBestSplits(partitionInfos, labelsBc, metadata) @@ -202,20 +195,25 @@ private[ml] object AltDT extends Logging { doneLearning = currentLevel + 1 >= strategy.maxDepth || estimatedRemainingActive == 0 if (!doneLearning) { - // Aggregate bit vector (1 bit/instance) indicating whether each instance goes left/right. - val aggBitVectors: Array[BitSubvector] = - collectBitVectors(partitionInfos, bestSplitsAndGains.map(_._1)) - - // Broadcast aggregated bit vectors. On each partition, update instance--node map. - val aggBitVectorsBc = input.sparkContext.broadcast(aggBitVectors) - // partitionInfos = partitionInfos.map { partitionInfo => - val partitionInfosB = partitionInfos.map { partitionInfo => - partitionInfo.update(aggBitVectorsBc.value, numNodeOffsets) + val splits: Array[Option[Split]] = bestSplitsAndGains.map(_._1) + // construct bit vector encoding which active nodes found a split + val nodeSplitBitVector: BitSet = splits.zipWithIndex.foldLeft(new BitSet(splits.length)) { + (acc: BitSet, splitAndIdx: (Option[Split], Int)) => + if (splitAndIdx._1.isDefined) { + acc.set(splitAndIdx._2) + } + acc } - partitionInfosB.cache().count() // TODO: remove. For some reason, this is needed to make things work. Probably messing up somewhere above... - partitionInfosDebug.append(partitionInfosB) - // TODO: unpersist aggBitVectorsBc after action. + // Aggregate bit vector (1 bit/instance) indicating whether each instance goes left/right. + val aggBitVector: BitSet = aggregateBitVector(partitionInfos, splits, numRows) + val newPartitionInfos = partitionInfos.map { partitionInfo => + partitionInfo.update(aggBitVector, nodeSplitBitVector, numNodeOffsets) + } + // TODO: remove. For some reason, this is needed to make things work. + // Probably messing up somewhere above... + newPartitionInfos.cache().count() + partitionInfos = newPartitionInfos } currentLevel += 1 @@ -260,18 +258,25 @@ private[ml] object AltDT extends Logging { // for each active node, best split + info gain, // where the best split is None if no useful split exists val partBestSplitsAndGains: RDD[Array[(Option[Split], ImpurityStats)]] = partitionInfos.map { - case PartitionInfo(columns: Array[FeatureVector], nodeOffsets: Array[Int], activeNodes: BitSet) => + case PartitionInfo(columns: Array[FeatureVector], nodeOffsets: Array[Int], + activeNodes: BitSet) => val localLabels = labelsBc.value // Iterate over the active nodes in the current level. - activeNodes.iterator.map { nodeIndexInLevel: Int => + val toReturn = new Array[(Option[Split], ImpurityStats)](activeNodes.cardinality()) + val iter: Iterator[Int] = activeNodes.iterator + var i = 0 + while (iter.hasNext) { + val nodeIndexInLevel = iter.next val fromOffset = nodeOffsets(nodeIndexInLevel) val toOffset = nodeOffsets(nodeIndexInLevel + 1) val splitsAndStats = columns.map { col => chooseSplit(col, localLabels, fromOffset, toOffset, metadata) } - splitsAndStats.maxBy(_._2.gain) - }.toArray + toReturn(i) = splitsAndStats.maxBy(_._2.gain) + i += 1 + } + toReturn } // TODO: treeReduce @@ -332,25 +337,28 @@ private[ml] object AltDT extends Logging { * @param bestSplits Split for each active node, or None if that node will not be split * @return Array of bit vectors, ordered by offset ranges */ - private[impl] def collectBitVectors( + private[impl] def aggregateBitVector( partitionInfos: RDD[PartitionInfo], - bestSplits: Array[Option[Split]]): Array[BitSubvector] = { + bestSplits: Array[Option[Split]], + numRows: Int): BitSet = { + val bestSplitsBc: Broadcast[Array[Option[Split]]] = partitionInfos.sparkContext.broadcast(bestSplits) - val workerBitSubvectors: RDD[Array[BitSubvector]] = partitionInfos.map { + val workerBitSubvectors: RDD[BitSet] = partitionInfos.map { case PartitionInfo(columns: Array[FeatureVector], nodeOffsets: Array[Int], activeNodes: BitSet) => val localBestSplits: Array[Option[Split]] = bestSplitsBc.value // localFeatureIndex[feature index] = index into PartitionInfo.columns val localFeatureIndex: Map[Int, Int] = columns.map(_.featureIndex).zipWithIndex.toMap - activeNodes.iterator.zip(localBestSplits.iterator).flatMap { + val bitSetForNodes: Iterator[BitSet] = activeNodes.iterator.zip(localBestSplits.iterator). + flatMap { case (nodeIndexInLevel: Int, Some(split: Split)) => if (localFeatureIndex.contains(split.featureIndex)) { // This partition has the column (feature) used for this split. val fromOffset = nodeOffsets(nodeIndexInLevel) val toOffset = nodeOffsets(nodeIndexInLevel + 1) val colIndex: Int = localFeatureIndex(split.featureIndex) - Iterator(bitSubvectorFromSplit(columns(colIndex), fromOffset, toOffset, split)) + Iterator(bitVectorFromSplit(columns(colIndex), fromOffset, toOffset, split, numRows)) } else { Iterator() } @@ -358,16 +366,22 @@ private[ml] object AltDT extends Logging { // Do not create a BitSubvector when there is no split. // This requires PartitionInfo.update to handle missing BitSubvectors. Iterator() - }.toArray + } + if (bitSetForNodes.isEmpty) { + new BitSet(0) + } else { + bitSetForNodes.reduce[BitSet]((acc: BitSet, bitv: BitSet) => acc | bitv) + } + } + val aggBitVector: BitSet = workerBitSubvectors.reduce { (acc: BitSet, bitv: BitSet) => + acc | bitv } - val aggBitVectors: Array[BitSubvector] = workerBitSubvectors.reduce(BitSubvector.merge) bestSplitsBc.unpersist() - aggBitVectors + aggBitVector } /** * Choose the best split for a feature at a node. - * * TODO: Return null or None when the split is invalid, such as putting all instances on one * child node. * @@ -425,10 +439,15 @@ private[ml] object AltDT extends Logging { // TODO: Support high-arity features by using a single array to hold the stats. // aggStats(category) = label statistics for category - val aggStats = Array.tabulate[ImpurityAggregatorSingle](featureArity)( - _ => metadata.createImpurityAggregator()) - values.zip(labels).foreach { case (cat, label) => + val aggStats: Array[ImpurityAggregatorSingle] = Array.tabulate[ImpurityAggregatorSingle]( + featureArity)(_ => metadata.createImpurityAggregator()) + var i = 0 + val len = values.length + while (i < len) { + val cat = values(i) + val label = labels(i) aggStats(cat.toInt).update(label) + i += 1 } // Compute centroids. centroidsForCategories is a list: (category, centroid) @@ -476,9 +495,14 @@ private[ml] object AltDT extends Logging { val categoriesSortedByCentroid: List[Int] = centroidsForCategories.toList.sortBy(_._2).map(_._1) // Cumulative sums of bin statistics for left, right parts of split. - val leftImpurityAgg = metadata.createImpurityAggregator() - val rightImpurityAgg = metadata.createImpurityAggregator() - aggStats.foreach(rightImpurityAgg.add) + val leftImpurityAgg: ImpurityAggregatorSingle = metadata.createImpurityAggregator() + val rightImpurityAgg: ImpurityAggregatorSingle = metadata.createImpurityAggregator() + var j = 0 + val length = aggStats.length + while (j < length) { + rightImpurityAgg.add(aggStats(j)) + j += 1 + } var bestSplitIndex: Int = -1 // index into categoriesSortedByCentroid val bestLeftImpurityAgg = leftImpurityAgg.deepCopy() @@ -550,7 +574,12 @@ private[ml] object AltDT extends Logging { val leftImpurityAgg = metadata.createImpurityAggregator() val rightImpurityAgg = metadata.createImpurityAggregator() - labels.foreach(rightImpurityAgg.update(_, 1.0)) + var i = 0 + val len = labels.length + while (i < len) { + rightImpurityAgg.update(labels(i), 1.0) + i += 1 + } var bestThreshold: Double = Double.NegativeInfinity val bestLeftImpurityAgg = leftImpurityAgg.deepCopy() @@ -560,7 +589,11 @@ private[ml] object AltDT extends Logging { var rightCount: Double = rightImpurityAgg.getCount val fullCount: Double = rightCount var currentThreshold = values.headOption.getOrElse(bestThreshold) - values.zip(labels).foreach { case (value, label) => + var j = 0 + val length = values.length + while (j < length) { + val value = values(j) + val label = labels(j) if (value != currentThreshold) { // Check gain val leftWeight = leftCount / fullCount @@ -580,6 +613,7 @@ private[ml] object AltDT extends Logging { rightImpurityAgg.update(label, -1.0) leftCount += 1.0 rightCount -= 1.0 + j += 1 } val fullImpurityAgg = leftImpurityAgg.deepCopy().add(rightImpurityAgg) @@ -644,8 +678,10 @@ private[ml] object AltDT extends Logging { featureIndex: Int, featureArity: Int, featureVector: Vector): FeatureVector = { - val (values, indices) = featureVector.toArray.zipWithIndex.sorted.unzip - new FeatureVector(featureIndex, featureArity, values.toArray, indices.toArray) + val values = featureVector.toArray + val indices = Range(0, values.length).toArray + DualPivotQuicksort.sort(values, indices) + new FeatureVector(featureIndex, featureArity, values, indices) } } @@ -664,19 +700,23 @@ private[ml] object AltDT extends Logging { * second by sorted row indices within the node's rows. * bit[index in sorted array of row indices] = false for left, true for right */ - private[impl] def bitSubvectorFromSplit( + private[impl] def bitVectorFromSplit( col: FeatureVector, fromOffset: Int, toOffset: Int, - split: Split): BitSubvector = { - val nodeRowIndices = col.indices.view.slice(fromOffset, toOffset).toArray - val nodeRowValues = col.values.view.slice(fromOffset, toOffset).toArray - val nodeRowValuesSortedByIndices = nodeRowIndices.zip(nodeRowValues).sortBy(_._1).map(_._2) - val bitv = new BitSubvector(fromOffset, toOffset) - nodeRowValuesSortedByIndices.zipWithIndex.foreach { case (value, i) => + split: Split, + numRows: Int): BitSet = { + val nodeRowIndices = col.indices.slice(fromOffset, toOffset) + val nodeRowValues = col.values.slice(fromOffset, toOffset) + val bitv = new BitSet(numRows) + var i = 0 + while (i < nodeRowValues.length) { + val value = nodeRowValues(i) + val idx = nodeRowIndices(i) if (!split.shouldGoLeft(value)) { - bitv.set(fromOffset + i) + bitv.set(idx) } + i += 1 } bitv } @@ -728,75 +768,106 @@ private[ml] object AltDT extends Logging { * Update nodeOffsets, activeNodes: * Split offsets for nodes which split (which can be identified using the bit vector). * - * @param bitVectors Bit vectors encoding splits for the next level of the tree. + * @param instanceBitVector Bit vector encoding splits for the next level of the tree. * These must follow a 2-level ordering, where the first level is by node * and the second level is by row index. * bitVector(i) = false iff instance i goes to the left child. * For instances at inactive (leaf) nodes, the value can be arbitrary. - * When an active node is not split (e.g., because no good split was found), - * then the corresponding BitSubvector can be missing. + * @param nodeSplitBitVector Bit vector encoding whether an active node was split or not * @return Updated partition info */ - def update(bitVectors: Array[BitSubvector], newNumNodeOffsets: Int): PartitionInfo = { - val newColumns = columns.map { oldCol => - val col = oldCol.deepCopy() - var curBitVecIdx = 0 + def update(instanceBitVector: BitSet, nodeSplitBitVector: BitSet, newNumNodeOffsets: Int): + PartitionInfo = { + // Create a 2-level representation of the new nodeOffsets (to be flattened). + // These 2 levels correspond to original nodes and their children (if split). + val newNodeOffsets = nodeOffsets.map(Array(_)) + + val newColumns = columns.map { col => activeNodes.iterator.foreach { nodeIdx => - val from = nodeOffsets(nodeIdx) - val to = nodeOffsets(nodeIdx + 1) - // TODO: Allow missing vectors when no split is chosen. - if (bitVectors(curBitVecIdx).to <= from) curBitVecIdx += 1 - val curBitVector = bitVectors(curBitVecIdx) // If the current BitVector does not cover this node, then this node was not split, // so we do not need to update its part of the column. Otherwise, we update it. - if (curBitVector.from <= from && to <= curBitVector.to) { - // Sort range [from, to) based on indices. This is required to match the bit vector - // across all workers. See [[bitSubvectorFromSplit]] for details. - val rangeIndices = col.indices.view.slice(from, to).toArray - val rangeValues = col.values.view.slice(from, to).toArray - val sortedRange = rangeIndices.zip(rangeValues).sortBy(_._1) - // Sort range [from, to) based on bit vector. - sortedRange.zipWithIndex.map { case ((idx, value), i) => - val bit = curBitVector.get(from + i) - // TODO: In-place merge, rather than general sort. - // TODO: We don't actually need to sort the categorical features using our approach. - (bit, value, idx) - }.sorted.zipWithIndex.foreach { case ((bit, value, idx), i) => - col.values(from + i) = value - col.indices(from + i) = idx - } - } - } - col - } + if (nodeSplitBitVector.get(nodeIdx)) { + val from = nodeOffsets(nodeIdx) + val to = nodeOffsets(nodeIdx + 1) + // Sort range [from, to) based on split, then value. This is required to match + // the bit vector across all workers. See [[bitVectorFromSplit]] for details. + // Within [from, to), we will have all "left child" instances (those that are false), + // then all "right child" instances. Then, within each child, we sort by value, so + // we can compute the best split for the next iteration. The corresponding index for + // an instance is used to look up the split value ("left" or "right") in the + // instanceBitVector, which is ordered by index. + val rangeIndices = col.indices.slice(from, to) + val rangeValues = col.values.slice(from, to) + + // BEGIN SORTING + var start = 0 + var end = rangeValues.length - 1 + // if this is the very first time we split + // we don't have to use the indices to figure + // out which bits are turned on + val numBitsSet = if (nodeOffsets.length == 2) instanceBitVector.cardinality + else rangeIndices.count(instanceBitVector.get) + val numBitsNotSet = to - from - numBitsSet - // Create a 2-level representation of the new nodeOffsets (to be flattened). - // These 2 levels correspond to original nodes and their children (if split). - val newNodeOffsets = nodeOffsets.map(Array(_)) - var curBitVecIdx = 0 - activeNodes.iterator.foreach { nodeIdx => - val from = nodeOffsets(nodeIdx) - val to = nodeOffsets(nodeIdx + 1) - if (bitVectors(curBitVecIdx).to <= from) curBitVecIdx += 1 - val curBitVector = bitVectors(curBitVecIdx) - // If the current BitVector does not cover this node, then this node was not split, - // so we do not need to create a new node offset. Otherwise, we create an offset. - if (curBitVector.from <= from && to <= curBitVector.to) { - // Count number of values splitting to left vs. right - val numRight = Range(from, to).count(curBitVector.get) - val numLeft = to - from - numRight - if (numLeft != 0 && numRight != 0) { - // node is split val oldOffset = newNodeOffsets(nodeIdx).head - newNodeOffsets(nodeIdx) = Array(oldOffset, oldOffset + numLeft) + // numBitsNotSet == number of instances going to the left + // which is how big the offset should be + if (numBitsNotSet == 0) { + newNodeOffsets(nodeIdx) = Array(oldOffset) + } else { + newNodeOffsets(nodeIdx) = Array(oldOffset, oldOffset + numBitsNotSet) + } + + // first we move all of the values and indices that have + // zero-bits to the front + while (start < numBitsNotSet && start <= end) { + // "start <= end" isn't really necessary, but we + // include it anyways + val startBit = instanceBitVector.get(rangeIndices(start)) + val endBit = instanceBitVector.get(rangeIndices(end)) + // if the startBit is false, we increment and move on + if (!startBit) { + start += 1 + } + // if endBit is true, we decrement and move on + if (endBit) { + end -= 1 + // if startBit is true and endBit is false, we swap + } else if (startBit && !endBit) { + // swap both in rangeValues and in rangeIndices + // (this should be a separate helper function, + // but we want to avoid function calls) + val tempVal = rangeValues(start) + rangeValues(start) = rangeValues(end) + rangeValues(end) = tempVal + val tempIdx = rangeIndices(start) + rangeIndices(start) = rangeIndices(end) + rangeIndices(end) = tempIdx + // update both start and end + start += 1 + end -= 1 + } + } + // Now, we sort the sub-arrays from [0, numBitsNotSet) and + // [numBitsNotSet, rangeValues.length) + DualPivotQuicksort.sort(rangeValues, rangeIndices, 0, numBitsNotSet - 1) + DualPivotQuicksort.sort(rangeValues, rangeIndices, numBitsNotSet, + rangeValues.length - 1) + // END SORTING + + // update the column values and indices + // with the corresponding indices + var i = 0 + while (i < rangeValues.length) { + col.values(from + i) = rangeValues(i) + col.indices(from + i) = rangeIndices(i) + i += 1 + } } } + col } - assert(newNodeOffsets.map(_.length).sum == newNumNodeOffsets, - s"(W) newNodeOffsets total size: ${newNodeOffsets.map(_.length).sum}," + - s" newNumNodeOffsets: $newNumNodeOffsets") - // Identify the new activeNodes based on the 2-level representation of the new nodeOffsets. val newActiveNodes = new BitSet(newNumNodeOffsets - 1) var newNodeOffsetsIdx = 0 @@ -809,9 +880,7 @@ private[ml] object AltDT extends Logging { newNodeOffsetsIdx += 1 } } - PartitionInfo(newColumns, newNodeOffsets.flatten, newActiveNodes) } } - } diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/BitSubvector.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/BitSubvector.scala deleted file mode 100644 index fe9b171cfd1b..000000000000 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/BitSubvector.scala +++ /dev/null @@ -1,81 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.ml.tree.impl - -import scala.collection.mutable.ArrayBuffer - -import org.apache.spark.util.collection.BitSet - - -private[impl] class BitSubvector(val from: Int, val to: Int) extends Serializable { - - val numBits: Int = to - from - - /** Element i will be put at location i + offset in the BitSet */ - private val offset: Int = 64 - (numBits % 64) - - private val bits: BitSet = new BitSet(numBits + offset) - - def set(bit: Int): Unit = bits.set(bit + offset - from) - - def get(bit: Int): Boolean = bits.get(bit + offset - from) - - /** Get an iterator over the set bits. */ - def iterator: Iterator[Int] = new Iterator[Int] { - val iter = bits.iterator - override def hasNext: Boolean = iter.hasNext - override def next(): Int = iter.next() - offset + from - } -} - -private[impl] object BitSubvector { - - def merge(parts1: Array[BitSubvector], parts2: Array[BitSubvector]): Array[BitSubvector] = { - // Merge sorted parts1, parts2 - val sortedSubvectors = (parts1 ++ parts2).sortBy(_.from) - if (sortedSubvectors.nonEmpty) { - // Merge adjacent PartialBitVectors (for adjacent node ranges) - val newSubvectorRanges: Array[(Int, Int)] = { - val newSubvRanges = ArrayBuffer.empty[(Int, Int)] - var i = 1 - var currentFrom = sortedSubvectors.head.from - while (i < sortedSubvectors.length) { - if (sortedSubvectors(i - 1).to != sortedSubvectors(i).from) { - newSubvRanges.append((currentFrom, sortedSubvectors(i - 1).to)) - currentFrom = sortedSubvectors(i).from - } - i += 1 - } - newSubvRanges.append((currentFrom, sortedSubvectors.last.to)) - newSubvRanges.toArray - } - val newSubvectors = newSubvectorRanges.map { case (from, to) => new BitSubvector(from, to) } - var curNewSubvIdx = 0 - sortedSubvectors.foreach { subv => - if (subv.to > newSubvectors(curNewSubvIdx).to) curNewSubvIdx += 1 - val newSubv = newSubvectors(curNewSubvIdx) - // TODO: More efficient (word-level) copy. - subv.iterator.foreach(idx => newSubv.set(idx)) - } - assert(curNewSubvIdx + 1 == newSubvectors.length) // sanity check - newSubvectors - } else { - Array.empty[BitSubvector] - } - } -} diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/DualPivotQuicksort.java b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/DualPivotQuicksort.java new file mode 100644 index 000000000000..a781f50b4cd1 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/DualPivotQuicksort.java @@ -0,0 +1,696 @@ +package org.apache.spark.ml.tree.impl; + +/** + * Created by fabuzaid21 on 11/2/15. + * Helper utility class for sorting + * Double arrays with corresponding Int indices in + * {@link org.apache.spark.ml.tree.impl.AltDT}. + * Provides a more efficient alternative to: + * + *
+ * {@code
+ * val doubles = Array.fill[Double](len)(Random.nextDouble)
+ * val idxs = Array.fill[Int](len)(Random.nextInt(len))
+ * val (sortedDoubles, sortedIdxs) = doubles.zip(idxs).sorted.unzip
+ * }
+ *
+ *
+ * NOTE: This implementation is directly borrowed from
+ * Yaroslavskiy et al.'s implementation of the Dual-Pivot
+ * Quicksort Algorithm in OpenJDK 7, which can be found
+ * here.
+ */
+public class DualPivotQuicksort {
+
+ /**
+ * Prevents instantiation.
+ */
+ private DualPivotQuicksort() {}
+
+ /**
+ * The maximum number of runs in merge sort.
+ */
+ private static final int MAX_RUN_COUNT = 67;
+
+ /**
+ * The maximum length of run in merge sort.
+ */
+ private static final int MAX_RUN_LENGTH = 33;
+
+ /**
+ * If the length of an array to be sorted is less than this
+ * constant, Quicksort is used in preference to merge sort.
+ */
+ private static final int QUICKSORT_THRESHOLD = 286;
+
+ /**
+ * If the length of an array to be sorted is less than this
+ * constant, insertion sort is used in preference to Quicksort.
+ */
+ private static final int INSERTION_SORT_THRESHOLD = 47;
+
+ /**
+ * Sorts the specified array.
+ *
+ * @param arr the array to be sorted
+ * @param idxs the corresponding indices array, updated to match arr's sorted order
+ */
+ public static void sort(double[] arr, int[] idxs) {
+ sort(arr, idxs, 0, arr.length - 1);
+ }
+
+ /**
+ * Sorts the specified range of the array.
+ *
+ * @param arr the array to be sorted
+ * @param idxs the corresponding indices array, updated to match arr's sorted order
+ * @param left the index of the first element, inclusive, to be sorted
+ * @param right the index of the last element, inclusive, to be sorted
+ */
+ public static void sort(double[] arr, int[] idxs, int left, int right) {
+ /*
+ * Phase 1: Move NaNs to the end of the array.
+ */
+ while (left <= right && Double.isNaN(arr[right])) {
+ --right;
+ }
+ for (int k = right; --k >= left; ) {
+ double ak = arr[k];
+ int bk = idxs[k];
+ if (ak != ak) { // arr[k] is NaN
+ arr[k] = arr[right];
+ arr[right] = ak;
+ idxs[k] = idxs[right];
+ idxs[right] = bk;
+ --right;
+ }
+ }
+
+ /*
+ * Phase 2: Sort everything except NaNs (which are already in place).
+ */
+ doSort(arr, idxs, left, right);
+
+ /*
+ * Phase 3: Place negative zeros before positive zeros.
+ */
+ int hi = right;
+
+ /*
+ * Find the first zero, or first positive, or last negative element.
+ */
+ while (left < hi) {
+ int middle = (left + hi) >>> 1;
+ double middleValue = arr[middle];
+
+ if (middleValue < 0.0d) {
+ left = middle + 1;
+ } else {
+ hi = middle;
+ }
+ }
+
+ /*
+ * Skip the last negative value (if any) or all leading negative zeros.
+ */
+ while (left <= right && Double.doubleToRawLongBits(arr[left]) < 0) {
+ ++left;
+ }
+
+ /*
+ * Move negative zeros to the beginning of the sub-range.
+ *
+ * Partitioning:
+ *
+ * +----------------------------------------------------+
+ * | < 0.0 | -0.0 | 0.0 | ? ( >= 0.0 ) |
+ * +----------------------------------------------------+
+ * ^ ^ ^
+ * | | |
+ * left p k
+ *
+ * Invariants:
+ *
+ * all in (*, left) < 0.0
+ * all in [left, p) == -0.0
+ * all in [p, k) == 0.0
+ * all in [k, right] >= 0.0
+ *
+ * Pointer k is the first index of ?-part.
+ */
+ for (int k = left, p = left - 1; ++k <= right; ) {
+ double ak = arr[k];
+ if (ak != 0.0d) {
+ break;
+ }
+ if (Double.doubleToRawLongBits(ak) < 0) { // ak is -0.0d
+ arr[k] = 0.0d;
+ arr[++p] = -0.0d;
+ }
+ }
+ }
+
+ /**
+ * Sorts the specified range of the array.
+ *
+ * @param arr the array to be sorted
+ * @param idxs the corresponding indices array, updated to match arr's sorted order
+ * @param left the index of the first element, inclusive, to be sorted
+ * @param right the index of the last element, inclusive, to be sorted
+ */
+ private static void doSort(double[] arr, int[] idxs, int left, int right) {
+ // Use Quicksort on small arrays
+ if (right - left < QUICKSORT_THRESHOLD) {
+ sort(arr, idxs, left, right, true);
+ return;
+ }
+
+ /*
+ * Index run[i] is the start of i-th run
+ * (ascending or descending sequence).
+ */
+ int[] run = new int[MAX_RUN_COUNT + 1];
+ int count = 0; run[0] = left;
+
+ // Check if the array is nearly sorted
+ for (int k = left; k < right; run[count] = k) {
+ if (arr[k] < arr[k + 1]) { // ascending
+ while (++k <= right && arr[k - 1] <= arr[k]);
+ } else if (arr[k] > arr[k + 1]) { // descending
+ while (++k <= right && arr[k - 1] >= arr[k]);
+ for (int lo = run[count] - 1, hi = k; ++lo < --hi; ) {
+ double t = arr[lo]; arr[lo] = arr[hi]; arr[hi] = t;
+ int v = idxs[lo]; idxs[lo] = idxs[hi]; idxs[hi] = v;
+ }
+ } else { // equal
+ for (int m = MAX_RUN_LENGTH; ++k <= right && arr[k - 1] == arr[k]; ) {
+ if (--m == 0) {
+ sort(arr, idxs, left, right, true);
+ return;
+ }
+ }
+ }
+
+ /*
+ * The array is not highly structured,
+ * use Quicksort instead of merge sort.
+ */
+ if (++count == MAX_RUN_COUNT) {
+ sort(arr, idxs, left, right, true);
+ return;
+ }
+ }
+
+ // Check special cases
+ if (run[count] == right++) { // The last run contains one element
+ run[++count] = right;
+ } else if (count == 1) { // The array is already sorted
+ return;
+ }
+
+ /*
+ * Create temporary array, which is used for merging.
+ * Implementation note: variable "right" is increased by 1.
+ */
+ double[] tempArr; byte odd = 0; int[] tempIdxs;
+ for (int n = 1; (n <<= 1) < count; odd ^= 1);
+
+ if (odd == 0) {
+ tempArr = arr; arr = new double[tempArr.length];
+ tempIdxs = idxs; idxs = new int[tempIdxs.length];
+ for (int i = left - 1; ++i < right; arr[i] = tempArr[i], idxs[i] = tempIdxs[i]);
+ } else {
+ tempArr = new double[arr.length];
+ tempIdxs = new int[idxs.length];
+ }
+
+ // Merging
+ for (int last; count > 1; count = last) {
+ for (int k = (last = 0) + 2; k <= count; k += 2) {
+ int hi = run[k], mi = run[k - 1];
+ for (int i = run[k - 2], p = i, q = mi; i < hi; ++i) {
+ if (q >= hi || p < mi && arr[p] <= arr[q]) {
+ tempIdxs[i] = idxs[p];
+ tempArr[i] = arr[p++];
+ } else {
+ tempIdxs[i] = idxs[q];
+ tempArr[i] = arr[q++];
+ }
+ }
+ run[++last] = hi;
+ }
+ if ((count & 1) != 0) {
+ for (int i = right, lo = run[count - 1]; --i >= lo;
+ tempArr[i] = arr[i], tempIdxs[i] = idxs[i]
+ );
+ run[++last] = right;
+ }
+ double[] t = arr; arr = tempArr; tempArr = t;
+ int[] v = idxs; idxs = tempIdxs; tempIdxs = v;
+ }
+ }
+
+ /**
+ * Sorts the specified range of the array by Dual-Pivot Quicksort.
+ *
+ * @param arr the array to be sorted
+ * @param idxs the corresponding indices array, updated to match arr's sorted order
+ * @param left the index of the first element, inclusive, to be sorted
+ * @param right the index of the last element, inclusive, to be sorted
+ * @param leftmost indicates if this part is the leftmost in the range
+ */
+ private static void sort(double[] arr, int[] idxs, int left, int right, boolean leftmost) {
+ int length = right - left + 1;
+
+ // Use insertion sort on tiny arrays
+ if (length < INSERTION_SORT_THRESHOLD) {
+ if (leftmost) {
+ /*
+ * Traditional (without sentinel) insertion sort,
+ * optimized for server VM, is used in case of
+ * the leftmost part.
+ */
+ for (int i = left, j = i; i < right; j = ++i) {
+ double ai = arr[i + 1];
+ int bi = idxs[i + 1];
+ while (ai < arr[j]) {
+ arr[j + 1] = arr[j];
+ idxs[j + 1] = idxs[j];
+ if (j-- == left) {
+ break;
+ }
+ }
+ arr[j + 1] = ai;
+ idxs[j + 1] = bi;
+ }
+ } else {
+ /*
+ * Skip the longest ascending sequence.
+ */
+ do {
+ if (left >= right) {
+ return;
+ }
+ } while (arr[++left] >= arr[left - 1]);
+
+ /*
+ * Every element from adjoining part plays the role
+ * of sentinel, therefore this allows us to avoid the
+ * left range check on each iteration. Moreover, we use
+ * the more optimized algorithm, so called pair insertion
+ * sort, which is faster (in the context of Quicksort)
+ * than traditional implementation of insertion sort.
+ */
+ for (int k = left; ++left <= right; k = ++left) {
+ double a1 = arr[k], a2 = arr[left];
+ int b1 = idxs[k], b2 = idxs[left];
+
+ if (a1 < a2) {
+ a2 = a1; a1 = arr[left];
+ b2 = b1; b1 = idxs[left];
+ }
+ while (a1 < arr[--k]) {
+ arr[k + 2] = arr[k];
+ idxs[k + 2] = idxs[k];
+ }
+ arr[++k + 1] = a1;
+ idxs[k + 1] = b1;
+
+ while (a2 < arr[--k]) {
+ arr[k + 1] = arr[k];
+ idxs[k + 1] = idxs[k];
+ }
+ arr[k + 1] = a2;
+ idxs[k + 1] = b2;
+ }
+ double last = arr[right];
+ int bLast = idxs[right];
+
+ while (last < arr[--right]) {
+ arr[right + 1] = arr[right];
+ idxs[right + 1] = idxs[right];
+ }
+ arr[right + 1] = last;
+ idxs[right + 1] = bLast;
+ }
+ return;
+ }
+
+ // Inexpensive approximation of length / 7
+ int seventh = (length >> 3) + (length >> 6) + 1;
+
+ /*
+ * Sort five evenly spaced elements around (and including) the
+ * center element in the range. These elements will be used for
+ * pivot selection as described below. The choice for spacing
+ * these elements was empirically determined to work well on
+ * arr wide variety of inputs.
+ */
+ int e3 = (left + right) >>> 1; // The midpoint
+ int e2 = e3 - seventh;
+ int e1 = e2 - seventh;
+ int e4 = e3 + seventh;
+ int e5 = e4 + seventh;
+
+ // Sort these elements using insertion sort
+ if (arr[e2] < arr[e1]) {
+ double t = arr[e2];
+ int v = idxs[e2];
+
+ arr[e2] = arr[e1];
+ arr[e1] = t;
+ idxs[e2] = idxs[e1];
+ idxs[e1] = v;
+ }
+
+ if (arr[e3] < arr[e2]) {
+ double t = arr[e3];
+ int v = idxs[e3];
+
+ arr[e3] = arr[e2];
+ arr[e2] = t;
+ idxs[e3] = idxs[e2];
+ idxs[e2] = v;
+
+ if (t < arr[e1]) {
+ arr[e2] = arr[e1];
+ arr[e1] = t;
+ idxs[e2] = idxs[e1];
+ idxs[e1] = v;
+ }
+ }
+
+ if (arr[e4] < arr[e3]) {
+ double t = arr[e4];
+ int v = idxs[e4];
+
+ arr[e4] = arr[e3];
+ arr[e3] = t;
+ idxs[e4] = idxs[e3];
+ idxs[e3] = v;
+
+ if (t < arr[e2]) {
+ arr[e3] = arr[e2];
+ arr[e2] = t;
+ idxs[e3] = idxs[e2];
+ idxs[e2] = v;
+
+ if (t < arr[e1]) {
+ arr[e2] = arr[e1];
+ arr[e1] = t;
+ idxs[e2] = idxs[e1];
+ idxs[e1] = v;
+ }
+ }
+ }
+
+ if (arr[e5] < arr[e4]) {
+ double t = arr[e5];
+ int v = idxs[e5];
+
+ arr[e5] = arr[e4];
+ arr[e4] = t;
+ idxs[e5] = idxs[e4];
+ idxs[e4] = v;
+
+ if (t < arr[e3]) {
+ arr[e4] = arr[e3];
+ arr[e3] = t;
+ idxs[e4] = idxs[e3];
+ idxs[e3] = v;
+
+ if (t < arr[e2]) {
+ arr[e3] = arr[e2];
+ arr[e2] = t;
+ idxs[e3] = idxs[e2];
+ idxs[e2] = v;
+
+ if (t < arr[e1]) {
+ arr[e2] = arr[e1];
+ arr[e1] = t;
+ idxs[e2] = idxs[e1];
+ idxs[e1] = v;
+ }
+ }
+ }
+ }
+
+ // Pointers
+ int less = left; // The index of the first element of center part
+ int great = right; // The index before the first element of right part
+
+ if (arr[e1] != arr[e2] && arr[e2] != arr[e3] && arr[e3] != arr[e4] && arr[e4] != arr[e5]) {
+ /*
+ * Use the second and fourth of the five sorted elements as pivots.
+ * These values are inexpensive approximations of the first and
+ * second terciles of the array. Note that pivot1 <= pivot2.
+ */
+ double pivot1 = arr[e2];
+ double pivot2 = arr[e4];
+ int idxPivot1 = idxs[e2];
+ int idxPivot2 = idxs[e4];
+
+ /*
+ * The first and the last elements to be sorted are moved to the
+ * locations formerly occupied by the pivots. When partitioning
+ * is complete, the pivots are swapped back into their final
+ * positions, and excluded from subsequent sorting.
+ */
+ arr[e2] = arr[left];
+ arr[e4] = arr[right];
+ idxs[e2] = idxs[left];
+ idxs[e4] = idxs[right];
+
+ /*
+ * Skip elements, which are less or greater than pivot values.
+ */
+ while (arr[++less] < pivot1);
+ while (arr[--great] > pivot2);
+
+ /*
+ * Partitioning:
+ *
+ * left part center part right part
+ * +--------------------------------------------------------------+
+ * | < pivot1 | pivot1 <= && <= pivot2 | ? | > pivot2 |
+ * +--------------------------------------------------------------+
+ * ^ ^ ^
+ * | | |
+ * less k great
+ *
+ * Invariants:
+ *
+ * all in (left, less) < pivot1
+ * pivot1 <= all in [less, k) <= pivot2
+ * all in (great, right) > pivot2
+ *
+ * Pointer k is the first index of ?-part.
+ */
+ outer:
+ for (int k = less - 1; ++k <= great; ) {
+ double ak = arr[k];
+ int bk = idxs[k];
+ if (ak < pivot1) { // Move arr[k] to left part
+ arr[k] = arr[less];
+ idxs[k] = idxs[less];
+ /*
+ * Here and below we use "arr[i] = b; i++;" instead
+ * of "arr[i++] = b;" due to performance issue.
+ */
+ arr[less] = ak;
+ idxs[less] = bk;
+ ++less;
+ } else if (ak > pivot2) { // Move arr[k] to right part
+ while (arr[great] > pivot2) {
+ if (great-- == k) {
+ break outer;
+ }
+ }
+ if (arr[great] < pivot1) { // arr[great] <= pivot2
+ arr[k] = arr[less];
+ arr[less] = arr[great];
+ idxs[k] = idxs[less];
+ idxs[less] = idxs[great];
+ ++less;
+ } else { // pivot1 <= arr[great] <= pivot2
+ arr[k] = arr[great];
+ idxs[k] = idxs[great];
+ }
+ /*
+ * Here and below we use "arr[i] = b; i--;" instead
+ * of "arr[i--] = b;" due to performance issue.
+ */
+ arr[great] = ak;
+ idxs[great] = bk;
+ --great;
+ }
+ }
+
+ // Swap pivots into their final positions
+ arr[left] = arr[less - 1]; arr[less - 1] = pivot1;
+ arr[right] = arr[great + 1]; arr[great + 1] = pivot2;
+ idxs[left] = idxs[less - 1]; idxs[less - 1] = idxPivot1;
+ idxs[right] = idxs[great + 1]; idxs[great + 1] = idxPivot2;
+
+ // Sort left and right parts recursively, excluding known pivots
+ sort(arr, idxs, left, less - 2, leftmost);
+ sort(arr, idxs, great + 2, right, false);
+
+ /*
+ * If center part is too large (comprises > 4/7 of the array),
+ * swap internal pivot values to ends.
+ */
+ if (less < e1 && e5 < great) {
+ /*
+ * Skip elements, which are equal to pivot values.
+ */
+ while (arr[less] == pivot1) {
+ ++less;
+ }
+
+ while (arr[great] == pivot2) {
+ --great;
+ }
+
+ /*
+ * Partitioning:
+ *
+ * left part center part right part
+ * +----------------------------------------------------------+
+ * | == pivot1 | pivot1 < && < pivot2 | ? | == pivot2 |
+ * +----------------------------------------------------------+
+ * ^ ^ ^
+ * | | |
+ * less k great
+ *
+ * Invariants:
+ *
+ * all in (*, less) == pivot1
+ * pivot1 < all in [less, k) < pivot2
+ * all in (great, *) == pivot2
+ *
+ * Pointer k is the first index of ?-part.
+ */
+ outer:
+ for (int k = less - 1; ++k <= great; ) {
+ double ak = arr[k];
+ int bk = idxs[k];
+ if (ak == pivot1) { // Move arr[k] to left part
+ arr[k] = arr[less];
+ arr[less] = ak;
+ idxs[k] = idxs[less];
+ idxs[less] = bk;
+ ++less;
+ } else if (ak == pivot2) { // Move arr[k] to right part
+ while (arr[great] == pivot2) {
+ if (great-- == k) {
+ break outer;
+ }
+ }
+ if (arr[great] == pivot1) { // arr[great] < pivot2
+ arr[k] = arr[less];
+ idxs[k] = idxs[less];
+ /*
+ * Even though arr[great] equals to pivot1, the
+ * assignment arr[less] = pivot1 may be incorrect,
+ * if arr[great] and pivot1 are floating-point zeros
+ * of different signs. Therefore in float and
+ * double sorting methods we have to use more
+ * accurate assignment arr[less] = arr[great].
+ */
+ arr[less] = arr[great];
+ idxs[less] = idxs[great];
+ ++less;
+ } else { // pivot1 < arr[great] < pivot2
+ arr[k] = arr[great];
+ idxs[k] = idxs[great];
+ }
+ arr[great] = ak;
+ idxs[great] = bk;
+ --great;
+ }
+ }
+ }
+
+ // Sort center part recursively
+ sort(arr, idxs, less, great, false);
+
+ } else { // Partitioning with one pivot
+ /*
+ * Use the third of the five sorted elements as pivot.
+ * This value is inexpensive approximation of the median.
+ */
+ double pivot = arr[e3];
+
+ /*
+ * Partitioning degenerates to the traditional 3-way
+ * (or "Dutch National Flag") schema:
+ *
+ * left part center part right part
+ * +-------------------------------------------------+
+ * | < pivot | == pivot | ? | > pivot |
+ * +-------------------------------------------------+
+ * ^ ^ ^
+ * | | |
+ * less k great
+ *
+ * Invariants:
+ *
+ * all in (left, less) < pivot
+ * all in [less, k) == pivot
+ * all in (great, right) > pivot
+ *
+ * Pointer k is the first index of ?-part.
+ */
+ for (int k = less; k <= great; ++k) {
+ if (arr[k] == pivot) {
+ continue;
+ }
+ double ak = arr[k];
+ int bk = idxs[k];
+ if (ak < pivot) { // Move arr[k] to left part
+ arr[k] = arr[less];
+ arr[less] = ak;
+ idxs[k] = idxs[less];
+ idxs[less] = bk;
+ ++less;
+ } else { // arr[k] > pivot - Move arr[k] to right part
+ while (arr[great] > pivot) {
+ --great;
+ }
+ if (arr[great] < pivot) { // arr[great] <= pivot
+ arr[k] = arr[less];
+ arr[less] = arr[great];
+ idxs[k] = idxs[less];
+ idxs[less] = idxs[great];
+ ++less;
+ } else { // arr[great] == pivot
+ /*
+ * Even though arr[great] equals to pivot, the
+ * assignment arr[k] = pivot may be incorrect,
+ * if arr[great] and pivot are floating-point
+ * zeros of different signs. Therefore in float
+ * and double sorting methods we have to use
+ * more accurate assignment arr[k] = arr[great].
+ */
+ arr[k] = arr[great];
+ idxs[k] = idxs[great];
+ }
+ arr[great] = ak;
+ idxs[great] = bk;
+ --great;
+ }
+ }
+
+ /*
+ * Sort left and right parts recursively.
+ * All elements from center part are equal
+ * and, therefore, already sorted.
+ */
+ sort(arr, idxs, left, less - 1, leftmost);
+ sort(arr, idxs, great + 1, right, false);
+ }
+ }
+}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/TreeUtil.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/TreeUtil.scala
index fa6eb63685d2..710ec48f72cb 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/TreeUtil.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/TreeUtil.scala
@@ -48,14 +48,8 @@ private[tree] object TreeUtil {
* (First collect stats to decide how to partition.)
* TODO: Move elsewhere in MLlib.
*/
- def rowToColumnStoreDense(rowStore: RDD[Vector]): RDD[(Int, Vector)] = {
+ def rowToColumnStoreDense(rowStore: RDD[Vector], numRows: Int): RDD[(Int, Vector)] = {
- val numRows = {
- val longNumRows: Long = rowStore.count()
- require(longNumRows < Int.MaxValue, s"rowToColumnStore given RDD with $longNumRows rows," +
- s" but can handle at most ${Int.MaxValue} rows")
- longNumRows.toInt
- }
if (numRows == 0) {
return rowStore.sparkContext.parallelize(Seq.empty[(Int, Vector)])
}
@@ -89,12 +83,16 @@ private[tree] object TreeUtil {
// = column values for each instance in sourcePartitionIndex,
// where colIdx is a 0-based index for columns for groupIndex
val columnSets = new Array[Array[ArrayBuffer[Double]]](numTargetPartitions)
- Range(0, numTargetPartitions).foreach { groupIndex =>
+ var groupIndex = 0
+ while(groupIndex < numTargetPartitions) {
columnSets(groupIndex) =
Array.fill[ArrayBuffer[Double]](getNumColsInGroup(groupIndex))(ArrayBuffer[Double]())
+ groupIndex += 1
}
- iterator.foreach { row =>
- Range(0, numTargetPartitions).foreach { groupIndex =>
+ while (iterator.hasNext) {
+ val row: Vector = iterator.next()
+ var groupIndex = 0
+ while (groupIndex < numTargetPartitions) {
val fromCol = groupIndex * maxColumnsPerPartition
val numColsInTargetPartition = getNumColsInGroup(groupIndex)
// TODO: match-case here on row as Dense or Sparse Vector (for speed)
@@ -103,6 +101,7 @@ private[tree] object TreeUtil {
columnSets(groupIndex)(colIdx) += row(fromCol + colIdx)
colIdx += 1
}
+ groupIndex += 1
}
}
Range(0, numTargetPartitions).map { groupIndex =>
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala
index d1e624fb8d6c..862e29bd1093 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala
@@ -160,7 +160,7 @@ private[spark] abstract class ImpurityCalculator(val stats: Array[Double]) exten
def prob(label: Double): Double = -1
/** Get [[Predict]] struct. */
- def getPredict = {
+ def getPredict: Predict = {
val pred = this.predict
new Predict(predict = pred, prob = this.prob(pred))
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/AltDTSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/AltDTSuite.scala
index 860c8ea78760..1c24dcd73608 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/AltDTSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/AltDTSuite.scala
@@ -21,14 +21,15 @@ import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.regression.DecisionTreeRegressor
import org.apache.spark.ml.tree._
import org.apache.spark.ml.tree.impl.AltDT.{AltDTMetadata, FeatureVector, PartitionInfo}
-import org.apache.spark.ml.tree.impl.TreeUtil._
-import org.apache.spark.mllib.linalg.{SparseVector, Vector, Vectors}
+import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.impurity._
import org.apache.spark.mllib.tree.model.ImpurityStats
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.util.collection.BitSet
+import scala.util.Random
+
/**
* Test suite for [[AltDT]].
*/
@@ -72,6 +73,39 @@ class AltDTSuite extends SparkFunSuite with MLlibTestSparkContext {
//////////////////////////////// Helper classes //////////////////////////////////
+ test("DualPivotQuicksort: sort array") {
+ val r = new Random()
+ val len = 1000
+ val doubles = Array.fill[Double](len)(r.nextDouble)
+ val idxs = Array.fill[Int](len)(r.nextInt(len))
+
+ val (sortedDoubles, sortedIdxs) = doubles.zip(idxs).sorted.unzip
+ DualPivotQuicksort.sort(doubles, idxs)
+
+ assert(sortedDoubles.sameElements(doubles))
+ assert(sortedIdxs.sameElements(idxs))
+ }
+
+ test("DualPivotQuicksort: sort sub-arrays") {
+ val r = new Random()
+ val len = 1000
+ val partition = r.nextInt(len) // partition point between the two sub-arrays
+ val doubles = Array.fill[Double](len)(r.nextDouble)
+ val idxs = Array.fill[Int](len)(r.nextInt(len))
+
+ val (sortedLeftDoubles, sortedLeftIdxs) = doubles.slice(0, partition).zip(idxs.slice(0, partition)).sorted.unzip
+ val (sortedRightDoubles, sortedRightIdxs) = doubles.slice(partition, len).zip(idxs.slice(partition, len)).sorted.unzip
+
+ val sortedDoubles = sortedLeftDoubles ++ sortedRightDoubles
+ val sortedIdxs = sortedLeftIdxs ++ sortedRightIdxs
+
+ DualPivotQuicksort.sort(doubles, idxs, 0, partition - 1)
+ DualPivotQuicksort.sort(doubles, idxs, partition, len - 1)
+
+ assert(sortedDoubles.sameElements(doubles))
+ assert(sortedIdxs.sameElements(idxs))
+ }
+
test("FeatureVector") {
val v = new FeatureVector(1, 0, Array(0.1, 0.3, 0.7), Array(1, 2, 0))
@@ -100,11 +134,12 @@ class AltDTSuite extends SparkFunSuite with MLlibTestSparkContext {
// Create bitVector for splitting the 4 rows: L, R, L, R
// New groups are {0, 2}, {1, 3}
- val bitVector = new BitSubvector(0, numRows)
+ val bitVector = new BitSet(numRows)
bitVector.set(1)
bitVector.set(3)
- val newInfo = info.update(Array(bitVector), newNumNodeOffsets = 3)
+ // for these tests, use the activeNodes for nodeSplitBitVector
+ val newInfo = info.update(bitVector, activeNodes, newNumNodeOffsets = 3)
assert(newInfo.columns.length === 2)
val expectedCol1a =
@@ -117,12 +152,16 @@ class AltDTSuite extends SparkFunSuite with MLlibTestSparkContext {
assert(newInfo.activeNodes.iterator.toSet === Set(0, 1))
// Create 2 bitVectors for splitting into: 0, 2, 1, 3
- val bv2a = new BitSubvector(0, 2)
- bv2a.set(1)
- val bv2b = new BitSubvector(2, 4)
- bv2b.set(3)
+ val bitVector2 = new BitSet(numRows)
+ bitVector2.set(2) // 2 goes to the right
+ bitVector2.set(3) // 3 goes to the right
- val newInfo2 = newInfo.update(Array(bv2a, bv2b), newNumNodeOffsets = 5)
+ // both nodes find a split
+ val nodeSplits = new BitSet(2)
+ nodeSplits.set(0)
+ nodeSplits.set(1)
+
+ val newInfo2 = newInfo.update(bitVector2, nodeSplits, newNumNodeOffsets = 5)
assert(newInfo2.columns.length === 2)
val expectedCol2a =
@@ -286,31 +325,34 @@ class AltDTSuite extends SparkFunSuite with MLlibTestSparkContext {
FeatureVector.fromOriginal(0, 0, Vectors.dense(0.1, 0.2, 0.4, 0.6, 0.7))
val fromOffset = 0
val toOffset = col.values.length
+ val numRows = col.values.length
val split = new ContinuousSplit(0, threshold = 0.5)
- val bitv = AltDT.bitSubvectorFromSplit(col, fromOffset, toOffset, split)
- assert(bitv.from === fromOffset)
- assert(bitv.to === toOffset)
+ val bitv = AltDT.bitVectorFromSplit(col, fromOffset, toOffset, split, numRows)
assert(bitv.iterator.toSet === Set(3, 4))
}
test("bitSubvectorFromSplit: 2 nodes") {
// Initially, 1 split: (0, 2, 4) | (1, 3)
+ // (0.4, 0.2, 0.1) | (0.6, 0.7)
+ // 0, 1, 2, 3, 4
+ // 0.4, 0.6, 0.2, 0.7, 0.1
val col = new FeatureVector(0, 0, Array(0.1, 0.2, 0.4, 0.6, 0.7),
Array(4, 2, 0, 1, 3))
def checkSplit(fromOffset: Int, toOffset: Int, threshold: Double, expectedRight: Set[Int]): Unit = {
val split = new ContinuousSplit(0, threshold)
- val bitv = AltDT.bitSubvectorFromSplit(col, fromOffset, toOffset, split)
- assert(bitv.from === fromOffset)
- assert(bitv.to === toOffset)
+ val numRows = col.values.length
+ val bitv = AltDT.bitVectorFromSplit(col, fromOffset, toOffset, split, numRows)
assert(bitv.iterator.toSet === expectedRight)
}
+
// Left child node
- checkSplit(0, 3, 0.15, Set(0, 1))
+ checkSplit(0, 3, 0.05, Set(0, 2, 4))
+ checkSplit(0, 3, 0.15, Set(0, 2))
checkSplit(0, 3, 0.2, Set(0))
checkSplit(0, 3, 0.5, Set())
// Right child node
- checkSplit(3, 5, 0.1, Set(3, 4))
- checkSplit(3, 5, 0.65, Set(4))
+ checkSplit(3, 5, 0.1, Set(1, 3))
+ checkSplit(3, 5, 0.65, Set(3))
checkSplit(3, 5, 0.8, Set())
}
@@ -323,11 +365,8 @@ class AltDTSuite extends SparkFunSuite with MLlibTestSparkContext {
val info = PartitionInfo(Array(col), Array(0, numRows), activeNodes)
val partitionInfos = sc.parallelize(Seq(info))
val bestSplit = new ContinuousSplit(0, threshold = 0.5)
- val bitVectors = AltDT.collectBitVectors(partitionInfos, Array(Some(bestSplit)))
- assert(bitVectors.length === 1)
- val bitv = bitVectors.head
- assert(bitv.numBits === numRows)
- assert(bitv.iterator.toArray === Array(3, 4))
+ val bitVector = AltDT.aggregateBitVector(partitionInfos, Array(Some(bestSplit)), numRows)
+ assert(bitVector.iterator.toSet === Set(3, 4))
}
test("collectBitVectors with 1 vector, with tied threshold") {
@@ -339,11 +378,8 @@ class AltDTSuite extends SparkFunSuite with MLlibTestSparkContext {
val info = PartitionInfo(Array(col), Array(0, numRows), activeNodes)
val partitionInfos = sc.parallelize(Seq(info))
val bestSplit = new ContinuousSplit(0, threshold = -2.0)
- val bitVectors = AltDT.collectBitVectors(partitionInfos, Array(Some(bestSplit)))
- assert(bitVectors.length === 1)
- val bitv = bitVectors.head
- assert(bitv.numBits === numRows)
- assert(bitv.iterator.toArray === Array(0, 1, 4, 5))
+ val bitVector = AltDT.aggregateBitVector(partitionInfos, Array(Some(bestSplit)), numRows)
+ assert(bitVector.iterator.toSet === Set(0, 1, 4, 5))
}
//////////////////////////////// Active nodes //////////////////////////////////
diff --git a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/BitSubvectorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/BitSubvectorSuite.scala
deleted file mode 100644
index 6e7b67623851..000000000000
--- a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/BitSubvectorSuite.scala
+++ /dev/null
@@ -1,69 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.ml.tree.impl
-
-import org.apache.spark.SparkFunSuite
-import org.apache.spark.mllib.util.MLlibTestSparkContext
-
-/**
- * Test suite for [[BitSubvector]].
- */
-class BitSubvectorSuite extends SparkFunSuite with MLlibTestSparkContext {
-
- test("BitSubvector basic ops") {
- val from = 1
- val to = 4
- val bs = new BitSubvector(from, to)
- assert(bs.numBits === to - from)
- Range(from, to).foreach(x => assert(!bs.get(x)))
- val setVals = Array(from, to - 1)
- setVals.foreach { x =>
- bs.set(x)
- assert(bs.get(x))
- }
- assert(bs.iterator.toSet === setVals.toSet)
- }
-
- test("BitSubvector merge") {
- val b1 = new BitSubvector(0, 5)
- b1.set(1)
- val b2 = new BitSubvector(5, 7)
- b2.set(5)
- val b3 = new BitSubvector(9, 12)
- b3.set(11)
- val parts1 = Array(b1)
- val parts2 = Array(b2, b3)
- val newParts = BitSubvector.merge(parts1, parts2)
-
- val r1 = new BitSubvector(0, 7)
- r1.set(1)
- r1.set(5)
- val r2 = new BitSubvector(9, 12)
- r2.set(11)
- val expectedParts = Array(r1, r2)
- newParts.zip(expectedParts).foreach { case (x, y) =>
- assert(x.from === y.from)
- assert(x.to === x.to)
- assert(x.iterator.toSet === y.iterator.toSet)
- }
- }
-
- test("BitSubvector merge with empty BitSubvectors") {
- val parts = BitSubvector.merge(Array.empty[BitSubvector], Array.empty[BitSubvector])
- }
-}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeUtilSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeUtilSuite.scala
index 483fb8568f8b..abacb8b8b7d8 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeUtilSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeUtilSuite.scala
@@ -32,7 +32,7 @@ class TreeUtilSuite extends SparkFunSuite with MLlibTestSparkContext {
private def checkDense(rows: Seq[Vector]): Unit = {
val numRowPartitions = 2
val rowStore = sc.parallelize(rows, numRowPartitions)
- val colStore = rowToColumnStoreDense(rowStore)
+ val colStore = rowToColumnStoreDense(rowStore, rowStore.count.toInt)
val numColPartitions = colStore.partitions.length
val cols: Map[Int, Vector] = colStore.collect().toMap
val numRows = rows.size