diff --git a/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala index cb42e9e4560cf..5dbe7b2427ace 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala @@ -242,16 +242,7 @@ class KeyValueGroupedDataset[K, V] private[sql]( def mapGroupsWithState[S: Encoder, U: Encoder]( func: (K, Iterator[V], GroupState[S]) => U): Dataset[U] = { val flatMapFunc = (key: K, it: Iterator[V], s: GroupState[S]) => Iterator(func(key, it, s)) - Dataset[U]( - sparkSession, - FlatMapGroupsWithState[K, V, S, U]( - flatMapFunc.asInstanceOf[(Any, Iterator[Any], LogicalGroupState[Any]) => Iterator[Any]], - groupingAttributes, - dataAttributes, - OutputMode.Update, - isMapGroupsWithState = true, - GroupStateTimeout.NoTimeout, - child = logicalPlan)) + flatMapGroupsWithState(OutputMode.Update, GroupStateTimeout.NoTimeout)(flatMapFunc) } /** @@ -278,16 +269,7 @@ class KeyValueGroupedDataset[K, V] private[sql]( timeoutConf: GroupStateTimeout)( func: (K, Iterator[V], GroupState[S]) => U): Dataset[U] = { val flatMapFunc = (key: K, it: Iterator[V], s: GroupState[S]) => Iterator(func(key, it, s)) - Dataset[U]( - sparkSession, - FlatMapGroupsWithState[K, V, S, U]( - flatMapFunc.asInstanceOf[(Any, Iterator[Any], LogicalGroupState[Any]) => Iterator[Any]], - groupingAttributes, - dataAttributes, - OutputMode.Update, - isMapGroupsWithState = true, - timeoutConf, - child = logicalPlan)) + flatMapGroupsWithState(OutputMode.Update, timeoutConf)(flatMapFunc) } /**