1818package org .apache .spark .sql .execution
1919
2020import org .apache .spark .{SparkEnv , TaskContext }
21+ import org .apache .spark .executor .TaskMetrics
2122import org .apache .spark .rdd .RDD
2223import org .apache .spark .sql .catalyst .InternalRow
2324import org .apache .spark .sql .catalyst .expressions ._
@@ -115,6 +116,9 @@ case class Sort(
115116 sorterVariable = ctx.freshName(" sorter" )
116117 ctx.addMutableState(classOf [UnsafeExternalRowSorter ].getName, sorterVariable,
117118 s " $sorterVariable = $thisPlan.createSorter(); " )
119+ val metrics = ctx.freshName(" metrics" )
120+ ctx.addMutableState(classOf [TaskMetrics ].getName, metrics,
121+ s " $metrics = org.apache.spark.TaskContext.get().taskMetrics(); " )
118122 val sortedIterator = ctx.freshName(" sortedIter" )
119123 ctx.addMutableState(" scala.collection.Iterator<UnsafeRow>" , sortedIterator, " " )
120124
@@ -127,10 +131,20 @@ case class Sort(
127131 """ .stripMargin.trim)
128132
129133 val outputRow = ctx.freshName(" outputRow" )
134+ val dataSize = ctx.freshName(" dataSize" )
135+ ctx.addMutableState(classOf [Long ].getName, dataSize, " " )
136+ val spillSize = ctx.freshName(" spillSize" )
137+ ctx.addMutableState(classOf [Long ].getName, spillSize, " " )
138+ val spillSizeBefore = ctx.freshName(" spillSizeBefore" )
139+ ctx.addMutableState(classOf [Long ].getName, spillSizeBefore,
140+ s " $spillSizeBefore = $metrics.memoryBytesSpilled(); " )
130141 s """
131142 | if ( $needToSort) {
132143 | $addToSorter();
133144 | $sortedIterator = $sorterVariable.sort();
145+ | $dataSize += $sorterVariable.getPeakMemoryUsage();
146+ | $spillSize += $metrics.memoryBytesSpilled() - $spillSizeBefore;
147+ | $metrics.incPeakExecutionMemory( $sorterVariable.getPeakMemoryUsage());
134148 | $needToSort = false;
135149 | }
136150 |
0 commit comments