diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/AltDT.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/AltDT.scala index a9094310da8c..d1d19fe3635e 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/AltDT.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/AltDT.scala @@ -17,8 +17,6 @@ package org.apache.spark.ml.tree.impl -import scala.collection.mutable.ArrayBuffer - import org.apache.spark.Logging import org.apache.spark.broadcast.Broadcast import org.apache.spark.ml.tree._ @@ -26,7 +24,7 @@ import org.apache.spark.ml.tree.impl.TreeUtil._ import org.apache.spark.mllib.linalg.Vector import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.configuration.Strategy -import org.apache.spark.mllib.tree.impurity.{Variance, Gini, Entropy, Impurity} +import org.apache.spark.mllib.tree.impurity.{Entropy, Gini, Impurity, Variance} import org.apache.spark.mllib.tree.model.ImpurityStats import org.apache.spark.rdd.RDD import org.apache.spark.storage.StorageLevel @@ -89,7 +87,8 @@ private[ml] object AltDT extends Logging { } private[impl] object AltDTMetadata { - def fromStrategy(strategy: Strategy) = new AltDTMetadata(strategy.numClasses, strategy.maxBins, + def fromStrategy(strategy: Strategy): AltDTMetadata = + new AltDTMetadata(strategy.numClasses, strategy.maxBins, strategy.minInfoGain, strategy.impurity) } @@ -122,16 +121,17 @@ private[ml] object AltDT extends Logging { impurityCalculator) } + val numRows = { + val longNumRows: Long = input.count() + require(longNumRows < Int.MaxValue, s"rowToColumnStore given RDD with $longNumRows rows," + + s" but can handle at most ${Int.MaxValue} rows") + longNumRows.toInt + } // Prepare column store. - // Note: rowToColumnStoreDense checks to make sure numRows < Int.MaxValue. // TODO: Is this mapping from arrays to iterators to arrays (when constructing learningData)? // Or is the mapping implicit (i.e., not costly)? - val colStoreInit: RDD[(Int, Vector)] = rowToColumnStoreDense(input.map(_.features)) - val numRows: Int = colStoreInit.first()._2.size - val labels = new Array[Double](numRows) - input.map(_.label).zipWithIndex().collect().foreach { case (label: Double, rowIndex: Long) => - labels(rowIndex.toInt) = label - } + val colStoreInit: RDD[(Int, Vector)] = rowToColumnStoreDense(input.map(_.features), numRows) + val labels = input.map(_.label).collect() val labelsBc = input.sparkContext.broadcast(labels) // NOTE: Labels are not sorted with features since that would require 1 copy per feature, // rather than 1 copy per worker. This means a lot of random accesses. @@ -144,15 +144,14 @@ private[ml] object AltDT extends Logging { } // Group columns together into one array of columns per partition. // TODO: Test avoiding this grouping, and see if it matters. - val groupedColStore: RDD[Array[FeatureVector]] = colStore.mapPartitions { iterator => - val groupedCols = new ArrayBuffer[FeatureVector] - iterator.foreach(groupedCols += _) - if (groupedCols.nonEmpty) Iterator(groupedCols.toArray) else Iterator() + val groupedColStore: RDD[Array[FeatureVector]] = colStore.mapPartitions { + iterator: Iterator[FeatureVector] => + if (iterator.nonEmpty) Iterator(iterator.toArray) else Iterator() } groupedColStore.persist(StorageLevel.MEMORY_AND_DISK) // Initialize partitions with 1 node (each instance at the root node). - var partitionInfosA: RDD[PartitionInfo] = groupedColStore.map { groupedCols => + var partitionInfos: RDD[PartitionInfo] = groupedColStore.map { groupedCols => val initActive = new BitSet(1) initActive.set(0) new PartitionInfo(groupedCols, Array[Int](0, numRows), initActive) @@ -165,16 +164,10 @@ private[ml] object AltDT extends Logging { var activeNodePeriphery: Array[LearningNode] = Array(rootNode) var numNodeOffsets: Int = 2 - val partitionInfosDebug = new scala.collection.mutable.ArrayBuffer[RDD[PartitionInfo]]() - partitionInfosDebug.append(partitionInfosA) - // Iteratively learn, one level of the tree at a time. var currentLevel = 0 var doneLearning = false while (currentLevel < strategy.maxDepth && !doneLearning) { - - val partitionInfos = partitionInfosDebug.last - // Compute best split for each active node. val bestSplitsAndGains: Array[(Option[Split], ImpurityStats)] = computeBestSplits(partitionInfos, labelsBc, metadata) @@ -202,20 +195,25 @@ private[ml] object AltDT extends Logging { doneLearning = currentLevel + 1 >= strategy.maxDepth || estimatedRemainingActive == 0 if (!doneLearning) { - // Aggregate bit vector (1 bit/instance) indicating whether each instance goes left/right. - val aggBitVectors: Array[BitSubvector] = - collectBitVectors(partitionInfos, bestSplitsAndGains.map(_._1)) - - // Broadcast aggregated bit vectors. On each partition, update instance--node map. - val aggBitVectorsBc = input.sparkContext.broadcast(aggBitVectors) - // partitionInfos = partitionInfos.map { partitionInfo => - val partitionInfosB = partitionInfos.map { partitionInfo => - partitionInfo.update(aggBitVectorsBc.value, numNodeOffsets) + val splits: Array[Option[Split]] = bestSplitsAndGains.map(_._1) + // construct bit vector encoding which active nodes found a split + val nodeSplitBitVector: BitSet = splits.zipWithIndex.foldLeft(new BitSet(splits.length)) { + (acc: BitSet, splitAndIdx: (Option[Split], Int)) => + if (splitAndIdx._1.isDefined) { + acc.set(splitAndIdx._2) + } + acc } - partitionInfosB.cache().count() // TODO: remove. For some reason, this is needed to make things work. Probably messing up somewhere above... - partitionInfosDebug.append(partitionInfosB) - // TODO: unpersist aggBitVectorsBc after action. + // Aggregate bit vector (1 bit/instance) indicating whether each instance goes left/right. + val aggBitVector: BitSet = aggregateBitVector(partitionInfos, splits, numRows) + val newPartitionInfos = partitionInfos.map { partitionInfo => + partitionInfo.update(aggBitVector, nodeSplitBitVector, numNodeOffsets) + } + // TODO: remove. For some reason, this is needed to make things work. + // Probably messing up somewhere above... + newPartitionInfos.cache().count() + partitionInfos = newPartitionInfos } currentLevel += 1 @@ -260,18 +258,25 @@ private[ml] object AltDT extends Logging { // for each active node, best split + info gain, // where the best split is None if no useful split exists val partBestSplitsAndGains: RDD[Array[(Option[Split], ImpurityStats)]] = partitionInfos.map { - case PartitionInfo(columns: Array[FeatureVector], nodeOffsets: Array[Int], activeNodes: BitSet) => + case PartitionInfo(columns: Array[FeatureVector], nodeOffsets: Array[Int], + activeNodes: BitSet) => val localLabels = labelsBc.value // Iterate over the active nodes in the current level. - activeNodes.iterator.map { nodeIndexInLevel: Int => + val toReturn = new Array[(Option[Split], ImpurityStats)](activeNodes.cardinality()) + val iter: Iterator[Int] = activeNodes.iterator + var i = 0 + while (iter.hasNext) { + val nodeIndexInLevel = iter.next val fromOffset = nodeOffsets(nodeIndexInLevel) val toOffset = nodeOffsets(nodeIndexInLevel + 1) val splitsAndStats = columns.map { col => chooseSplit(col, localLabels, fromOffset, toOffset, metadata) } - splitsAndStats.maxBy(_._2.gain) - }.toArray + toReturn(i) = splitsAndStats.maxBy(_._2.gain) + i += 1 + } + toReturn } // TODO: treeReduce @@ -332,25 +337,28 @@ private[ml] object AltDT extends Logging { * @param bestSplits Split for each active node, or None if that node will not be split * @return Array of bit vectors, ordered by offset ranges */ - private[impl] def collectBitVectors( + private[impl] def aggregateBitVector( partitionInfos: RDD[PartitionInfo], - bestSplits: Array[Option[Split]]): Array[BitSubvector] = { + bestSplits: Array[Option[Split]], + numRows: Int): BitSet = { + val bestSplitsBc: Broadcast[Array[Option[Split]]] = partitionInfos.sparkContext.broadcast(bestSplits) - val workerBitSubvectors: RDD[Array[BitSubvector]] = partitionInfos.map { + val workerBitSubvectors: RDD[BitSet] = partitionInfos.map { case PartitionInfo(columns: Array[FeatureVector], nodeOffsets: Array[Int], activeNodes: BitSet) => val localBestSplits: Array[Option[Split]] = bestSplitsBc.value // localFeatureIndex[feature index] = index into PartitionInfo.columns val localFeatureIndex: Map[Int, Int] = columns.map(_.featureIndex).zipWithIndex.toMap - activeNodes.iterator.zip(localBestSplits.iterator).flatMap { + val bitSetForNodes: Iterator[BitSet] = activeNodes.iterator.zip(localBestSplits.iterator). + flatMap { case (nodeIndexInLevel: Int, Some(split: Split)) => if (localFeatureIndex.contains(split.featureIndex)) { // This partition has the column (feature) used for this split. val fromOffset = nodeOffsets(nodeIndexInLevel) val toOffset = nodeOffsets(nodeIndexInLevel + 1) val colIndex: Int = localFeatureIndex(split.featureIndex) - Iterator(bitSubvectorFromSplit(columns(colIndex), fromOffset, toOffset, split)) + Iterator(bitVectorFromSplit(columns(colIndex), fromOffset, toOffset, split, numRows)) } else { Iterator() } @@ -358,16 +366,22 @@ private[ml] object AltDT extends Logging { // Do not create a BitSubvector when there is no split. // This requires PartitionInfo.update to handle missing BitSubvectors. Iterator() - }.toArray + } + if (bitSetForNodes.isEmpty) { + new BitSet(0) + } else { + bitSetForNodes.reduce[BitSet]((acc: BitSet, bitv: BitSet) => acc | bitv) + } + } + val aggBitVector: BitSet = workerBitSubvectors.reduce { (acc: BitSet, bitv: BitSet) => + acc | bitv } - val aggBitVectors: Array[BitSubvector] = workerBitSubvectors.reduce(BitSubvector.merge) bestSplitsBc.unpersist() - aggBitVectors + aggBitVector } /** * Choose the best split for a feature at a node. - * * TODO: Return null or None when the split is invalid, such as putting all instances on one * child node. * @@ -425,10 +439,15 @@ private[ml] object AltDT extends Logging { // TODO: Support high-arity features by using a single array to hold the stats. // aggStats(category) = label statistics for category - val aggStats = Array.tabulate[ImpurityAggregatorSingle](featureArity)( - _ => metadata.createImpurityAggregator()) - values.zip(labels).foreach { case (cat, label) => + val aggStats: Array[ImpurityAggregatorSingle] = Array.tabulate[ImpurityAggregatorSingle]( + featureArity)(_ => metadata.createImpurityAggregator()) + var i = 0 + val len = values.length + while (i < len) { + val cat = values(i) + val label = labels(i) aggStats(cat.toInt).update(label) + i += 1 } // Compute centroids. centroidsForCategories is a list: (category, centroid) @@ -476,9 +495,14 @@ private[ml] object AltDT extends Logging { val categoriesSortedByCentroid: List[Int] = centroidsForCategories.toList.sortBy(_._2).map(_._1) // Cumulative sums of bin statistics for left, right parts of split. - val leftImpurityAgg = metadata.createImpurityAggregator() - val rightImpurityAgg = metadata.createImpurityAggregator() - aggStats.foreach(rightImpurityAgg.add) + val leftImpurityAgg: ImpurityAggregatorSingle = metadata.createImpurityAggregator() + val rightImpurityAgg: ImpurityAggregatorSingle = metadata.createImpurityAggregator() + var j = 0 + val length = aggStats.length + while (j < length) { + rightImpurityAgg.add(aggStats(j)) + j += 1 + } var bestSplitIndex: Int = -1 // index into categoriesSortedByCentroid val bestLeftImpurityAgg = leftImpurityAgg.deepCopy() @@ -550,7 +574,12 @@ private[ml] object AltDT extends Logging { val leftImpurityAgg = metadata.createImpurityAggregator() val rightImpurityAgg = metadata.createImpurityAggregator() - labels.foreach(rightImpurityAgg.update(_, 1.0)) + var i = 0 + val len = labels.length + while (i < len) { + rightImpurityAgg.update(labels(i), 1.0) + i += 1 + } var bestThreshold: Double = Double.NegativeInfinity val bestLeftImpurityAgg = leftImpurityAgg.deepCopy() @@ -560,7 +589,11 @@ private[ml] object AltDT extends Logging { var rightCount: Double = rightImpurityAgg.getCount val fullCount: Double = rightCount var currentThreshold = values.headOption.getOrElse(bestThreshold) - values.zip(labels).foreach { case (value, label) => + var j = 0 + val length = values.length + while (j < length) { + val value = values(j) + val label = labels(j) if (value != currentThreshold) { // Check gain val leftWeight = leftCount / fullCount @@ -580,6 +613,7 @@ private[ml] object AltDT extends Logging { rightImpurityAgg.update(label, -1.0) leftCount += 1.0 rightCount -= 1.0 + j += 1 } val fullImpurityAgg = leftImpurityAgg.deepCopy().add(rightImpurityAgg) @@ -644,8 +678,10 @@ private[ml] object AltDT extends Logging { featureIndex: Int, featureArity: Int, featureVector: Vector): FeatureVector = { - val (values, indices) = featureVector.toArray.zipWithIndex.sorted.unzip - new FeatureVector(featureIndex, featureArity, values.toArray, indices.toArray) + val values = featureVector.toArray + val indices = Range(0, values.length).toArray + DualPivotQuicksort.sort(values, indices) + new FeatureVector(featureIndex, featureArity, values, indices) } } @@ -664,19 +700,23 @@ private[ml] object AltDT extends Logging { * second by sorted row indices within the node's rows. * bit[index in sorted array of row indices] = false for left, true for right */ - private[impl] def bitSubvectorFromSplit( + private[impl] def bitVectorFromSplit( col: FeatureVector, fromOffset: Int, toOffset: Int, - split: Split): BitSubvector = { - val nodeRowIndices = col.indices.view.slice(fromOffset, toOffset).toArray - val nodeRowValues = col.values.view.slice(fromOffset, toOffset).toArray - val nodeRowValuesSortedByIndices = nodeRowIndices.zip(nodeRowValues).sortBy(_._1).map(_._2) - val bitv = new BitSubvector(fromOffset, toOffset) - nodeRowValuesSortedByIndices.zipWithIndex.foreach { case (value, i) => + split: Split, + numRows: Int): BitSet = { + val nodeRowIndices = col.indices.slice(fromOffset, toOffset) + val nodeRowValues = col.values.slice(fromOffset, toOffset) + val bitv = new BitSet(numRows) + var i = 0 + while (i < nodeRowValues.length) { + val value = nodeRowValues(i) + val idx = nodeRowIndices(i) if (!split.shouldGoLeft(value)) { - bitv.set(fromOffset + i) + bitv.set(idx) } + i += 1 } bitv } @@ -728,75 +768,106 @@ private[ml] object AltDT extends Logging { * Update nodeOffsets, activeNodes: * Split offsets for nodes which split (which can be identified using the bit vector). * - * @param bitVectors Bit vectors encoding splits for the next level of the tree. + * @param instanceBitVector Bit vector encoding splits for the next level of the tree. * These must follow a 2-level ordering, where the first level is by node * and the second level is by row index. * bitVector(i) = false iff instance i goes to the left child. * For instances at inactive (leaf) nodes, the value can be arbitrary. - * When an active node is not split (e.g., because no good split was found), - * then the corresponding BitSubvector can be missing. + * @param nodeSplitBitVector Bit vector encoding whether an active node was split or not * @return Updated partition info */ - def update(bitVectors: Array[BitSubvector], newNumNodeOffsets: Int): PartitionInfo = { - val newColumns = columns.map { oldCol => - val col = oldCol.deepCopy() - var curBitVecIdx = 0 + def update(instanceBitVector: BitSet, nodeSplitBitVector: BitSet, newNumNodeOffsets: Int): + PartitionInfo = { + // Create a 2-level representation of the new nodeOffsets (to be flattened). + // These 2 levels correspond to original nodes and their children (if split). + val newNodeOffsets = nodeOffsets.map(Array(_)) + + val newColumns = columns.map { col => activeNodes.iterator.foreach { nodeIdx => - val from = nodeOffsets(nodeIdx) - val to = nodeOffsets(nodeIdx + 1) - // TODO: Allow missing vectors when no split is chosen. - if (bitVectors(curBitVecIdx).to <= from) curBitVecIdx += 1 - val curBitVector = bitVectors(curBitVecIdx) // If the current BitVector does not cover this node, then this node was not split, // so we do not need to update its part of the column. Otherwise, we update it. - if (curBitVector.from <= from && to <= curBitVector.to) { - // Sort range [from, to) based on indices. This is required to match the bit vector - // across all workers. See [[bitSubvectorFromSplit]] for details. - val rangeIndices = col.indices.view.slice(from, to).toArray - val rangeValues = col.values.view.slice(from, to).toArray - val sortedRange = rangeIndices.zip(rangeValues).sortBy(_._1) - // Sort range [from, to) based on bit vector. - sortedRange.zipWithIndex.map { case ((idx, value), i) => - val bit = curBitVector.get(from + i) - // TODO: In-place merge, rather than general sort. - // TODO: We don't actually need to sort the categorical features using our approach. - (bit, value, idx) - }.sorted.zipWithIndex.foreach { case ((bit, value, idx), i) => - col.values(from + i) = value - col.indices(from + i) = idx - } - } - } - col - } + if (nodeSplitBitVector.get(nodeIdx)) { + val from = nodeOffsets(nodeIdx) + val to = nodeOffsets(nodeIdx + 1) + // Sort range [from, to) based on split, then value. This is required to match + // the bit vector across all workers. See [[bitVectorFromSplit]] for details. + // Within [from, to), we will have all "left child" instances (those that are false), + // then all "right child" instances. Then, within each child, we sort by value, so + // we can compute the best split for the next iteration. The corresponding index for + // an instance is used to look up the split value ("left" or "right") in the + // instanceBitVector, which is ordered by index. + val rangeIndices = col.indices.slice(from, to) + val rangeValues = col.values.slice(from, to) + + // BEGIN SORTING + var start = 0 + var end = rangeValues.length - 1 + // if this is the very first time we split + // we don't have to use the indices to figure + // out which bits are turned on + val numBitsSet = if (nodeOffsets.length == 2) instanceBitVector.cardinality + else rangeIndices.count(instanceBitVector.get) + val numBitsNotSet = to - from - numBitsSet - // Create a 2-level representation of the new nodeOffsets (to be flattened). - // These 2 levels correspond to original nodes and their children (if split). - val newNodeOffsets = nodeOffsets.map(Array(_)) - var curBitVecIdx = 0 - activeNodes.iterator.foreach { nodeIdx => - val from = nodeOffsets(nodeIdx) - val to = nodeOffsets(nodeIdx + 1) - if (bitVectors(curBitVecIdx).to <= from) curBitVecIdx += 1 - val curBitVector = bitVectors(curBitVecIdx) - // If the current BitVector does not cover this node, then this node was not split, - // so we do not need to create a new node offset. Otherwise, we create an offset. - if (curBitVector.from <= from && to <= curBitVector.to) { - // Count number of values splitting to left vs. right - val numRight = Range(from, to).count(curBitVector.get) - val numLeft = to - from - numRight - if (numLeft != 0 && numRight != 0) { - // node is split val oldOffset = newNodeOffsets(nodeIdx).head - newNodeOffsets(nodeIdx) = Array(oldOffset, oldOffset + numLeft) + // numBitsNotSet == number of instances going to the left + // which is how big the offset should be + if (numBitsNotSet == 0) { + newNodeOffsets(nodeIdx) = Array(oldOffset) + } else { + newNodeOffsets(nodeIdx) = Array(oldOffset, oldOffset + numBitsNotSet) + } + + // first we move all of the values and indices that have + // zero-bits to the front + while (start < numBitsNotSet && start <= end) { + // "start <= end" isn't really necessary, but we + // include it anyways + val startBit = instanceBitVector.get(rangeIndices(start)) + val endBit = instanceBitVector.get(rangeIndices(end)) + // if the startBit is false, we increment and move on + if (!startBit) { + start += 1 + } + // if endBit is true, we decrement and move on + if (endBit) { + end -= 1 + // if startBit is true and endBit is false, we swap + } else if (startBit && !endBit) { + // swap both in rangeValues and in rangeIndices + // (this should be a separate helper function, + // but we want to avoid function calls) + val tempVal = rangeValues(start) + rangeValues(start) = rangeValues(end) + rangeValues(end) = tempVal + val tempIdx = rangeIndices(start) + rangeIndices(start) = rangeIndices(end) + rangeIndices(end) = tempIdx + // update both start and end + start += 1 + end -= 1 + } + } + // Now, we sort the sub-arrays from [0, numBitsNotSet) and + // [numBitsNotSet, rangeValues.length) + DualPivotQuicksort.sort(rangeValues, rangeIndices, 0, numBitsNotSet - 1) + DualPivotQuicksort.sort(rangeValues, rangeIndices, numBitsNotSet, + rangeValues.length - 1) + // END SORTING + + // update the column values and indices + // with the corresponding indices + var i = 0 + while (i < rangeValues.length) { + col.values(from + i) = rangeValues(i) + col.indices(from + i) = rangeIndices(i) + i += 1 + } } } + col } - assert(newNodeOffsets.map(_.length).sum == newNumNodeOffsets, - s"(W) newNodeOffsets total size: ${newNodeOffsets.map(_.length).sum}," + - s" newNumNodeOffsets: $newNumNodeOffsets") - // Identify the new activeNodes based on the 2-level representation of the new nodeOffsets. val newActiveNodes = new BitSet(newNumNodeOffsets - 1) var newNodeOffsetsIdx = 0 @@ -809,9 +880,7 @@ private[ml] object AltDT extends Logging { newNodeOffsetsIdx += 1 } } - PartitionInfo(newColumns, newNodeOffsets.flatten, newActiveNodes) } } - } diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/BitSubvector.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/BitSubvector.scala deleted file mode 100644 index fe9b171cfd1b..000000000000 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/BitSubvector.scala +++ /dev/null @@ -1,81 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.ml.tree.impl - -import scala.collection.mutable.ArrayBuffer - -import org.apache.spark.util.collection.BitSet - - -private[impl] class BitSubvector(val from: Int, val to: Int) extends Serializable { - - val numBits: Int = to - from - - /** Element i will be put at location i + offset in the BitSet */ - private val offset: Int = 64 - (numBits % 64) - - private val bits: BitSet = new BitSet(numBits + offset) - - def set(bit: Int): Unit = bits.set(bit + offset - from) - - def get(bit: Int): Boolean = bits.get(bit + offset - from) - - /** Get an iterator over the set bits. */ - def iterator: Iterator[Int] = new Iterator[Int] { - val iter = bits.iterator - override def hasNext: Boolean = iter.hasNext - override def next(): Int = iter.next() - offset + from - } -} - -private[impl] object BitSubvector { - - def merge(parts1: Array[BitSubvector], parts2: Array[BitSubvector]): Array[BitSubvector] = { - // Merge sorted parts1, parts2 - val sortedSubvectors = (parts1 ++ parts2).sortBy(_.from) - if (sortedSubvectors.nonEmpty) { - // Merge adjacent PartialBitVectors (for adjacent node ranges) - val newSubvectorRanges: Array[(Int, Int)] = { - val newSubvRanges = ArrayBuffer.empty[(Int, Int)] - var i = 1 - var currentFrom = sortedSubvectors.head.from - while (i < sortedSubvectors.length) { - if (sortedSubvectors(i - 1).to != sortedSubvectors(i).from) { - newSubvRanges.append((currentFrom, sortedSubvectors(i - 1).to)) - currentFrom = sortedSubvectors(i).from - } - i += 1 - } - newSubvRanges.append((currentFrom, sortedSubvectors.last.to)) - newSubvRanges.toArray - } - val newSubvectors = newSubvectorRanges.map { case (from, to) => new BitSubvector(from, to) } - var curNewSubvIdx = 0 - sortedSubvectors.foreach { subv => - if (subv.to > newSubvectors(curNewSubvIdx).to) curNewSubvIdx += 1 - val newSubv = newSubvectors(curNewSubvIdx) - // TODO: More efficient (word-level) copy. - subv.iterator.foreach(idx => newSubv.set(idx)) - } - assert(curNewSubvIdx + 1 == newSubvectors.length) // sanity check - newSubvectors - } else { - Array.empty[BitSubvector] - } - } -} diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/DualPivotQuicksort.java b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/DualPivotQuicksort.java new file mode 100644 index 000000000000..a781f50b4cd1 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/DualPivotQuicksort.java @@ -0,0 +1,696 @@ +package org.apache.spark.ml.tree.impl; + +/** + * Created by fabuzaid21 on 11/2/15. + * Helper utility class for sorting + * Double arrays with corresponding Int indices in + * {@link org.apache.spark.ml.tree.impl.AltDT}. + * Provides a more efficient alternative to: + * + *
+ * {@code
+ * val doubles = Array.fill[Double](len)(Random.nextDouble)
+ * val idxs = Array.fill[Int](len)(Random.nextInt(len))
+ * val (sortedDoubles, sortedIdxs) = doubles.zip(idxs).sorted.unzip
+ * }
+ * 
+ * + * NOTE: This implementation is directly borrowed from + * Yaroslavskiy et al.'s implementation of the Dual-Pivot + * Quicksort Algorithm in OpenJDK 7, which can be found + * here. + */ +public class DualPivotQuicksort { + + /** + * Prevents instantiation. + */ + private DualPivotQuicksort() {} + + /** + * The maximum number of runs in merge sort. + */ + private static final int MAX_RUN_COUNT = 67; + + /** + * The maximum length of run in merge sort. + */ + private static final int MAX_RUN_LENGTH = 33; + + /** + * If the length of an array to be sorted is less than this + * constant, Quicksort is used in preference to merge sort. + */ + private static final int QUICKSORT_THRESHOLD = 286; + + /** + * If the length of an array to be sorted is less than this + * constant, insertion sort is used in preference to Quicksort. + */ + private static final int INSERTION_SORT_THRESHOLD = 47; + + /** + * Sorts the specified array. + * + * @param arr the array to be sorted + * @param idxs the corresponding indices array, updated to match arr's sorted order + */ + public static void sort(double[] arr, int[] idxs) { + sort(arr, idxs, 0, arr.length - 1); + } + + /** + * Sorts the specified range of the array. + * + * @param arr the array to be sorted + * @param idxs the corresponding indices array, updated to match arr's sorted order + * @param left the index of the first element, inclusive, to be sorted + * @param right the index of the last element, inclusive, to be sorted + */ + public static void sort(double[] arr, int[] idxs, int left, int right) { + /* + * Phase 1: Move NaNs to the end of the array. + */ + while (left <= right && Double.isNaN(arr[right])) { + --right; + } + for (int k = right; --k >= left; ) { + double ak = arr[k]; + int bk = idxs[k]; + if (ak != ak) { // arr[k] is NaN + arr[k] = arr[right]; + arr[right] = ak; + idxs[k] = idxs[right]; + idxs[right] = bk; + --right; + } + } + + /* + * Phase 2: Sort everything except NaNs (which are already in place). + */ + doSort(arr, idxs, left, right); + + /* + * Phase 3: Place negative zeros before positive zeros. + */ + int hi = right; + + /* + * Find the first zero, or first positive, or last negative element. + */ + while (left < hi) { + int middle = (left + hi) >>> 1; + double middleValue = arr[middle]; + + if (middleValue < 0.0d) { + left = middle + 1; + } else { + hi = middle; + } + } + + /* + * Skip the last negative value (if any) or all leading negative zeros. + */ + while (left <= right && Double.doubleToRawLongBits(arr[left]) < 0) { + ++left; + } + + /* + * Move negative zeros to the beginning of the sub-range. + * + * Partitioning: + * + * +----------------------------------------------------+ + * | < 0.0 | -0.0 | 0.0 | ? ( >= 0.0 ) | + * +----------------------------------------------------+ + * ^ ^ ^ + * | | | + * left p k + * + * Invariants: + * + * all in (*, left) < 0.0 + * all in [left, p) == -0.0 + * all in [p, k) == 0.0 + * all in [k, right] >= 0.0 + * + * Pointer k is the first index of ?-part. + */ + for (int k = left, p = left - 1; ++k <= right; ) { + double ak = arr[k]; + if (ak != 0.0d) { + break; + } + if (Double.doubleToRawLongBits(ak) < 0) { // ak is -0.0d + arr[k] = 0.0d; + arr[++p] = -0.0d; + } + } + } + + /** + * Sorts the specified range of the array. + * + * @param arr the array to be sorted + * @param idxs the corresponding indices array, updated to match arr's sorted order + * @param left the index of the first element, inclusive, to be sorted + * @param right the index of the last element, inclusive, to be sorted + */ + private static void doSort(double[] arr, int[] idxs, int left, int right) { + // Use Quicksort on small arrays + if (right - left < QUICKSORT_THRESHOLD) { + sort(arr, idxs, left, right, true); + return; + } + + /* + * Index run[i] is the start of i-th run + * (ascending or descending sequence). + */ + int[] run = new int[MAX_RUN_COUNT + 1]; + int count = 0; run[0] = left; + + // Check if the array is nearly sorted + for (int k = left; k < right; run[count] = k) { + if (arr[k] < arr[k + 1]) { // ascending + while (++k <= right && arr[k - 1] <= arr[k]); + } else if (arr[k] > arr[k + 1]) { // descending + while (++k <= right && arr[k - 1] >= arr[k]); + for (int lo = run[count] - 1, hi = k; ++lo < --hi; ) { + double t = arr[lo]; arr[lo] = arr[hi]; arr[hi] = t; + int v = idxs[lo]; idxs[lo] = idxs[hi]; idxs[hi] = v; + } + } else { // equal + for (int m = MAX_RUN_LENGTH; ++k <= right && arr[k - 1] == arr[k]; ) { + if (--m == 0) { + sort(arr, idxs, left, right, true); + return; + } + } + } + + /* + * The array is not highly structured, + * use Quicksort instead of merge sort. + */ + if (++count == MAX_RUN_COUNT) { + sort(arr, idxs, left, right, true); + return; + } + } + + // Check special cases + if (run[count] == right++) { // The last run contains one element + run[++count] = right; + } else if (count == 1) { // The array is already sorted + return; + } + + /* + * Create temporary array, which is used for merging. + * Implementation note: variable "right" is increased by 1. + */ + double[] tempArr; byte odd = 0; int[] tempIdxs; + for (int n = 1; (n <<= 1) < count; odd ^= 1); + + if (odd == 0) { + tempArr = arr; arr = new double[tempArr.length]; + tempIdxs = idxs; idxs = new int[tempIdxs.length]; + for (int i = left - 1; ++i < right; arr[i] = tempArr[i], idxs[i] = tempIdxs[i]); + } else { + tempArr = new double[arr.length]; + tempIdxs = new int[idxs.length]; + } + + // Merging + for (int last; count > 1; count = last) { + for (int k = (last = 0) + 2; k <= count; k += 2) { + int hi = run[k], mi = run[k - 1]; + for (int i = run[k - 2], p = i, q = mi; i < hi; ++i) { + if (q >= hi || p < mi && arr[p] <= arr[q]) { + tempIdxs[i] = idxs[p]; + tempArr[i] = arr[p++]; + } else { + tempIdxs[i] = idxs[q]; + tempArr[i] = arr[q++]; + } + } + run[++last] = hi; + } + if ((count & 1) != 0) { + for (int i = right, lo = run[count - 1]; --i >= lo; + tempArr[i] = arr[i], tempIdxs[i] = idxs[i] + ); + run[++last] = right; + } + double[] t = arr; arr = tempArr; tempArr = t; + int[] v = idxs; idxs = tempIdxs; tempIdxs = v; + } + } + + /** + * Sorts the specified range of the array by Dual-Pivot Quicksort. + * + * @param arr the array to be sorted + * @param idxs the corresponding indices array, updated to match arr's sorted order + * @param left the index of the first element, inclusive, to be sorted + * @param right the index of the last element, inclusive, to be sorted + * @param leftmost indicates if this part is the leftmost in the range + */ + private static void sort(double[] arr, int[] idxs, int left, int right, boolean leftmost) { + int length = right - left + 1; + + // Use insertion sort on tiny arrays + if (length < INSERTION_SORT_THRESHOLD) { + if (leftmost) { + /* + * Traditional (without sentinel) insertion sort, + * optimized for server VM, is used in case of + * the leftmost part. + */ + for (int i = left, j = i; i < right; j = ++i) { + double ai = arr[i + 1]; + int bi = idxs[i + 1]; + while (ai < arr[j]) { + arr[j + 1] = arr[j]; + idxs[j + 1] = idxs[j]; + if (j-- == left) { + break; + } + } + arr[j + 1] = ai; + idxs[j + 1] = bi; + } + } else { + /* + * Skip the longest ascending sequence. + */ + do { + if (left >= right) { + return; + } + } while (arr[++left] >= arr[left - 1]); + + /* + * Every element from adjoining part plays the role + * of sentinel, therefore this allows us to avoid the + * left range check on each iteration. Moreover, we use + * the more optimized algorithm, so called pair insertion + * sort, which is faster (in the context of Quicksort) + * than traditional implementation of insertion sort. + */ + for (int k = left; ++left <= right; k = ++left) { + double a1 = arr[k], a2 = arr[left]; + int b1 = idxs[k], b2 = idxs[left]; + + if (a1 < a2) { + a2 = a1; a1 = arr[left]; + b2 = b1; b1 = idxs[left]; + } + while (a1 < arr[--k]) { + arr[k + 2] = arr[k]; + idxs[k + 2] = idxs[k]; + } + arr[++k + 1] = a1; + idxs[k + 1] = b1; + + while (a2 < arr[--k]) { + arr[k + 1] = arr[k]; + idxs[k + 1] = idxs[k]; + } + arr[k + 1] = a2; + idxs[k + 1] = b2; + } + double last = arr[right]; + int bLast = idxs[right]; + + while (last < arr[--right]) { + arr[right + 1] = arr[right]; + idxs[right + 1] = idxs[right]; + } + arr[right + 1] = last; + idxs[right + 1] = bLast; + } + return; + } + + // Inexpensive approximation of length / 7 + int seventh = (length >> 3) + (length >> 6) + 1; + + /* + * Sort five evenly spaced elements around (and including) the + * center element in the range. These elements will be used for + * pivot selection as described below. The choice for spacing + * these elements was empirically determined to work well on + * arr wide variety of inputs. + */ + int e3 = (left + right) >>> 1; // The midpoint + int e2 = e3 - seventh; + int e1 = e2 - seventh; + int e4 = e3 + seventh; + int e5 = e4 + seventh; + + // Sort these elements using insertion sort + if (arr[e2] < arr[e1]) { + double t = arr[e2]; + int v = idxs[e2]; + + arr[e2] = arr[e1]; + arr[e1] = t; + idxs[e2] = idxs[e1]; + idxs[e1] = v; + } + + if (arr[e3] < arr[e2]) { + double t = arr[e3]; + int v = idxs[e3]; + + arr[e3] = arr[e2]; + arr[e2] = t; + idxs[e3] = idxs[e2]; + idxs[e2] = v; + + if (t < arr[e1]) { + arr[e2] = arr[e1]; + arr[e1] = t; + idxs[e2] = idxs[e1]; + idxs[e1] = v; + } + } + + if (arr[e4] < arr[e3]) { + double t = arr[e4]; + int v = idxs[e4]; + + arr[e4] = arr[e3]; + arr[e3] = t; + idxs[e4] = idxs[e3]; + idxs[e3] = v; + + if (t < arr[e2]) { + arr[e3] = arr[e2]; + arr[e2] = t; + idxs[e3] = idxs[e2]; + idxs[e2] = v; + + if (t < arr[e1]) { + arr[e2] = arr[e1]; + arr[e1] = t; + idxs[e2] = idxs[e1]; + idxs[e1] = v; + } + } + } + + if (arr[e5] < arr[e4]) { + double t = arr[e5]; + int v = idxs[e5]; + + arr[e5] = arr[e4]; + arr[e4] = t; + idxs[e5] = idxs[e4]; + idxs[e4] = v; + + if (t < arr[e3]) { + arr[e4] = arr[e3]; + arr[e3] = t; + idxs[e4] = idxs[e3]; + idxs[e3] = v; + + if (t < arr[e2]) { + arr[e3] = arr[e2]; + arr[e2] = t; + idxs[e3] = idxs[e2]; + idxs[e2] = v; + + if (t < arr[e1]) { + arr[e2] = arr[e1]; + arr[e1] = t; + idxs[e2] = idxs[e1]; + idxs[e1] = v; + } + } + } + } + + // Pointers + int less = left; // The index of the first element of center part + int great = right; // The index before the first element of right part + + if (arr[e1] != arr[e2] && arr[e2] != arr[e3] && arr[e3] != arr[e4] && arr[e4] != arr[e5]) { + /* + * Use the second and fourth of the five sorted elements as pivots. + * These values are inexpensive approximations of the first and + * second terciles of the array. Note that pivot1 <= pivot2. + */ + double pivot1 = arr[e2]; + double pivot2 = arr[e4]; + int idxPivot1 = idxs[e2]; + int idxPivot2 = idxs[e4]; + + /* + * The first and the last elements to be sorted are moved to the + * locations formerly occupied by the pivots. When partitioning + * is complete, the pivots are swapped back into their final + * positions, and excluded from subsequent sorting. + */ + arr[e2] = arr[left]; + arr[e4] = arr[right]; + idxs[e2] = idxs[left]; + idxs[e4] = idxs[right]; + + /* + * Skip elements, which are less or greater than pivot values. + */ + while (arr[++less] < pivot1); + while (arr[--great] > pivot2); + + /* + * Partitioning: + * + * left part center part right part + * +--------------------------------------------------------------+ + * | < pivot1 | pivot1 <= && <= pivot2 | ? | > pivot2 | + * +--------------------------------------------------------------+ + * ^ ^ ^ + * | | | + * less k great + * + * Invariants: + * + * all in (left, less) < pivot1 + * pivot1 <= all in [less, k) <= pivot2 + * all in (great, right) > pivot2 + * + * Pointer k is the first index of ?-part. + */ + outer: + for (int k = less - 1; ++k <= great; ) { + double ak = arr[k]; + int bk = idxs[k]; + if (ak < pivot1) { // Move arr[k] to left part + arr[k] = arr[less]; + idxs[k] = idxs[less]; + /* + * Here and below we use "arr[i] = b; i++;" instead + * of "arr[i++] = b;" due to performance issue. + */ + arr[less] = ak; + idxs[less] = bk; + ++less; + } else if (ak > pivot2) { // Move arr[k] to right part + while (arr[great] > pivot2) { + if (great-- == k) { + break outer; + } + } + if (arr[great] < pivot1) { // arr[great] <= pivot2 + arr[k] = arr[less]; + arr[less] = arr[great]; + idxs[k] = idxs[less]; + idxs[less] = idxs[great]; + ++less; + } else { // pivot1 <= arr[great] <= pivot2 + arr[k] = arr[great]; + idxs[k] = idxs[great]; + } + /* + * Here and below we use "arr[i] = b; i--;" instead + * of "arr[i--] = b;" due to performance issue. + */ + arr[great] = ak; + idxs[great] = bk; + --great; + } + } + + // Swap pivots into their final positions + arr[left] = arr[less - 1]; arr[less - 1] = pivot1; + arr[right] = arr[great + 1]; arr[great + 1] = pivot2; + idxs[left] = idxs[less - 1]; idxs[less - 1] = idxPivot1; + idxs[right] = idxs[great + 1]; idxs[great + 1] = idxPivot2; + + // Sort left and right parts recursively, excluding known pivots + sort(arr, idxs, left, less - 2, leftmost); + sort(arr, idxs, great + 2, right, false); + + /* + * If center part is too large (comprises > 4/7 of the array), + * swap internal pivot values to ends. + */ + if (less < e1 && e5 < great) { + /* + * Skip elements, which are equal to pivot values. + */ + while (arr[less] == pivot1) { + ++less; + } + + while (arr[great] == pivot2) { + --great; + } + + /* + * Partitioning: + * + * left part center part right part + * +----------------------------------------------------------+ + * | == pivot1 | pivot1 < && < pivot2 | ? | == pivot2 | + * +----------------------------------------------------------+ + * ^ ^ ^ + * | | | + * less k great + * + * Invariants: + * + * all in (*, less) == pivot1 + * pivot1 < all in [less, k) < pivot2 + * all in (great, *) == pivot2 + * + * Pointer k is the first index of ?-part. + */ + outer: + for (int k = less - 1; ++k <= great; ) { + double ak = arr[k]; + int bk = idxs[k]; + if (ak == pivot1) { // Move arr[k] to left part + arr[k] = arr[less]; + arr[less] = ak; + idxs[k] = idxs[less]; + idxs[less] = bk; + ++less; + } else if (ak == pivot2) { // Move arr[k] to right part + while (arr[great] == pivot2) { + if (great-- == k) { + break outer; + } + } + if (arr[great] == pivot1) { // arr[great] < pivot2 + arr[k] = arr[less]; + idxs[k] = idxs[less]; + /* + * Even though arr[great] equals to pivot1, the + * assignment arr[less] = pivot1 may be incorrect, + * if arr[great] and pivot1 are floating-point zeros + * of different signs. Therefore in float and + * double sorting methods we have to use more + * accurate assignment arr[less] = arr[great]. + */ + arr[less] = arr[great]; + idxs[less] = idxs[great]; + ++less; + } else { // pivot1 < arr[great] < pivot2 + arr[k] = arr[great]; + idxs[k] = idxs[great]; + } + arr[great] = ak; + idxs[great] = bk; + --great; + } + } + } + + // Sort center part recursively + sort(arr, idxs, less, great, false); + + } else { // Partitioning with one pivot + /* + * Use the third of the five sorted elements as pivot. + * This value is inexpensive approximation of the median. + */ + double pivot = arr[e3]; + + /* + * Partitioning degenerates to the traditional 3-way + * (or "Dutch National Flag") schema: + * + * left part center part right part + * +-------------------------------------------------+ + * | < pivot | == pivot | ? | > pivot | + * +-------------------------------------------------+ + * ^ ^ ^ + * | | | + * less k great + * + * Invariants: + * + * all in (left, less) < pivot + * all in [less, k) == pivot + * all in (great, right) > pivot + * + * Pointer k is the first index of ?-part. + */ + for (int k = less; k <= great; ++k) { + if (arr[k] == pivot) { + continue; + } + double ak = arr[k]; + int bk = idxs[k]; + if (ak < pivot) { // Move arr[k] to left part + arr[k] = arr[less]; + arr[less] = ak; + idxs[k] = idxs[less]; + idxs[less] = bk; + ++less; + } else { // arr[k] > pivot - Move arr[k] to right part + while (arr[great] > pivot) { + --great; + } + if (arr[great] < pivot) { // arr[great] <= pivot + arr[k] = arr[less]; + arr[less] = arr[great]; + idxs[k] = idxs[less]; + idxs[less] = idxs[great]; + ++less; + } else { // arr[great] == pivot + /* + * Even though arr[great] equals to pivot, the + * assignment arr[k] = pivot may be incorrect, + * if arr[great] and pivot are floating-point + * zeros of different signs. Therefore in float + * and double sorting methods we have to use + * more accurate assignment arr[k] = arr[great]. + */ + arr[k] = arr[great]; + idxs[k] = idxs[great]; + } + arr[great] = ak; + idxs[great] = bk; + --great; + } + } + + /* + * Sort left and right parts recursively. + * All elements from center part are equal + * and, therefore, already sorted. + */ + sort(arr, idxs, left, less - 1, leftmost); + sort(arr, idxs, great + 1, right, false); + } + } +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/TreeUtil.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/TreeUtil.scala index fa6eb63685d2..710ec48f72cb 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/TreeUtil.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/TreeUtil.scala @@ -48,14 +48,8 @@ private[tree] object TreeUtil { * (First collect stats to decide how to partition.) * TODO: Move elsewhere in MLlib. */ - def rowToColumnStoreDense(rowStore: RDD[Vector]): RDD[(Int, Vector)] = { + def rowToColumnStoreDense(rowStore: RDD[Vector], numRows: Int): RDD[(Int, Vector)] = { - val numRows = { - val longNumRows: Long = rowStore.count() - require(longNumRows < Int.MaxValue, s"rowToColumnStore given RDD with $longNumRows rows," + - s" but can handle at most ${Int.MaxValue} rows") - longNumRows.toInt - } if (numRows == 0) { return rowStore.sparkContext.parallelize(Seq.empty[(Int, Vector)]) } @@ -89,12 +83,16 @@ private[tree] object TreeUtil { // = column values for each instance in sourcePartitionIndex, // where colIdx is a 0-based index for columns for groupIndex val columnSets = new Array[Array[ArrayBuffer[Double]]](numTargetPartitions) - Range(0, numTargetPartitions).foreach { groupIndex => + var groupIndex = 0 + while(groupIndex < numTargetPartitions) { columnSets(groupIndex) = Array.fill[ArrayBuffer[Double]](getNumColsInGroup(groupIndex))(ArrayBuffer[Double]()) + groupIndex += 1 } - iterator.foreach { row => - Range(0, numTargetPartitions).foreach { groupIndex => + while (iterator.hasNext) { + val row: Vector = iterator.next() + var groupIndex = 0 + while (groupIndex < numTargetPartitions) { val fromCol = groupIndex * maxColumnsPerPartition val numColsInTargetPartition = getNumColsInGroup(groupIndex) // TODO: match-case here on row as Dense or Sparse Vector (for speed) @@ -103,6 +101,7 @@ private[tree] object TreeUtil { columnSets(groupIndex)(colIdx) += row(fromCol + colIdx) colIdx += 1 } + groupIndex += 1 } } Range(0, numTargetPartitions).map { groupIndex => diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala index d1e624fb8d6c..862e29bd1093 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala @@ -160,7 +160,7 @@ private[spark] abstract class ImpurityCalculator(val stats: Array[Double]) exten def prob(label: Double): Double = -1 /** Get [[Predict]] struct. */ - def getPredict = { + def getPredict: Predict = { val pred = this.predict new Predict(predict = pred, prob = this.prob(pred)) } diff --git a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/AltDTSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/AltDTSuite.scala index 860c8ea78760..1c24dcd73608 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/AltDTSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/AltDTSuite.scala @@ -21,14 +21,15 @@ import org.apache.spark.SparkFunSuite import org.apache.spark.ml.regression.DecisionTreeRegressor import org.apache.spark.ml.tree._ import org.apache.spark.ml.tree.impl.AltDT.{AltDTMetadata, FeatureVector, PartitionInfo} -import org.apache.spark.ml.tree.impl.TreeUtil._ -import org.apache.spark.mllib.linalg.{SparseVector, Vector, Vectors} +import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.impurity._ import org.apache.spark.mllib.tree.model.ImpurityStats import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.util.collection.BitSet +import scala.util.Random + /** * Test suite for [[AltDT]]. */ @@ -72,6 +73,39 @@ class AltDTSuite extends SparkFunSuite with MLlibTestSparkContext { //////////////////////////////// Helper classes ////////////////////////////////// + test("DualPivotQuicksort: sort array") { + val r = new Random() + val len = 1000 + val doubles = Array.fill[Double](len)(r.nextDouble) + val idxs = Array.fill[Int](len)(r.nextInt(len)) + + val (sortedDoubles, sortedIdxs) = doubles.zip(idxs).sorted.unzip + DualPivotQuicksort.sort(doubles, idxs) + + assert(sortedDoubles.sameElements(doubles)) + assert(sortedIdxs.sameElements(idxs)) + } + + test("DualPivotQuicksort: sort sub-arrays") { + val r = new Random() + val len = 1000 + val partition = r.nextInt(len) // partition point between the two sub-arrays + val doubles = Array.fill[Double](len)(r.nextDouble) + val idxs = Array.fill[Int](len)(r.nextInt(len)) + + val (sortedLeftDoubles, sortedLeftIdxs) = doubles.slice(0, partition).zip(idxs.slice(0, partition)).sorted.unzip + val (sortedRightDoubles, sortedRightIdxs) = doubles.slice(partition, len).zip(idxs.slice(partition, len)).sorted.unzip + + val sortedDoubles = sortedLeftDoubles ++ sortedRightDoubles + val sortedIdxs = sortedLeftIdxs ++ sortedRightIdxs + + DualPivotQuicksort.sort(doubles, idxs, 0, partition - 1) + DualPivotQuicksort.sort(doubles, idxs, partition, len - 1) + + assert(sortedDoubles.sameElements(doubles)) + assert(sortedIdxs.sameElements(idxs)) + } + test("FeatureVector") { val v = new FeatureVector(1, 0, Array(0.1, 0.3, 0.7), Array(1, 2, 0)) @@ -100,11 +134,12 @@ class AltDTSuite extends SparkFunSuite with MLlibTestSparkContext { // Create bitVector for splitting the 4 rows: L, R, L, R // New groups are {0, 2}, {1, 3} - val bitVector = new BitSubvector(0, numRows) + val bitVector = new BitSet(numRows) bitVector.set(1) bitVector.set(3) - val newInfo = info.update(Array(bitVector), newNumNodeOffsets = 3) + // for these tests, use the activeNodes for nodeSplitBitVector + val newInfo = info.update(bitVector, activeNodes, newNumNodeOffsets = 3) assert(newInfo.columns.length === 2) val expectedCol1a = @@ -117,12 +152,16 @@ class AltDTSuite extends SparkFunSuite with MLlibTestSparkContext { assert(newInfo.activeNodes.iterator.toSet === Set(0, 1)) // Create 2 bitVectors for splitting into: 0, 2, 1, 3 - val bv2a = new BitSubvector(0, 2) - bv2a.set(1) - val bv2b = new BitSubvector(2, 4) - bv2b.set(3) + val bitVector2 = new BitSet(numRows) + bitVector2.set(2) // 2 goes to the right + bitVector2.set(3) // 3 goes to the right - val newInfo2 = newInfo.update(Array(bv2a, bv2b), newNumNodeOffsets = 5) + // both nodes find a split + val nodeSplits = new BitSet(2) + nodeSplits.set(0) + nodeSplits.set(1) + + val newInfo2 = newInfo.update(bitVector2, nodeSplits, newNumNodeOffsets = 5) assert(newInfo2.columns.length === 2) val expectedCol2a = @@ -286,31 +325,34 @@ class AltDTSuite extends SparkFunSuite with MLlibTestSparkContext { FeatureVector.fromOriginal(0, 0, Vectors.dense(0.1, 0.2, 0.4, 0.6, 0.7)) val fromOffset = 0 val toOffset = col.values.length + val numRows = col.values.length val split = new ContinuousSplit(0, threshold = 0.5) - val bitv = AltDT.bitSubvectorFromSplit(col, fromOffset, toOffset, split) - assert(bitv.from === fromOffset) - assert(bitv.to === toOffset) + val bitv = AltDT.bitVectorFromSplit(col, fromOffset, toOffset, split, numRows) assert(bitv.iterator.toSet === Set(3, 4)) } test("bitSubvectorFromSplit: 2 nodes") { // Initially, 1 split: (0, 2, 4) | (1, 3) + // (0.4, 0.2, 0.1) | (0.6, 0.7) + // 0, 1, 2, 3, 4 + // 0.4, 0.6, 0.2, 0.7, 0.1 val col = new FeatureVector(0, 0, Array(0.1, 0.2, 0.4, 0.6, 0.7), Array(4, 2, 0, 1, 3)) def checkSplit(fromOffset: Int, toOffset: Int, threshold: Double, expectedRight: Set[Int]): Unit = { val split = new ContinuousSplit(0, threshold) - val bitv = AltDT.bitSubvectorFromSplit(col, fromOffset, toOffset, split) - assert(bitv.from === fromOffset) - assert(bitv.to === toOffset) + val numRows = col.values.length + val bitv = AltDT.bitVectorFromSplit(col, fromOffset, toOffset, split, numRows) assert(bitv.iterator.toSet === expectedRight) } + // Left child node - checkSplit(0, 3, 0.15, Set(0, 1)) + checkSplit(0, 3, 0.05, Set(0, 2, 4)) + checkSplit(0, 3, 0.15, Set(0, 2)) checkSplit(0, 3, 0.2, Set(0)) checkSplit(0, 3, 0.5, Set()) // Right child node - checkSplit(3, 5, 0.1, Set(3, 4)) - checkSplit(3, 5, 0.65, Set(4)) + checkSplit(3, 5, 0.1, Set(1, 3)) + checkSplit(3, 5, 0.65, Set(3)) checkSplit(3, 5, 0.8, Set()) } @@ -323,11 +365,8 @@ class AltDTSuite extends SparkFunSuite with MLlibTestSparkContext { val info = PartitionInfo(Array(col), Array(0, numRows), activeNodes) val partitionInfos = sc.parallelize(Seq(info)) val bestSplit = new ContinuousSplit(0, threshold = 0.5) - val bitVectors = AltDT.collectBitVectors(partitionInfos, Array(Some(bestSplit))) - assert(bitVectors.length === 1) - val bitv = bitVectors.head - assert(bitv.numBits === numRows) - assert(bitv.iterator.toArray === Array(3, 4)) + val bitVector = AltDT.aggregateBitVector(partitionInfos, Array(Some(bestSplit)), numRows) + assert(bitVector.iterator.toSet === Set(3, 4)) } test("collectBitVectors with 1 vector, with tied threshold") { @@ -339,11 +378,8 @@ class AltDTSuite extends SparkFunSuite with MLlibTestSparkContext { val info = PartitionInfo(Array(col), Array(0, numRows), activeNodes) val partitionInfos = sc.parallelize(Seq(info)) val bestSplit = new ContinuousSplit(0, threshold = -2.0) - val bitVectors = AltDT.collectBitVectors(partitionInfos, Array(Some(bestSplit))) - assert(bitVectors.length === 1) - val bitv = bitVectors.head - assert(bitv.numBits === numRows) - assert(bitv.iterator.toArray === Array(0, 1, 4, 5)) + val bitVector = AltDT.aggregateBitVector(partitionInfos, Array(Some(bestSplit)), numRows) + assert(bitVector.iterator.toSet === Set(0, 1, 4, 5)) } //////////////////////////////// Active nodes ////////////////////////////////// diff --git a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/BitSubvectorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/BitSubvectorSuite.scala deleted file mode 100644 index 6e7b67623851..000000000000 --- a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/BitSubvectorSuite.scala +++ /dev/null @@ -1,69 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.ml.tree.impl - -import org.apache.spark.SparkFunSuite -import org.apache.spark.mllib.util.MLlibTestSparkContext - -/** - * Test suite for [[BitSubvector]]. - */ -class BitSubvectorSuite extends SparkFunSuite with MLlibTestSparkContext { - - test("BitSubvector basic ops") { - val from = 1 - val to = 4 - val bs = new BitSubvector(from, to) - assert(bs.numBits === to - from) - Range(from, to).foreach(x => assert(!bs.get(x))) - val setVals = Array(from, to - 1) - setVals.foreach { x => - bs.set(x) - assert(bs.get(x)) - } - assert(bs.iterator.toSet === setVals.toSet) - } - - test("BitSubvector merge") { - val b1 = new BitSubvector(0, 5) - b1.set(1) - val b2 = new BitSubvector(5, 7) - b2.set(5) - val b3 = new BitSubvector(9, 12) - b3.set(11) - val parts1 = Array(b1) - val parts2 = Array(b2, b3) - val newParts = BitSubvector.merge(parts1, parts2) - - val r1 = new BitSubvector(0, 7) - r1.set(1) - r1.set(5) - val r2 = new BitSubvector(9, 12) - r2.set(11) - val expectedParts = Array(r1, r2) - newParts.zip(expectedParts).foreach { case (x, y) => - assert(x.from === y.from) - assert(x.to === x.to) - assert(x.iterator.toSet === y.iterator.toSet) - } - } - - test("BitSubvector merge with empty BitSubvectors") { - val parts = BitSubvector.merge(Array.empty[BitSubvector], Array.empty[BitSubvector]) - } -} diff --git a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeUtilSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeUtilSuite.scala index 483fb8568f8b..abacb8b8b7d8 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeUtilSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeUtilSuite.scala @@ -32,7 +32,7 @@ class TreeUtilSuite extends SparkFunSuite with MLlibTestSparkContext { private def checkDense(rows: Seq[Vector]): Unit = { val numRowPartitions = 2 val rowStore = sc.parallelize(rows, numRowPartitions) - val colStore = rowToColumnStoreDense(rowStore) + val colStore = rowToColumnStoreDense(rowStore, rowStore.count.toInt) val numColPartitions = colStore.partitions.length val cols: Map[Int, Vector] = colStore.collect().toMap val numRows = rows.size