@@ -176,15 +176,14 @@ object SummaryBuilderImpl extends Logging {
176176 * metrics that need to de computed internally to get the final result.
177177 */
178178 private val allMetrics : Seq [(String , Metrics , DataType , Seq [ComputeMetrics ])] = Seq (
179- (" mean" , Mean , arrayDType, Seq (ComputeMean , ComputeWeightSum , ComputeTotalWeightSum )),
180- (" variance" , Variance , arrayDType, Seq (ComputeTotalWeightSum , ComputeWeightSum ,
181- ComputeWeightSquareSum , ComputeMean , ComputeM2n )),
182- (" count" , Count , LongType , Seq (ComputeCount )),
179+ (" mean" , Mean , arrayDType, Seq (ComputeMean , ComputeWeightSum )),
180+ (" variance" , Variance , arrayDType, Seq (ComputeWeightSum , ComputeMean , ComputeM2n )),
181+ (" count" , Count , LongType , Seq ()),
183182 (" numNonZeros" , NumNonZeros , arrayLType, Seq (ComputeNNZ )),
184183 (" max" , Max , arrayDType, Seq (ComputeMax )),
185184 (" min" , Min , arrayDType, Seq (ComputeMin )),
186- (" normL2" , NormL2 , arrayDType, Seq (ComputeTotalWeightSum , ComputeM2 )),
187- (" normL1" , NormL1 , arrayDType, Seq (ComputeTotalWeightSum , ComputeL1 ))
185+ (" normL2" , NormL2 , arrayDType, Seq (ComputeM2 )),
186+ (" normL1" , NormL1 , arrayDType, Seq (ComputeL1 ))
188187 )
189188
190189 /**
@@ -210,9 +209,6 @@ object SummaryBuilderImpl extends Logging {
210209 case object ComputeM2n extends ComputeMetrics
211210 case object ComputeM2 extends ComputeMetrics
212211 case object ComputeL1 extends ComputeMetrics
213- case object ComputeCount extends ComputeMetrics // Always computed -> TODO: remove
214- case object ComputeTotalWeightSum extends ComputeMetrics // Always computed -> TODO: remove
215- case object ComputeWeightSquareSum extends ComputeMetrics
216212 case object ComputeWeightSum extends ComputeMetrics
217213 case object ComputeNNZ extends ComputeMetrics
218214 case object ComputeMax extends ComputeMetrics
@@ -275,8 +271,8 @@ object SummaryBuilderImpl extends Logging {
275271 * (testing only). Makes a buffer with all the metrics enabled.
276272 */
277273 def allMetrics (): Buffer = {
278- fromMetrics(Seq (ComputeMean , ComputeM2n , ComputeM2 , ComputeL1 , ComputeCount ,
279- ComputeTotalWeightSum , ComputeWeightSquareSum , ComputeWeightSum , ComputeNNZ , ComputeMax ,
274+ fromMetrics(Seq (ComputeMean , ComputeM2n , ComputeM2 , ComputeL1 ,
275+ ComputeWeightSum , ComputeNNZ , ComputeMax ,
280276 ComputeMin ))
281277 }
282278
@@ -382,12 +378,9 @@ object SummaryBuilderImpl extends Logging {
382378 /**
383379 * Reads a buffer from a serialized form, using the row object as an assistant.
384380 */
385- def read (bytes : Array [Byte ]): Buffer = {
386- // TODO move this row outside to the aggregate
387- assert(numFields == 12 , numFields)
388- val row3 = new UnsafeRow (numFields)
389- row3.pointTo(bytes.clone(), bytes.length)
390- val row = row3.getStruct(0 , numFields)
381+ def read (bytes : Array [Byte ], backingRow : UnsafeRow ): Buffer = {
382+ backingRow.pointTo(bytes.clone(), bytes.length)
383+ val row = backingRow.getStruct(0 , numFields)
391384 new Buffer (
392385 n = row.getInt(0 ),
393386 mean = nullableArrayD(row, 1 ),
@@ -405,7 +398,7 @@ object SummaryBuilderImpl extends Logging {
405398 }
406399
407400
408- def write (buffer : Buffer ): Array [Byte ] = {
401+ def write (buffer : Buffer , project : UnsafeProjection ): Array [Byte ] = {
409402 val ir = InternalRow .apply(
410403 buffer.n,
411404 gadD(buffer.mean),
@@ -420,9 +413,7 @@ object SummaryBuilderImpl extends Logging {
420413 gadD(buffer.max),
421414 gadD(buffer.min)
422415 )
423- // TODO: the projection should be passed as an argument.
424- val projection = UnsafeProjection .create(bufferSchema)
425- projection.apply(ir).getBytes
416+ project.apply(ir).getBytes
426417 }
427418
428419 def mean (buffer : Buffer ): Array [Double ] = {
@@ -685,7 +676,10 @@ object SummaryBuilderImpl extends Logging {
685676 inputAggBufferOffset : Int )
686677 extends TypedImperativeAggregate [Buffer ] {
687678
688- // private lazy val row = new UnsafeRow(Buffer.numFields)
679+ // These objects are not thread-safe, allocate them in the aggregator.
680+ private [this ] lazy val row = new UnsafeRow (Buffer .numFields)
681+ private [this ] lazy val projection = UnsafeProjection .create(Buffer .bufferSchema)
682+
689683
690684 override def eval (buff : Buffer ): InternalRow = {
691685 val metrics = requested.map({
@@ -727,15 +721,11 @@ object SummaryBuilderImpl extends Logging {
727721 override def createAggregationBuffer (): Buffer = startBuffer.copy()
728722
729723 override def serialize (buff : Buffer ): Array [Byte ] = {
730- val x = Buffer .write(buff)
731- val b2 = deserialize(x)
732- x
724+ Buffer .write(buff, projection)
733725 }
734726
735727 override def deserialize (bytes : Array [Byte ]): Buffer = {
736- // Buffer.read(bytes, row)
737- assert(Buffer .numFields == 12 , Buffer .numFields)
738- Buffer .read(bytes)
728+ Buffer .read(bytes, row)
739729 }
740730
741731 override def withNewMutableAggBufferOffset (newMutableAggBufferOffset : Int ): MetricsAggregate = {
0 commit comments