2020import  java .io .IOException ;
2121import  java .util .function .Supplier ;
2222
23- import  scala .collection .AbstractIterator ;
2423import  scala .collection .Iterator ;
2524import  scala .math .Ordering ;
2625
@@ -168,39 +167,40 @@ public void cleanupResources() {
168167    sorter .cleanupResources ();
169168  }
170169
171-   public  Iterator <UnsafeRow > sort () throws  IOException  {
170+   public  Iterator <InternalRow > sort () throws  IOException  {
172171    try  {
173172      final  UnsafeSorterIterator  sortedIterator  = sorter .getSortedIterator ();
174173      if  (!sortedIterator .hasNext ()) {
175174        // Since we won't ever call next() on an empty iterator, we need to clean up resources 
176175        // here in order to prevent memory leaks. 
177176        cleanupResources ();
178177      }
179-       return  new  AbstractIterator < UnsafeRow > () {
178+       return  new  RowIterator () {
180179
181180        private  final  int  numFields  = schema .length ();
182181        private  UnsafeRow  row  = new  UnsafeRow (numFields );
183182
184183        @ Override 
185-         public  boolean  hasNext () {
186-           return  !isReleased  && sortedIterator .hasNext ();
187-         }
188- 
189-         @ Override 
190-         public  UnsafeRow  next () {
184+         public  boolean  advanceNext () {
191185          try  {
192-             sortedIterator .loadNext ();
193-             row .pointTo (
194-               sortedIterator .getBaseObject (),
195-               sortedIterator .getBaseOffset (),
196-               sortedIterator .getRecordLength ());
197-             if  (!hasNext ()) {
198-               UnsafeRow  copy  = row .copy (); // so that we don't have dangling pointers to freed page 
199-               row  = null ; // so that we don't keep references to the base object 
200-               cleanupResources ();
201-               return  copy ;
186+             if  (!isReleased  && sortedIterator .hasNext ()) {
187+               sortedIterator .loadNext ();
188+               row .pointTo (
189+                   sortedIterator .getBaseObject (),
190+                   sortedIterator .getBaseOffset (),
191+                   sortedIterator .getRecordLength ());
192+               // Here is the initial bug fix in SPARK-9364: the bug fix of use-after-free bug 
193+               // when returning the last row from an iterator. For example, in 
194+               // [[GroupedIterator]], we still use the last row after traversing the iterator 
195+               // in `fetchNextGroupIterator` 
196+               if  (!sortedIterator .hasNext ()) {
197+                 row  = row .copy (); // so that we don't have dangling pointers to freed page 
198+                 cleanupResources ();
199+               }
200+               return  true ;
202201            } else  {
203-               return  row ;
202+               row  = null ; // so that we don't keep references to the base object 
203+               return  false ;
204204            }
205205          } catch  (IOException  e ) {
206206            cleanupResources ();
@@ -210,14 +210,18 @@ public UnsafeRow next() {
210210          }
211211          throw  new  RuntimeException ("Exception should have been re-thrown in next()" );
212212        }
213-       };
213+ 
214+         @ Override 
215+         public  UnsafeRow  getRow () { return  row ; }
216+ 
217+       }.toScala ();
214218    } catch  (IOException  e ) {
215219      cleanupResources ();
216220      throw  e ;
217221    }
218222  }
219223
220-   public  Iterator <UnsafeRow > sort (Iterator <UnsafeRow > inputIterator ) throws  IOException  {
224+   public  Iterator <InternalRow > sort (Iterator <UnsafeRow > inputIterator ) throws  IOException  {
221225    while  (inputIterator .hasNext ()) {
222226      insertRow (inputIterator .next ());
223227    }
0 commit comments