Skip to content

Commit a13ab67

Browse files
committed
improve
1 parent 62c75f2 commit a13ab67

File tree

3 files changed

+25
-12
lines changed

3 files changed

+25
-12
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -302,7 +302,7 @@ object ScalaReflection extends ScalaReflection {
302302

303303
case t if t <:< localTypeOf[Seq[_]] =>
304304
val TypeRef(_, _, Seq(elementType)) = t
305-
val Schema(dataType, elementNullable) = schemaFor(elementType)
305+
val Schema(_, elementNullable) = schemaFor(elementType)
306306
val className = getClassNameFromType(elementType)
307307
val newTypePath = s"""- array element class: "$className"""" +: walkedTypePath
308308

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala

Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -452,7 +452,11 @@ case class UnresolvedMapObjects(
452452
function: Expression => Expression,
453453
child: Expression,
454454
customCollectionCls: Option[Class[_]] = None) extends UnaryExpression with Unevaluable {
455-
override def dataType: DataType = throw new UnsupportedOperationException("not resolved")
455+
override lazy val resolved = false
456+
457+
override def dataType: DataType = customCollectionCls.map(ObjectType.apply).getOrElse {
458+
throw new UnsupportedOperationException("not resolved")
459+
}
456460
}
457461

458462
/**
@@ -588,17 +592,24 @@ case class MapObjects private(
588592
// collection
589593
val collObjectName = s"${cls.getName}$$.MODULE$$"
590594
val getBuilderVar = s"$collObjectName.newBuilder()"
591-
592-
(s"""${classOf[Builder[_, _]].getName} $builderValue = $getBuilderVar;
593-
$builderValue.sizeHint($dataLength);""",
595+
(
596+
s"""
597+
${classOf[Builder[_, _]].getName} $builderValue = $getBuilderVar;
598+
$builderValue.sizeHint($dataLength);
599+
""",
594600
genValue => s"$builderValue.$$plus$$eq($genValue);",
595-
s"(${cls.getName}) $builderValue.result();")
601+
s"(${cls.getName}) $builderValue.result();"
602+
)
596603
case None =>
597604
// array
598-
(s"""$convertedType[] $convertedArray = null;
599-
$convertedArray = $arrayConstructor;""",
605+
(
606+
s"""
607+
$convertedType[] $convertedArray = null;
608+
$convertedArray = $arrayConstructor;
609+
""",
600610
genValue => s"$convertedArray[$loopIndex] = $genValue;",
601-
s"new ${classOf[GenericArrayData].getName}($convertedArray);")
611+
s"new ${classOf[GenericArrayData].getName}($convertedArray);"
612+
)
602613
}
603614

604615
val code = s"""

sql/core/src/main/scala/org/apache/spark/sql/expressions/ReduceAggregator.scala

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,17 +27,19 @@ import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
2727
*
2828
* This class currently assumes there is at least one input row.
2929
*/
30-
private[sql] class ReduceAggregator[T](func: (T, T) => T)(@transient implicit val enc: Encoder[T])
30+
private[sql] class ReduceAggregator[T: Encoder](func: (T, T) => T)
3131
extends Aggregator[T, (Boolean, T), T] {
3232

33+
@transient private val encoder = implicitly[Encoder[T]]
34+
3335
override def zero: (Boolean, T) = (false, null.asInstanceOf[T])
3436

3537
override def bufferEncoder: Encoder[(Boolean, T)] =
3638
ExpressionEncoder.tuple(
3739
ExpressionEncoder[Boolean](),
38-
enc.asInstanceOf[ExpressionEncoder[T]])
40+
encoder.asInstanceOf[ExpressionEncoder[T]])
3941

40-
override def outputEncoder: Encoder[T] = enc
42+
override def outputEncoder: Encoder[T] = encoder
4143

4244
override def reduce(b: (Boolean, T), a: T): (Boolean, T) = {
4345
if (b._1) {

0 commit comments

Comments
 (0)