1717
1818package org .apache .spark .sql .execution
1919
20- import scala .collection ._
21-
2220import org .apache .spark .annotation .DeveloperApi
2321
22+ import org .apache .spark .util .collection .{OpenHashSet , OpenHashMap }
23+
2424import org .apache .spark .sql .catalyst .errors ._
2525import org .apache .spark .sql .catalyst .expressions ._
2626import org .apache .spark .sql .catalyst .plans .physical ._
@@ -39,7 +39,7 @@ sealed case class AggregateFunctionBind(
3939sealed class InputBufferSeens (
4040 var input : Row , //
4141 var buffer : MutableRow ,
42- var seens : Array [mutable. HashSet [Any ]] = null ) {
42+ var seens : Array [OpenHashSet [Any ]] = null ) {
4343 def this () {
4444 this (new GenericMutableRow (0 ), null )
4545 }
@@ -54,7 +54,7 @@ sealed class InputBufferSeens(
5454 this
5555 }
5656
57- def withSeens (seens : Array [mutable. HashSet [Any ]]): InputBufferSeens = {
57+ def withSeens (seens : Array [OpenHashSet [Any ]]): InputBufferSeens = {
5858 this .seens = seens
5959 this
6060 }
@@ -250,20 +250,13 @@ case class AggregatePreShuffle(
250250
251251 createIterator(aggregates, Iterator (new InputBufferSeens ().withBuffer(buffer)))
252252 } else {
253- val results = new mutable. HashMap [Row , InputBufferSeens ]()
253+ val results = new OpenHashMap [Row , InputBufferSeens ]()
254254 while (iter.hasNext) {
255255 val currentRow = iter.next()
256256
257257 val keys = groupByProjection(currentRow)
258- results.get(keys) match {
259- case Some (inputbuffer) =>
260- var idx = 0
261- while (idx < aggregates.length) {
262- val ae = aggregates(idx)
263- ae.iterate(ae.eval(currentRow), inputbuffer.buffer)
264- idx += 1
265- }
266- case None =>
258+ results(keys) match {
259+ case null =>
267260 val buffer = new GenericMutableRow (bufferSchema.length)
268261 var idx = 0
269262 while (idx < aggregates.length) {
@@ -278,11 +271,19 @@ case class AggregatePreShuffle(
278271 }
279272
280273 val copies = keys.copy()
281- results.put(copies, new InputBufferSeens (copies, buffer))
274+ results(copies) = new InputBufferSeens (copies, buffer)
275+ case inputbuffer =>
276+ var idx = 0
277+ while (idx < aggregates.length) {
278+ val ae = aggregates(idx)
279+ ae.iterate(ae.eval(currentRow), inputbuffer.buffer)
280+ idx += 1
281+ }
282+
282283 }
283284 }
284285
285- createIterator(aggregates, results.valuesIterator )
286+ createIterator(aggregates, results.iterator.map(_._2) )
286287 }
287288 }
288289 }
@@ -328,32 +329,32 @@ case class AggregatePostShuffle(
328329
329330 createIterator(aggregates, Iterator (new InputBufferSeens ().withBuffer(buffer)))
330331 } else {
331- val results = new mutable. HashMap [Row , InputBufferSeens ]()
332+ val results = new OpenHashMap [Row , InputBufferSeens ]()
332333 while (iter.hasNext) {
333334 val currentRow = iter.next()
334335 val keys = groupByProjection(currentRow)
335- results.get(keys) match {
336- case Some (pair) =>
336+ results(keys) match {
337+ case null =>
338+ val buffer = new GenericMutableRow (bufferSchema.length)
337339 var idx = 0
338340 while (idx < aggregates.length) {
339341 val ae = aggregates(idx)
340- ae.merge(currentRow, pair.buffer)
342+ ae.reset(buffer)
343+ ae.merge(currentRow, buffer)
341344 idx += 1
342345 }
343- case None =>
344- val buffer = new GenericMutableRow (bufferSchema.length)
346+ results(keys.copy()) = new InputBufferSeens (currentRow.copy(), buffer)
347+ case pair =>
345348 var idx = 0
346349 while (idx < aggregates.length) {
347350 val ae = aggregates(idx)
348- ae.reset(buffer)
349- ae.merge(currentRow, buffer)
351+ ae.merge(currentRow, pair.buffer)
350352 idx += 1
351353 }
352- results.put(keys.copy(), new InputBufferSeens (currentRow.copy(), buffer))
353354 }
354355 }
355356
356- createIterator(aggregates, results.valuesIterator )
357+ createIterator(aggregates, results.iterator.map(_._2) )
357358 }
358359 }
359360 }
@@ -383,15 +384,15 @@ case class DistinctAggregate(
383384 if (groupingExpressions.isEmpty) {
384385 val buffer = new GenericMutableRow (bufferSchema.length)
385386 // TODO save the memory only for those DISTINCT aggregate expressions
386- val seens = new Array [mutable. HashSet [Any ]](aggregateFunctionBinds.length)
387+ val seens = new Array [OpenHashSet [Any ]](aggregateFunctionBinds.length)
387388
388389 var idx = 0
389390 while (idx < aggregateFunctionBinds.length) {
390391 val ae = aggregates(idx)
391392 ae.reset(buffer)
392393
393394 if (ae.distinct) {
394- seens(idx) = new mutable. HashSet [Any ]()
395+ seens(idx) = new OpenHashSet [Any ]()
395396 }
396397
397398 idx += 1
@@ -420,56 +421,57 @@ case class DistinctAggregate(
420421
421422 createIterator(aggregates, Iterator (ibs))
422423 } else {
423- val results = new mutable. HashMap [Row , InputBufferSeens ]()
424+ val results = new OpenHashMap [Row , InputBufferSeens ]()
424425
425426 while (iter.hasNext) {
426427 val currentRow = iter.next()
427428
428429 val keys = groupByProjection(currentRow)
429- results.get(keys) match {
430- case Some (inputBufferSeens) =>
430+ results(keys) match {
431+ case null =>
432+ val buffer = new GenericMutableRow (bufferSchema.length)
433+ // TODO save the memory only for those DISTINCT aggregate expressions
434+ val seens = new Array [OpenHashSet [Any ]](aggregateFunctionBinds.length)
435+
431436 var idx = 0
432437 while (idx < aggregateFunctionBinds.length) {
433438 val ae = aggregates(idx)
434439 val value = ae.eval(currentRow)
440+ ae.reset(buffer)
441+ ae.iterate(value, buffer)
435442
436443 if (ae.distinct) {
437- if (value != null && ! inputBufferSeens.seens(idx).contains(value)) {
438- ae.iterate (value, inputBufferSeens.buffer)
439- inputBufferSeens.seens(idx) .add(value)
444+ val seen = new OpenHashSet [ Any ]()
445+ if (value != null ) {
446+ seen .add(value)
440447 }
441- } else {
442- ae.iterate(value, inputBufferSeens.buffer)
448+ seens.update(idx, seen)
443449 }
450+
444451 idx += 1
445452 }
446- case None =>
447- val buffer = new GenericMutableRow (bufferSchema.length)
448- // TODO save the memory only for those DISTINCT aggregate expressions
449- val seens = new Array [mutable.HashSet [Any ]](aggregateFunctionBinds.length)
453+ results(keys.copy()) = new InputBufferSeens (currentRow.copy(), buffer, seens)
450454
455+ case inputBufferSeens =>
451456 var idx = 0
452457 while (idx < aggregateFunctionBinds.length) {
453458 val ae = aggregates(idx)
454459 val value = ae.eval(currentRow)
455- ae.reset(buffer)
456- ae.iterate(value, buffer)
457460
458461 if (ae.distinct) {
459- val seen = new mutable. HashSet [ Any ]()
460- if (value != null ) {
461- seen .add(value)
462+ if (value != null && ! inputBufferSeens.seens(idx).contains(value)) {
463+ ae.iterate (value, inputBufferSeens.buffer)
464+ inputBufferSeens.seens(idx) .add(value)
462465 }
463- seens.update(idx, seen)
466+ } else {
467+ ae.iterate(value, inputBufferSeens.buffer)
464468 }
465-
466469 idx += 1
467470 }
468- results.put(keys.copy(), new InputBufferSeens (currentRow.copy(), buffer, seens))
469471 }
470472 }
471473
472- createIterator(aggregates, results.valuesIterator )
474+ createIterator(aggregates, results.iterator.map(_._2) )
473475 }
474476 }
475477 }
0 commit comments