Skip to content

Commit d67a5ea

Browse files
committed
Merge remote-tracking branch 'upstream/master' into fix_decimal
2 parents ab6d8af + a138953 commit d67a5ea

File tree

32 files changed

+235
-238
lines changed

32 files changed

+235
-238
lines changed

core/src/main/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleManager.scala

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -56,9 +56,6 @@ private[spark] object UnsafeShuffleManager extends Logging {
5656
} else if (dependency.aggregator.isDefined) {
5757
log.debug(s"Can't use UnsafeShuffle for shuffle $shufId because an aggregator is defined")
5858
false
59-
} else if (dependency.keyOrdering.isDefined) {
60-
log.debug(s"Can't use UnsafeShuffle for shuffle $shufId because a key ordering is defined")
61-
false
6259
} else if (dependency.partitioner.numPartitions > MAX_SHUFFLE_OUTPUT_PARTITIONS) {
6360
log.debug(s"Can't use UnsafeShuffle for shuffle $shufId because it has more than " +
6461
s"$MAX_SHUFFLE_OUTPUT_PARTITIONS partitions")

core/src/test/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleManagerSuite.scala

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,15 @@ class UnsafeShuffleManagerSuite extends SparkFunSuite with Matchers {
7676
mapSideCombine = false
7777
)))
7878

79+
// Shuffles with key orderings are supported as long as no aggregator is specified
80+
assert(canUseUnsafeShuffle(shuffleDep(
81+
partitioner = new HashPartitioner(2),
82+
serializer = kryo,
83+
keyOrdering = Some(mock(classOf[Ordering[Any]])),
84+
aggregator = None,
85+
mapSideCombine = false
86+
)))
87+
7988
}
8089

8190
test("unsupported shuffle dependencies") {
@@ -100,22 +109,14 @@ class UnsafeShuffleManagerSuite extends SparkFunSuite with Matchers {
100109
mapSideCombine = false
101110
)))
102111

103-
// We do not support shuffles that perform any kind of aggregation or sorting of keys
104-
assert(!canUseUnsafeShuffle(shuffleDep(
105-
partitioner = new HashPartitioner(2),
106-
serializer = kryo,
107-
keyOrdering = Some(mock(classOf[Ordering[Any]])),
108-
aggregator = None,
109-
mapSideCombine = false
110-
)))
112+
// We do not support shuffles that perform aggregation
111113
assert(!canUseUnsafeShuffle(shuffleDep(
112114
partitioner = new HashPartitioner(2),
113115
serializer = kryo,
114116
keyOrdering = None,
115117
aggregator = Some(mock(classOf[Aggregator[Any, Any, Any]])),
116118
mapSideCombine = false
117119
)))
118-
// We do not support shuffles that perform any kind of aggregation or sorting of keys
119120
assert(!canUseUnsafeShuffle(shuffleDep(
120121
partitioner = new HashPartitioner(2),
121122
serializer = kryo,

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -672,13 +672,13 @@ trait HiveTypeCoercion {
672672
findTightestCommonTypeToString(left.dataType, right.dataType).map { widestType =>
673673
val newLeft = if (left.dataType == widestType) left else Cast(left, widestType)
674674
val newRight = if (right.dataType == widestType) right else Cast(right, widestType)
675-
i.makeCopy(Array(pred, newLeft, newRight))
675+
If(pred, newLeft, newRight)
676676
}.getOrElse(i) // If there is no applicable conversion, leave expression unchanged.
677677

678678
// Convert If(null literal, _, _) into boolean type.
679679
// In the optimizer, we should short-circuit this directly into false value.
680-
case i @ If(pred, left, right) if pred.dataType == NullType =>
681-
i.makeCopy(Array(Literal.create(null, BooleanType), left, right))
680+
case If(pred, left, right) if pred.dataType == NullType =>
681+
If(Literal.create(null, BooleanType), left, right)
682682
}
683683
}
684684

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ case class UnresolvedAttribute(nameParts: Seq[String])
6868
override def withName(newName: String): UnresolvedAttribute = UnresolvedAttribute.quoted(newName)
6969

7070
// Unresolved attributes are transient at compile time and don't get evaluated during execution.
71-
override def eval(input: catalyst.InternalRow = null): Any =
71+
override def eval(input: InternalRow = null): Any =
7272
throw new TreeNodeException(this, s"No function to evaluate expression. type: ${this.nodeName}")
7373

7474
override def toString: String = s"'$name"
@@ -86,7 +86,7 @@ case class UnresolvedFunction(name: String, children: Seq[Expression]) extends E
8686
override lazy val resolved = false
8787

8888
// Unresolved functions are transient at compile time and don't get evaluated during execution.
89-
override def eval(input: catalyst.InternalRow = null): Any =
89+
override def eval(input: InternalRow = null): Any =
9090
throw new TreeNodeException(this, s"No function to evaluate expression. type: ${this.nodeName}")
9191

9292
override def toString: String = s"'$name(${children.mkString(",")})"
@@ -108,7 +108,7 @@ trait Star extends NamedExpression with trees.LeafNode[Expression] {
108108
override lazy val resolved = false
109109

110110
// Star gets expanded at runtime so we never evaluate a Star.
111-
override def eval(input: catalyst.InternalRow = null): Any =
111+
override def eval(input: InternalRow = null): Any =
112112
throw new TreeNodeException(this, s"No function to evaluate expression. type: ${this.nodeName}")
113113

114114
def expand(input: Seq[Attribute], resolver: Resolver): Seq[NamedExpression]
@@ -167,7 +167,7 @@ case class MultiAlias(child: Expression, names: Seq[String])
167167

168168
override lazy val resolved = false
169169

170-
override def eval(input: catalyst.InternalRow = null): Any =
170+
override def eval(input: InternalRow = null): Any =
171171
throw new TreeNodeException(this, s"No function to evaluate expression. type: ${this.nodeName}")
172172

173173
override def toString: String = s"$child AS $names"
@@ -201,7 +201,7 @@ case class UnresolvedExtractValue(child: Expression, extraction: Expression)
201201
override def nullable: Boolean = throw new UnresolvedException(this, "nullable")
202202
override lazy val resolved = false
203203

204-
override def eval(input: catalyst.InternalRow = null): Any =
204+
override def eval(input: InternalRow = null): Any =
205205
throw new TreeNodeException(this, s"No function to evaluate expression. type: ${this.nodeName}")
206206

207207
override def toString: String = s"$child[$extraction]"

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
package org.apache.spark.sql.catalyst.expressions
1919

20+
import java.math.{BigDecimal => JavaBigDecimal}
2021
import java.sql.{Date, Timestamp}
2122
import java.text.{DateFormat, SimpleDateFormat}
2223

@@ -320,7 +321,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w
320321
private[this] def castToDecimal(from: DataType, target: DecimalType): Any => Any = from match {
321322
case StringType =>
322323
buildCast[UTF8String](_, s => try {
323-
changePrecision(Decimal(s.toString.toDouble), target)
324+
changePrecision(Decimal(new JavaBigDecimal(s.toString)), target)
324325
} catch {
325326
case _: NumberFormatException => null
326327
})
@@ -394,7 +395,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w
394395
}
395396
// TODO: Could be faster?
396397
val newRow = new GenericMutableRow(from.fields.size)
397-
buildCast[catalyst.InternalRow](_, row => {
398+
buildCast[InternalRow](_, row => {
398399
var i = 0
399400
while (i < row.length) {
400401
val v = row(i)
@@ -426,7 +427,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w
426427

427428
private[this] lazy val cast: Any => Any = cast(child.dataType, dataType)
428429

429-
override def eval(input: catalyst.InternalRow): Any = {
430+
override def eval(input: InternalRow): Any = {
430431
val evaluated = child.eval(input)
431432
if (evaluated == null) null else cast(evaluated)
432433
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExtractValue.scala

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -105,8 +105,8 @@ case class GetStructField(child: Expression, field: StructField, ordinal: Int)
105105
override def foldable: Boolean = child.foldable
106106
override def toString: String = s"$child.${field.name}"
107107

108-
override def eval(input: catalyst.InternalRow): Any = {
109-
val baseValue = child.eval(input).asInstanceOf[catalyst.InternalRow]
108+
override def eval(input: InternalRow): Any = {
109+
val baseValue = child.eval(input).asInstanceOf[InternalRow]
110110
if (baseValue == null) null else baseValue(ordinal)
111111
}
112112
}
@@ -125,8 +125,8 @@ case class GetArrayStructFields(
125125
override def foldable: Boolean = child.foldable
126126
override def toString: String = s"$child.${field.name}"
127127

128-
override def eval(input: catalyst.InternalRow): Any = {
129-
val baseValue = child.eval(input).asInstanceOf[Seq[catalyst.InternalRow]]
128+
override def eval(input: InternalRow): Any = {
129+
val baseValue = child.eval(input).asInstanceOf[Seq[InternalRow]]
130130
if (baseValue == null) null else {
131131
baseValue.map { row =>
132132
if (row == null) null else row(ordinal)
@@ -146,7 +146,7 @@ abstract class ExtractValueWithOrdinal extends ExtractValue {
146146
override def toString: String = s"$child[$ordinal]"
147147
override def children: Seq[Expression] = child :: ordinal :: Nil
148148

149-
override def eval(input: catalyst.InternalRow): Any = {
149+
override def eval(input: InternalRow): Any = {
150150
val value = child.eval(input)
151151
if (value == null) {
152152
null

0 commit comments

Comments
 (0)