From dd85fa0fd12d1337cfb62b89c563864b338875d7 Mon Sep 17 00:00:00 2001 From: jreback Date: Tue, 10 Jun 2014 12:53:05 -0400 Subject: [PATCH] PERF: Series.transform speedups (GH6496) --- doc/source/v0.14.1.txt | 2 +- pandas/core/groupby.py | 34 +++++++++++++++++++--------------- pandas/tests/test_groupby.py | 2 ++ vb_suite/groupby.py | 18 ++++++++++++++++++ 4 files changed, 40 insertions(+), 16 deletions(-) diff --git a/doc/source/v0.14.1.txt b/doc/source/v0.14.1.txt index 2b76da1434ba3..04231f08787f4 100644 --- a/doc/source/v0.14.1.txt +++ b/doc/source/v0.14.1.txt @@ -137,7 +137,7 @@ Performance - Improvements in dtype inference for numeric operations involving yielding performance gains for dtypes: ``int64``, ``timedelta64``, ``datetime64`` (:issue:`7223`) - +- Improvements in Series.transform for signifcant performance gains (:issue`6496`) diff --git a/pandas/core/groupby.py b/pandas/core/groupby.py index e6af3c20bea00..c50df6f9bb08f 100644 --- a/pandas/core/groupby.py +++ b/pandas/core/groupby.py @@ -14,7 +14,7 @@ from pandas.core.categorical import Categorical from pandas.core.frame import DataFrame from pandas.core.generic import NDFrame -from pandas.core.index import Index, MultiIndex, _ensure_index +from pandas.core.index import Index, MultiIndex, _ensure_index, _union_indexes from pandas.core.internals import BlockManager, make_block from pandas.core.series import Series from pandas.core.panel import Panel @@ -425,7 +425,7 @@ def convert(key, s): return Timestamp(key).asm8 return key - sample = list(self.indices)[0] + sample = next(iter(self.indices)) if isinstance(sample, tuple): if not isinstance(name, tuple): raise ValueError("must supply a tuple to get_group with multiple grouping keys") @@ -2193,33 +2193,37 @@ def transform(self, func, *args, **kwargs): ------- transformed : Series """ - result = self._selected_obj.copy() - if hasattr(result, 'values'): - result = result.values - dtype = result.dtype + dtype = self._selected_obj.dtype if isinstance(func, compat.string_types): wrapper = lambda x: getattr(x, func)(*args, **kwargs) else: wrapper = lambda x: func(x, *args, **kwargs) - for name, group in self: + result = self._selected_obj.values.copy() + for i, (name, group) in enumerate(self): object.__setattr__(group, 'name', name) res = wrapper(group) + if hasattr(res, 'values'): res = res.values - # need to do a safe put here, as the dtype may be different - # this needs to be an ndarray - result = Series(result) - result.iloc[self._get_index(name)] = res - result = result.values + # may need to astype + try: + common_type = np.common_type(np.array(res), result) + if common_type != result.dtype: + result = result.astype(common_type) + except: + pass + + indexer = self._get_index(name) + result[indexer] = res - # downcast if we can (and need) result = _possibly_downcast_to_dtype(result, dtype) - return self._selected_obj.__class__(result, index=self._selected_obj.index, - name=self._selected_obj.name) + return self._selected_obj.__class__(result, + index=self._selected_obj.index, + name=self._selected_obj.name) def filter(self, func, dropna=True, *args, **kwargs): """ diff --git a/pandas/tests/test_groupby.py b/pandas/tests/test_groupby.py index 1da51ce824120..14380c83de79e 100644 --- a/pandas/tests/test_groupby.py +++ b/pandas/tests/test_groupby.py @@ -126,8 +126,10 @@ def checkit(dtype): assert_series_equal(agged, grouped.mean()) assert_series_equal(grouped.agg(np.sum), grouped.sum()) + expected = grouped.apply(lambda x: x * x.sum()) transformed = grouped.transform(lambda x: x * x.sum()) self.assertEqual(transformed[7], 12) + assert_series_equal(transformed, expected) value_grouped = data.groupby(data) assert_series_equal(value_grouped.aggregate(np.mean), agged) diff --git a/vb_suite/groupby.py b/vb_suite/groupby.py index 6f2132ff9b154..f61c60d939907 100644 --- a/vb_suite/groupby.py +++ b/vb_suite/groupby.py @@ -376,3 +376,21 @@ def f(g): """ groupby_transform = Benchmark("data.groupby(level='security_id').transform(f_fillna)", setup) + +setup = common_setup + """ +np.random.seed(0) + +N = 120000 +N_TRANSITIONS = 1400 + +# generate groups +transition_points = np.random.permutation(np.arange(N))[:N_TRANSITIONS] +transition_points.sort() +transitions = np.zeros((N,), dtype=np.bool) +transitions[transition_points] = True +g = transitions.cumsum() + +df = DataFrame({ 'signal' : np.random.rand(N)}) +""" + +groupby_transform2 = Benchmark("df['signal'].groupby(g).transform(np.mean)", setup)