Skip to content

Commit 3c00cc6

Browse files
author
Andrew Or
committed
Merge branch 'master' of github.com:apache/spark into concurrent-sql-executions
Conflicts: core/src/test/scala/org/apache/spark/ThreadingSuite.scala
2 parents 5297f79 + c34fc19 commit 3c00cc6

File tree

21 files changed

+626
-62
lines changed

21 files changed

+626
-62
lines changed

core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -297,6 +297,8 @@ private[spark] class ExternalSorter[K, V, C](
297297
val it = collection.destructiveSortedWritablePartitionedIterator(comparator)
298298
while (it.hasNext) {
299299
val partitionId = it.nextPartition()
300+
require(partitionId >= 0 && partitionId < numPartitions,
301+
s"partition Id: ${partitionId} should be in the range [0, ${numPartitions})")
300302
it.writeNext(writer)
301303
elementsPerPartition(partitionId) += 1
302304
objectsWritten += 1

core/src/test/scala/org/apache/spark/ThreadingSuite.scala

Lines changed: 45 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -119,23 +119,30 @@ class ThreadingSuite extends SparkFunSuite with LocalSparkContext with Logging {
119119
val nums = sc.parallelize(1 to 2, 2)
120120
val sem = new Semaphore(0)
121121
ThreadingSuiteState.clear()
122+
var throwable: Option[Throwable] = None
122123
for (i <- 0 until 2) {
123124
new Thread {
124125
override def run() {
125-
val ans = nums.map(number => {
126-
val running = ThreadingSuiteState.runningThreads
127-
running.getAndIncrement()
128-
val time = System.currentTimeMillis()
129-
while (running.get() != 4 && System.currentTimeMillis() < time + 1000) {
130-
Thread.sleep(100)
131-
}
132-
if (running.get() != 4) {
133-
ThreadingSuiteState.failed.set(true)
134-
}
135-
number
136-
}).collect()
137-
assert(ans.toList === List(1, 2))
138-
sem.release()
126+
try {
127+
val ans = nums.map(number => {
128+
val running = ThreadingSuiteState.runningThreads
129+
running.getAndIncrement()
130+
val time = System.currentTimeMillis()
131+
while (running.get() != 4 && System.currentTimeMillis() < time + 1000) {
132+
Thread.sleep(100)
133+
}
134+
if (running.get() != 4) {
135+
ThreadingSuiteState.failed.set(true)
136+
}
137+
number
138+
}).collect()
139+
assert(ans.toList === List(1, 2))
140+
} catch {
141+
case t: Throwable =>
142+
throwable = Some(t)
143+
} finally {
144+
sem.release()
145+
}
139146
}
140147
}.start()
141148
}
@@ -145,19 +152,25 @@ class ThreadingSuite extends SparkFunSuite with LocalSparkContext with Logging {
145152
ThreadingSuiteState.runningThreads.get() + "); failing test")
146153
fail("One or more threads didn't see runningThreads = 4")
147154
}
155+
throwable.foreach { t => throw t }
148156
}
149157

150158
test("set local properties in different thread") {
151159
sc = new SparkContext("local", "test")
152160
val sem = new Semaphore(0)
153-
161+
var throwable: Option[Throwable] = None
154162
val threads = (1 to 5).map { i =>
155163
new Thread() {
156164
override def run() {
157-
// TODO: these assertion failures don't actually fail the test...
158-
sc.setLocalProperty("test", i.toString)
159-
assert(sc.getLocalProperty("test") === i.toString)
160-
sem.release()
165+
try {
166+
sc.setLocalProperty("test", i.toString)
167+
assert(sc.getLocalProperty("test") === i.toString)
168+
} catch {
169+
case t: Throwable =>
170+
throwable = Some(t)
171+
} finally {
172+
sem.release()
173+
}
161174
}
162175
}
163176
}
@@ -166,21 +179,27 @@ class ThreadingSuite extends SparkFunSuite with LocalSparkContext with Logging {
166179

167180
sem.acquire(5)
168181
assert(sc.getLocalProperty("test") === null)
182+
throwable.foreach { t => throw t }
169183
}
170184

171185
test("set and get local properties in parent-children thread") {
172186
sc = new SparkContext("local", "test")
173187
sc.setLocalProperty("test", "parent")
174188
val sem = new Semaphore(0)
175-
189+
var throwable: Option[Throwable] = None
176190
val threads = (1 to 5).map { i =>
177191
new Thread() {
178192
override def run() {
179-
// TODO: these assertion failures don't actually fail the test...
180-
assert(sc.getLocalProperty("test") === "parent")
181-
sc.setLocalProperty("test", i.toString)
182-
assert(sc.getLocalProperty("test") === i.toString)
183-
sem.release()
193+
try {
194+
assert(sc.getLocalProperty("test") === "parent")
195+
sc.setLocalProperty("test", i.toString)
196+
assert(sc.getLocalProperty("test") === i.toString)
197+
} catch {
198+
case t: Throwable =>
199+
throwable = Some(t)
200+
} finally {
201+
sem.release()
202+
}
184203
}
185204
}
186205
}
@@ -190,6 +209,7 @@ class ThreadingSuite extends SparkFunSuite with LocalSparkContext with Logging {
190209
sem.acquire(5)
191210
assert(sc.getLocalProperty("test") === "parent")
192211
assert(sc.getLocalProperty("Foo") === null)
212+
throwable.foreach { t => throw t }
193213
}
194214

195215
test("inheritance exclusions (SPARK-10548)") {
@@ -236,7 +256,6 @@ class ThreadingSuite extends SparkFunSuite with LocalSparkContext with Logging {
236256
// Create a new thread which will inherit the current thread's properties
237257
val thread = new Thread() {
238258
override def run(): Unit = {
239-
// TODO: these assertion failures don't actually fail the test...
240259
assert(sc.getLocalProperty(SparkContext.SPARK_JOB_GROUP_ID) === "originalJobGroupId")
241260
// Sleeps for a total of 10 seconds, but allows cancellation to interrupt the task
242261
try {

python/pyspark/context.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -255,7 +255,7 @@ def __getnewargs__(self):
255255
# This method is called when attempting to pickle SparkContext, which is always an error:
256256
raise Exception(
257257
"It appears that you are attempting to reference SparkContext from a broadcast "
258-
"variable, action, or transforamtion. SparkContext can only be used on the driver, "
258+
"variable, action, or transformation. SparkContext can only be used on the driver, "
259259
"not in code that it run on workers. For more information, see SPARK-5063."
260260
)
261261

python/pyspark/sql/column.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,17 @@ def _(self):
9191
return _
9292

9393

94+
def _bin_func_op(name, reverse=False, doc="binary function"):
95+
def _(self, other):
96+
sc = SparkContext._active_spark_context
97+
fn = getattr(sc._jvm.functions, name)
98+
jc = other._jc if isinstance(other, Column) else _create_column_from_literal(other)
99+
njc = fn(self._jc, jc) if not reverse else fn(jc, self._jc)
100+
return Column(njc)
101+
_.__doc__ = doc
102+
return _
103+
104+
94105
def _bin_op(name, doc="binary operator"):
95106
""" Create a method for given binary operator
96107
"""
@@ -151,6 +162,8 @@ def __init__(self, jc):
151162
__rdiv__ = _reverse_op("divide")
152163
__rtruediv__ = _reverse_op("divide")
153164
__rmod__ = _reverse_op("mod")
165+
__pow__ = _bin_func_op("pow")
166+
__rpow__ = _bin_func_op("pow", reverse=True)
154167

155168
# logistic operators
156169
__eq__ = _bin_op("equalTo")

python/pyspark/sql/tests.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -568,7 +568,7 @@ def test_column_operators(self):
568568
cs = self.df.value
569569
c = ci == cs
570570
self.assertTrue(isinstance((- ci - 1 - 2) % 3 * 2.5 / 3.5, Column))
571-
rcc = (1 + ci), (1 - ci), (1 * ci), (1 / ci), (1 % ci)
571+
rcc = (1 + ci), (1 - ci), (1 * ci), (1 / ci), (1 % ci), (1 ** ci), (ci ** 1)
572572
self.assertTrue(all(isinstance(c, Column) for c in rcc))
573573
cb = [ci == 5, ci != 0, ci > 3, ci < 4, ci >= 0, ci <= 7]
574574
self.assertTrue(all(isinstance(c, Column) for c in cb))

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ import java.math.{BigDecimal => JavaBigDecimal}
2222
import org.apache.spark.sql.catalyst.InternalRow
2323
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
2424
import org.apache.spark.sql.catalyst.expressions.codegen._
25-
import org.apache.spark.sql.catalyst.util.DateTimeUtils
25+
import org.apache.spark.sql.catalyst.util.{StringUtils, DateTimeUtils}
2626
import org.apache.spark.sql.types._
2727
import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String}
2828

@@ -140,7 +140,15 @@ case class Cast(child: Expression, dataType: DataType)
140140
// UDFToBoolean
141141
private[this] def castToBoolean(from: DataType): Any => Any = from match {
142142
case StringType =>
143-
buildCast[UTF8String](_, _.numBytes() != 0)
143+
buildCast[UTF8String](_, s => {
144+
if (StringUtils.isTrueString(s)) {
145+
true
146+
} else if (StringUtils.isFalseString(s)) {
147+
false
148+
} else {
149+
null
150+
}
151+
})
144152
case TimestampType =>
145153
buildCast[Long](_, t => t != 0)
146154
case DateType =>
@@ -646,7 +654,17 @@ case class Cast(child: Expression, dataType: DataType)
646654

647655
private[this] def castToBooleanCode(from: DataType): CastFunction = from match {
648656
case StringType =>
649-
(c, evPrim, evNull) => s"$evPrim = $c.numBytes() != 0;"
657+
val stringUtils = StringUtils.getClass.getName.stripSuffix("$")
658+
(c, evPrim, evNull) =>
659+
s"""
660+
if ($stringUtils.isTrueString($c)) {
661+
$evPrim = true;
662+
} else if ($stringUtils.isFalseString($c)) {
663+
$evPrim = false;
664+
} else {
665+
$evNull = true;
666+
}
667+
"""
650668
case TimestampType =>
651669
(c, evPrim, evNull) => s"$evPrim = $c != 0;"
652670
case DateType =>

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -435,10 +435,10 @@ object BooleanSimplification extends Rule[LogicalPlan] with PredicateHelper {
435435
// a && a => a
436436
case (l, r) if l fastEquals r => l
437437
// a && (not(a) || b) => a && b
438-
case (l, Or(l1, r)) if (Not(l) fastEquals l1) => And(l, r)
439-
case (l, Or(r, l1)) if (Not(l) fastEquals l1) => And(l, r)
440-
case (Or(l, l1), r) if (l1 fastEquals Not(r)) => And(l, r)
441-
case (Or(l1, l), r) if (l1 fastEquals Not(r)) => And(l, r)
438+
case (l, Or(l1, r)) if (Not(l) == l1) => And(l, r)
439+
case (l, Or(r, l1)) if (Not(l) == l1) => And(l, r)
440+
case (Or(l, l1), r) if (l1 == Not(r)) => And(l, r)
441+
case (Or(l1, l), r) if (l1 == Not(r)) => And(l, r)
442442
// (a || b) && (a || c) => a || (b && c)
443443
case _ =>
444444
// 1. Split left and right to get the disjunctive predicates,

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/StringUtils.scala

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@ package org.apache.spark.sql.catalyst.util
1919

2020
import java.util.regex.Pattern
2121

22+
import org.apache.spark.unsafe.types.UTF8String
23+
2224
object StringUtils {
2325

2426
// replace the _ with .{1} exactly match 1 time of any character
@@ -44,4 +46,10 @@ object StringUtils {
4446
v
4547
}
4648
}
49+
50+
private[this] val trueStrings = Set("t", "true", "y", "yes", "1").map(UTF8String.fromString)
51+
private[this] val falseStrings = Set("f", "false", "n", "no", "0").map(UTF8String.fromString)
52+
53+
def isTrueString(s: UTF8String): Boolean = trueStrings.contains(s.toLowerCase)
54+
def isFalseString(s: UTF8String): Boolean = falseStrings.contains(s.toLowerCase)
4755
}

0 commit comments

Comments
 (0)