Skip to content

Commit aa3de46

Browse files
author
Davies Liu
committed
Merge branch 'master' of github.com:apache/spark into semijoin
2 parents 7590a25 + d5a9af3 commit aa3de46

File tree

43 files changed

+961
-626
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

43 files changed

+961
-626
lines changed

core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java

Lines changed: 20 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -191,24 +191,29 @@ public void spill() throws IOException {
191191
spillWriters.size(),
192192
spillWriters.size() > 1 ? " times" : " time");
193193

194-
final UnsafeSorterSpillWriter spillWriter =
195-
new UnsafeSorterSpillWriter(blockManager, fileBufferSizeBytes, writeMetrics,
196-
inMemSorter.numRecords());
197-
spillWriters.add(spillWriter);
198-
final UnsafeSorterIterator sortedRecords = inMemSorter.getSortedIterator();
199-
while (sortedRecords.hasNext()) {
200-
sortedRecords.loadNext();
201-
final Object baseObject = sortedRecords.getBaseObject();
202-
final long baseOffset = sortedRecords.getBaseOffset();
203-
final int recordLength = sortedRecords.getRecordLength();
204-
spillWriter.write(baseObject, baseOffset, recordLength, sortedRecords.getKeyPrefix());
194+
// We only write out contents of the inMemSorter if it is not empty.
195+
if (inMemSorter.numRecords() > 0) {
196+
final UnsafeSorterSpillWriter spillWriter =
197+
new UnsafeSorterSpillWriter(blockManager, fileBufferSizeBytes, writeMetrics,
198+
inMemSorter.numRecords());
199+
spillWriters.add(spillWriter);
200+
final UnsafeSorterIterator sortedRecords = inMemSorter.getSortedIterator();
201+
while (sortedRecords.hasNext()) {
202+
sortedRecords.loadNext();
203+
final Object baseObject = sortedRecords.getBaseObject();
204+
final long baseOffset = sortedRecords.getBaseOffset();
205+
final int recordLength = sortedRecords.getRecordLength();
206+
spillWriter.write(baseObject, baseOffset, recordLength, sortedRecords.getKeyPrefix());
207+
}
208+
spillWriter.close();
205209
}
206-
spillWriter.close();
210+
207211
final long spillSize = freeMemory();
208212
// Note that this is more-or-less going to be a multiple of the page size, so wasted space in
209213
// pages will currently be counted as memory spilled even though that space isn't actually
210214
// written to disk. This also counts the space needed to store the sorter's pointer array.
211215
taskContext.taskMetrics().incMemoryBytesSpilled(spillSize);
216+
212217
initializeForWriting();
213218
}
214219

@@ -505,12 +510,11 @@ public UnsafeSorterIterator getSortedIterator() throws IOException {
505510
final UnsafeSorterSpillMerger spillMerger =
506511
new UnsafeSorterSpillMerger(recordComparator, prefixComparator, numIteratorsToMerge);
507512
for (UnsafeSorterSpillWriter spillWriter : spillWriters) {
508-
spillMerger.addSpill(spillWriter.getReader(blockManager));
513+
spillMerger.addSpillIfNotEmpty(spillWriter.getReader(blockManager));
509514
}
510515
spillWriters.clear();
511-
if (inMemoryIterator.hasNext()) {
512-
spillMerger.addSpill(inMemoryIterator);
513-
}
516+
spillMerger.addSpillIfNotEmpty(inMemoryIterator);
517+
514518
return spillMerger.getSortedIterator();
515519
}
516520
}

core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillMerger.java

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,11 +47,19 @@ public int compare(UnsafeSorterIterator left, UnsafeSorterIterator right) {
4747
priorityQueue = new PriorityQueue<UnsafeSorterIterator>(numSpills, comparator);
4848
}
4949

50-
public void addSpill(UnsafeSorterIterator spillReader) throws IOException {
50+
/**
51+
* Add an UnsafeSorterIterator to this merger
52+
*/
53+
public void addSpillIfNotEmpty(UnsafeSorterIterator spillReader) throws IOException {
5154
if (spillReader.hasNext()) {
55+
// We only add the spillReader to the priorityQueue if it is not empty. We do this to
56+
// make sure the hasNext method of UnsafeSorterIterator returned by getSortedIterator
57+
// does not return wrong result because hasNext will returns true
58+
// at least priorityQueue.size() times. If we allow n spillReaders in the
59+
// priorityQueue, we will have n extra empty records in the result of the UnsafeSorterIterator.
5260
spillReader.loadNext();
61+
priorityQueue.add(spillReader);
5362
}
54-
priorityQueue.add(spillReader);
5563
}
5664

5765
public UnsafeSorterIterator getSortedIterator() throws IOException {

core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ public OutputStream apply(OutputStream stream) {
101101
public void setUp() {
102102
MockitoAnnotations.initMocks(this);
103103
sparkConf = new SparkConf();
104-
tempDir = new File(Utils.createTempDir$default$1());
104+
tempDir = Utils.createTempDir(System.getProperty("java.io.tmpdir"), "unsafe-test");
105105
shuffleMemoryManager = new ShuffleMemoryManager(Long.MAX_VALUE);
106106
spillFilesCreated.clear();
107107
taskContext = mock(TaskContext.class);
@@ -143,13 +143,18 @@ public DiskBlockObjectWriter answer(InvocationOnMock invocationOnMock) throws Th
143143

144144
@After
145145
public void tearDown() {
146-
long leakedUnsafeMemory = taskMemoryManager.cleanUpAllAllocatedMemory();
147-
if (shuffleMemoryManager != null) {
148-
long leakedShuffleMemory = shuffleMemoryManager.getMemoryConsumptionForThisTask();
149-
shuffleMemoryManager = null;
150-
assertEquals(0L, leakedShuffleMemory);
146+
try {
147+
long leakedUnsafeMemory = taskMemoryManager.cleanUpAllAllocatedMemory();
148+
if (shuffleMemoryManager != null) {
149+
long leakedShuffleMemory = shuffleMemoryManager.getMemoryConsumptionForThisTask();
150+
shuffleMemoryManager = null;
151+
assertEquals(0L, leakedShuffleMemory);
152+
}
153+
assertEquals(0, leakedUnsafeMemory);
154+
} finally {
155+
Utils.deleteRecursively(tempDir);
156+
tempDir = null;
151157
}
152-
assertEquals(0, leakedUnsafeMemory);
153158
}
154159

155160
private void assertSpillFilesWereCleanedUp() {
@@ -234,7 +239,7 @@ public void testSortingEmptyArrays() throws Exception {
234239
public void spillingOccursInResponseToMemoryPressure() throws Exception {
235240
shuffleMemoryManager = new ShuffleMemoryManager(pageSizeBytes * 2);
236241
final UnsafeExternalSorter sorter = newSorter();
237-
final int numRecords = 100000;
242+
final int numRecords = (int) pageSizeBytes / 4;
238243
for (int i = 0; i <= numRecords; i++) {
239244
insertNumber(sorter, numRecords - i);
240245
}

core/src/test/scala/org/apache/spark/deploy/master/MasterSuite.scala

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -93,8 +93,8 @@ class MasterSuite extends SparkFunSuite with Matchers with Eventually with Priva
9393
publicAddress = ""
9494
)
9595

96-
val (rpcEnv, uiPort, restPort) =
97-
Master.startRpcEnvAndEndpoint("127.0.0.1", 7077, 8080, conf)
96+
val (rpcEnv, _, _) =
97+
Master.startRpcEnvAndEndpoint("127.0.0.1", 0, 0, conf)
9898

9999
try {
100100
rpcEnv.setupEndpointRef(Master.SYSTEM_NAME, rpcEnv.address, Master.ENDPOINT_NAME)
@@ -343,8 +343,8 @@ class MasterSuite extends SparkFunSuite with Matchers with Eventually with Priva
343343

344344
private def makeMaster(conf: SparkConf = new SparkConf): Master = {
345345
val securityMgr = new SecurityManager(conf)
346-
val rpcEnv = RpcEnv.create(Master.SYSTEM_NAME, "localhost", 7077, conf, securityMgr)
347-
val master = new Master(rpcEnv, rpcEnv.address, 8080, securityMgr, conf)
346+
val rpcEnv = RpcEnv.create(Master.SYSTEM_NAME, "localhost", 0, conf, securityMgr)
347+
val master = new Master(rpcEnv, rpcEnv.address, 0, securityMgr, conf)
348348
master
349349
}
350350

Lines changed: 170 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,170 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.ml.feature
19+
20+
import org.apache.spark.annotation.Experimental
21+
import org.apache.spark.ml.Transformer
22+
import org.apache.spark.ml.attribute.{Attribute, AttributeGroup}
23+
import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol}
24+
import org.apache.spark.ml.param.{IntArrayParam, ParamMap, StringArrayParam}
25+
import org.apache.spark.ml.util.{Identifiable, MetadataUtils, SchemaUtils}
26+
import org.apache.spark.mllib.linalg._
27+
import org.apache.spark.sql.DataFrame
28+
import org.apache.spark.sql.functions._
29+
import org.apache.spark.sql.types.StructType
30+
31+
/**
32+
* :: Experimental ::
33+
* This class takes a feature vector and outputs a new feature vector with a subarray of the
34+
* original features.
35+
*
36+
* The subset of features can be specified with either indices ([[setIndices()]])
37+
* or names ([[setNames()]]). At least one feature must be selected. Duplicate features
38+
* are not allowed, so there can be no overlap between selected indices and names.
39+
*
40+
* The output vector will order features with the selected indices first (in the order given),
41+
* followed by the selected names (in the order given).
42+
*/
43+
@Experimental
44+
final class VectorSlicer(override val uid: String)
45+
extends Transformer with HasInputCol with HasOutputCol {
46+
47+
def this() = this(Identifiable.randomUID("vectorSlicer"))
48+
49+
/**
50+
* An array of indices to select features from a vector column.
51+
* There can be no overlap with [[names]].
52+
* @group param
53+
*/
54+
val indices = new IntArrayParam(this, "indices",
55+
"An array of indices to select features from a vector column." +
56+
" There can be no overlap with names.", VectorSlicer.validIndices)
57+
58+
setDefault(indices -> Array.empty[Int])
59+
60+
/** @group getParam */
61+
def getIndices: Array[Int] = $(indices)
62+
63+
/** @group setParam */
64+
def setIndices(value: Array[Int]): this.type = set(indices, value)
65+
66+
/**
67+
* An array of feature names to select features from a vector column.
68+
* These names must be specified by ML [[org.apache.spark.ml.attribute.Attribute]]s.
69+
* There can be no overlap with [[indices]].
70+
* @group param
71+
*/
72+
val names = new StringArrayParam(this, "names",
73+
"An array of feature names to select features from a vector column." +
74+
" There can be no overlap with indices.", VectorSlicer.validNames)
75+
76+
setDefault(names -> Array.empty[String])
77+
78+
/** @group getParam */
79+
def getNames: Array[String] = $(names)
80+
81+
/** @group setParam */
82+
def setNames(value: Array[String]): this.type = set(names, value)
83+
84+
/** @group setParam */
85+
def setInputCol(value: String): this.type = set(inputCol, value)
86+
87+
/** @group setParam */
88+
def setOutputCol(value: String): this.type = set(outputCol, value)
89+
90+
override def validateParams(): Unit = {
91+
require($(indices).length > 0 || $(names).length > 0,
92+
s"VectorSlicer requires that at least one feature be selected.")
93+
}
94+
95+
override def transform(dataset: DataFrame): DataFrame = {
96+
// Validity checks
97+
transformSchema(dataset.schema)
98+
val inputAttr = AttributeGroup.fromStructField(dataset.schema($(inputCol)))
99+
inputAttr.numAttributes.foreach { numFeatures =>
100+
val maxIndex = $(indices).max
101+
require(maxIndex < numFeatures,
102+
s"Selected feature index $maxIndex invalid for only $numFeatures input features.")
103+
}
104+
105+
// Prepare output attributes
106+
val inds = getSelectedFeatureIndices(dataset.schema)
107+
val selectedAttrs: Option[Array[Attribute]] = inputAttr.attributes.map { attrs =>
108+
inds.map(index => attrs(index))
109+
}
110+
val outputAttr = selectedAttrs match {
111+
case Some(attrs) => new AttributeGroup($(outputCol), attrs)
112+
case None => new AttributeGroup($(outputCol), inds.length)
113+
}
114+
115+
// Select features
116+
val slicer = udf { vec: Vector =>
117+
vec match {
118+
case features: DenseVector => Vectors.dense(inds.map(features.apply))
119+
case features: SparseVector => features.slice(inds)
120+
}
121+
}
122+
dataset.withColumn($(outputCol),
123+
slicer(dataset($(inputCol))).as($(outputCol), outputAttr.toMetadata()))
124+
}
125+
126+
/** Get the feature indices in order: indices, names */
127+
private def getSelectedFeatureIndices(schema: StructType): Array[Int] = {
128+
val nameFeatures = MetadataUtils.getFeatureIndicesFromNames(schema($(inputCol)), $(names))
129+
val indFeatures = $(indices)
130+
val numDistinctFeatures = (nameFeatures ++ indFeatures).distinct.length
131+
lazy val errMsg = "VectorSlicer requires indices and names to be disjoint" +
132+
s" sets of features, but they overlap." +
133+
s" indices: ${indFeatures.mkString("[", ",", "]")}." +
134+
s" names: " +
135+
nameFeatures.zip($(names)).map { case (i, n) => s"$i:$n" }.mkString("[", ",", "]")
136+
require(nameFeatures.length + indFeatures.length == numDistinctFeatures, errMsg)
137+
indFeatures ++ nameFeatures
138+
}
139+
140+
override def transformSchema(schema: StructType): StructType = {
141+
SchemaUtils.checkColumnType(schema, $(inputCol), new VectorUDT)
142+
143+
if (schema.fieldNames.contains($(outputCol))) {
144+
throw new IllegalArgumentException(s"Output column ${$(outputCol)} already exists.")
145+
}
146+
val numFeaturesSelected = $(indices).length + $(names).length
147+
val outputAttr = new AttributeGroup($(outputCol), numFeaturesSelected)
148+
val outputFields = schema.fields :+ outputAttr.toStructField()
149+
StructType(outputFields)
150+
}
151+
152+
override def copy(extra: ParamMap): VectorSlicer = defaultCopy(extra)
153+
}
154+
155+
private[feature] object VectorSlicer {
156+
157+
/** Return true if given feature indices are valid */
158+
def validIndices(indices: Array[Int]): Boolean = {
159+
if (indices.isEmpty) {
160+
true
161+
} else {
162+
indices.length == indices.distinct.length && indices.forall(_ >= 0)
163+
}
164+
}
165+
166+
/** Return true if given feature names are valid */
167+
def validNames(names: Array[String]): Boolean = {
168+
names.forall(_.nonEmpty) && names.length == names.distinct.length
169+
}
170+
}

mllib/src/main/scala/org/apache/spark/ml/util/MetadataUtils.scala

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ package org.apache.spark.ml.util
2020
import scala.collection.immutable.HashMap
2121

2222
import org.apache.spark.ml.attribute._
23+
import org.apache.spark.mllib.linalg.VectorUDT
2324
import org.apache.spark.sql.types.StructField
2425

2526

@@ -74,4 +75,20 @@ private[spark] object MetadataUtils {
7475
}
7576
}
7677

78+
/**
79+
* Takes a Vector column and a list of feature names, and returns the corresponding list of
80+
* feature indices in the column, in order.
81+
* @param col Vector column which must have feature names specified via attributes
82+
* @param names List of feature names
83+
*/
84+
def getFeatureIndicesFromNames(col: StructField, names: Array[String]): Array[Int] = {
85+
require(col.dataType.isInstanceOf[VectorUDT], s"getFeatureIndicesFromNames expected column $col"
86+
+ s" to be Vector type, but it was type ${col.dataType} instead.")
87+
val inputAttr = AttributeGroup.fromStructField(col)
88+
names.map { name =>
89+
require(inputAttr.hasAttr(name),
90+
s"getFeatureIndicesFromNames found no feature with name $name in column $col.")
91+
inputAttr.getAttr(name).index.get
92+
}
93+
}
7794
}

mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ class PrefixSpan private (
8282
/**
8383
* Gets the maximal pattern length (i.e. the length of the longest sequential pattern to consider.
8484
*/
85-
def getMaxPatternLength: Double = maxPatternLength
85+
def getMaxPatternLength: Int = maxPatternLength
8686

8787
/**
8888
* Sets maximal pattern length (default: `10`).

mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -766,6 +766,30 @@ class SparseVector(
766766
maxIdx
767767
}
768768
}
769+
770+
/**
771+
* Create a slice of this vector based on the given indices.
772+
* @param selectedIndices Unsorted list of indices into the vector.
773+
* This does NOT do bound checking.
774+
* @return New SparseVector with values in the order specified by the given indices.
775+
*
776+
* NOTE: The API needs to be discussed before making this public.
777+
* Also, if we have a version assuming indices are sorted, we should optimize it.
778+
*/
779+
private[spark] def slice(selectedIndices: Array[Int]): SparseVector = {
780+
var currentIdx = 0
781+
val (sliceInds, sliceVals) = selectedIndices.flatMap { origIdx =>
782+
val iIdx = java.util.Arrays.binarySearch(this.indices, origIdx)
783+
val i_v = if (iIdx >= 0) {
784+
Iterator((currentIdx, this.values(iIdx)))
785+
} else {
786+
Iterator()
787+
}
788+
currentIdx += 1
789+
i_v
790+
}.unzip
791+
new SparseVector(selectedIndices.length, sliceInds.toArray, sliceVals.toArray)
792+
}
769793
}
770794

771795
object SparseVector {

0 commit comments

Comments
 (0)