Skip to content

Commit ffe5cfe

Browse files
committed
Cleanups
1 parent 18078c1 commit ffe5cfe

File tree

2 files changed

+21
-31
lines changed

2 files changed

+21
-31
lines changed

mllib/src/main/scala/org/apache/spark/ml/stat/Summarizer.scala

Lines changed: 18 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -176,15 +176,14 @@ object SummaryBuilderImpl extends Logging {
176176
* metrics that need to de computed internally to get the final result.
177177
*/
178178
private val allMetrics: Seq[(String, Metrics, DataType, Seq[ComputeMetrics])] = Seq(
179-
("mean", Mean, arrayDType, Seq(ComputeMean, ComputeWeightSum, ComputeTotalWeightSum)),
180-
("variance", Variance, arrayDType, Seq(ComputeTotalWeightSum, ComputeWeightSum,
181-
ComputeWeightSquareSum, ComputeMean, ComputeM2n)),
182-
("count", Count, LongType, Seq(ComputeCount)),
179+
("mean", Mean, arrayDType, Seq(ComputeMean, ComputeWeightSum)),
180+
("variance", Variance, arrayDType, Seq(ComputeWeightSum, ComputeMean, ComputeM2n)),
181+
("count", Count, LongType, Seq()),
183182
("numNonZeros", NumNonZeros, arrayLType, Seq(ComputeNNZ)),
184183
("max", Max, arrayDType, Seq(ComputeMax)),
185184
("min", Min, arrayDType, Seq(ComputeMin)),
186-
("normL2", NormL2, arrayDType, Seq(ComputeTotalWeightSum, ComputeM2)),
187-
("normL1", NormL1, arrayDType, Seq(ComputeTotalWeightSum, ComputeL1))
185+
("normL2", NormL2, arrayDType, Seq(ComputeM2)),
186+
("normL1", NormL1, arrayDType, Seq(ComputeL1))
188187
)
189188

190189
/**
@@ -210,9 +209,6 @@ object SummaryBuilderImpl extends Logging {
210209
case object ComputeM2n extends ComputeMetrics
211210
case object ComputeM2 extends ComputeMetrics
212211
case object ComputeL1 extends ComputeMetrics
213-
case object ComputeCount extends ComputeMetrics // Always computed -> TODO: remove
214-
case object ComputeTotalWeightSum extends ComputeMetrics // Always computed -> TODO: remove
215-
case object ComputeWeightSquareSum extends ComputeMetrics
216212
case object ComputeWeightSum extends ComputeMetrics
217213
case object ComputeNNZ extends ComputeMetrics
218214
case object ComputeMax extends ComputeMetrics
@@ -275,8 +271,8 @@ object SummaryBuilderImpl extends Logging {
275271
* (testing only). Makes a buffer with all the metrics enabled.
276272
*/
277273
def allMetrics(): Buffer = {
278-
fromMetrics(Seq(ComputeMean, ComputeM2n, ComputeM2, ComputeL1, ComputeCount,
279-
ComputeTotalWeightSum, ComputeWeightSquareSum, ComputeWeightSum, ComputeNNZ, ComputeMax,
274+
fromMetrics(Seq(ComputeMean, ComputeM2n, ComputeM2, ComputeL1,
275+
ComputeWeightSum, ComputeNNZ, ComputeMax,
280276
ComputeMin))
281277
}
282278

@@ -382,12 +378,9 @@ object SummaryBuilderImpl extends Logging {
382378
/**
383379
* Reads a buffer from a serialized form, using the row object as an assistant.
384380
*/
385-
def read(bytes: Array[Byte]): Buffer = {
386-
// TODO move this row outside to the aggregate
387-
assert(numFields == 12, numFields)
388-
val row3 = new UnsafeRow(numFields)
389-
row3.pointTo(bytes.clone(), bytes.length)
390-
val row = row3.getStruct(0, numFields)
381+
def read(bytes: Array[Byte], backingRow: UnsafeRow): Buffer = {
382+
backingRow.pointTo(bytes.clone(), bytes.length)
383+
val row = backingRow.getStruct(0, numFields)
391384
new Buffer(
392385
n = row.getInt(0),
393386
mean = nullableArrayD(row, 1),
@@ -405,7 +398,7 @@ object SummaryBuilderImpl extends Logging {
405398
}
406399

407400

408-
def write(buffer: Buffer): Array[Byte] = {
401+
def write(buffer: Buffer, project: UnsafeProjection): Array[Byte] = {
409402
val ir = InternalRow.apply(
410403
buffer.n,
411404
gadD(buffer.mean),
@@ -420,9 +413,7 @@ object SummaryBuilderImpl extends Logging {
420413
gadD(buffer.max),
421414
gadD(buffer.min)
422415
)
423-
// TODO: the projection should be passed as an argument.
424-
val projection = UnsafeProjection.create(bufferSchema)
425-
projection.apply(ir).getBytes
416+
project.apply(ir).getBytes
426417
}
427418

428419
def mean(buffer: Buffer): Array[Double] = {
@@ -685,7 +676,10 @@ object SummaryBuilderImpl extends Logging {
685676
inputAggBufferOffset: Int)
686677
extends TypedImperativeAggregate[Buffer] {
687678

688-
// private lazy val row = new UnsafeRow(Buffer.numFields)
679+
// These objects are not thread-safe, allocate them in the aggregator.
680+
private[this] lazy val row = new UnsafeRow(Buffer.numFields)
681+
private[this] lazy val projection = UnsafeProjection.create(Buffer.bufferSchema)
682+
689683

690684
override def eval(buff: Buffer): InternalRow = {
691685
val metrics = requested.map({
@@ -727,15 +721,11 @@ object SummaryBuilderImpl extends Logging {
727721
override def createAggregationBuffer(): Buffer = startBuffer.copy()
728722

729723
override def serialize(buff: Buffer): Array[Byte] = {
730-
val x = Buffer.write(buff)
731-
val b2 = deserialize(x)
732-
x
724+
Buffer.write(buff, projection)
733725
}
734726

735727
override def deserialize(bytes: Array[Byte]): Buffer = {
736-
// Buffer.read(bytes, row)
737-
assert(Buffer.numFields == 12, Buffer.numFields)
738-
Buffer.read(bytes)
728+
Buffer.read(bytes, row)
739729
}
740730

741731
override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): MetricsAggregate = {

mllib/src/test/scala/org/apache/spark/ml/stat/SummarizerSuite.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -54,8 +54,8 @@ class SummarizerSuite extends SparkFunSuite with MLlibTestSparkContext {
5454

5555
private def testExample(name: String, input: Seq[Any], exp: ExpectedMetrics): Unit = {
5656
def inputVec: Seq[Vector] = input.map {
57-
case x: Array[Double] => Vectors.dense(x)
58-
case x: Seq[Double] => Vectors.dense(x.toArray)
57+
case x: Array[Double @unchecked] => Vectors.dense(x)
58+
case x: Seq[Double @unchecked] => Vectors.dense(x.toArray)
5959
case x: Vector => x
6060
case x => throw new Exception(x.toString)
6161
}
@@ -170,7 +170,7 @@ class SummarizerSuite extends SparkFunSuite with MLlibTestSparkContext {
170170

171171
// Compares structured content.
172172
private def compareStructures(x1: Any, x2: Any, name: String): Unit = (x1, x2) match {
173-
case (y1: Seq[Double], v1: OldVector) =>
173+
case (y1: Seq[Double @unchecked], v1: OldVector) =>
174174
compareStructures(y1, v1.toArray.toSeq, name)
175175
case (d1: Double, d2: Double) =>
176176
assert2(Vectors.dense(d1) ~== Vectors.dense(d2) absTol 1e-4, name)

0 commit comments

Comments
 (0)