@@ -2791,6 +2791,59 @@ def test_insert_error_msmgs(self):
27912791 with assertRaisesRegexp(TypeError, msg):
27922792 df['gr'] = df.groupby(['b', 'c']).count()
27932793
2794+ def test_frame_subclassing_and_slicing(self):
2795+ # Subclass frame and ensure it returns the right class on slicing it
2796+ # In reference to PR 9632
2797+
2798+ class CustomSeries(Series):
2799+ @property
2800+ def _constructor(self):
2801+ return CustomSeries
2802+
2803+ def custom_series_function(self):
2804+ return 'OK'
2805+
2806+ class CustomDataFrame(DataFrame):
2807+ "Subclasses pandas DF, fills DF with simulation results, adds some custom plotting functions."
2808+
2809+ def __init__(self, *args, **kw):
2810+ super(CustomDataFrame, self).__init__(*args, **kw)
2811+
2812+ @property
2813+ def _constructor(self):
2814+ return CustomDataFrame
2815+
2816+ _constructor_sliced = CustomSeries
2817+
2818+ def custom_frame_function(self):
2819+ return 'OK'
2820+
2821+ data = {'col1': range(10),
2822+ 'col2': range(10)}
2823+ cdf = CustomDataFrame(data)
2824+
2825+ # Did we get back our own DF class?
2826+ self.assertTrue(isinstance(cdf, CustomDataFrame))
2827+
2828+ # Do we get back our own Series class after selecting a column?
2829+ cdf_series = cdf.col1
2830+ self.assertTrue(isinstance(cdf_series, CustomSeries))
2831+ self.assertEqual(cdf_series.custom_series_function(), 'OK')
2832+
2833+ # Do we get back our own DF class after slicing row-wise?
2834+ cdf_rows = cdf[1:5]
2835+ self.assertTrue(isinstance(cdf_rows, CustomDataFrame))
2836+ self.assertEqual(cdf_rows.custom_frame_function(), 'OK')
2837+
2838+ # Make sure sliced part of multi-index frame is custom class
2839+ mcol = pd.MultiIndex.from_tuples([('A', 'A'), ('A', 'B')])
2840+ cdf_multi = CustomDataFrame([[0, 1], [2, 3]], columns=mcol)
2841+ self.assertTrue(isinstance(cdf_multi['A'], CustomDataFrame))
2842+
2843+ mcol = pd.MultiIndex.from_tuples([('A', ''), ('B', '')])
2844+ cdf_multi2 = CustomDataFrame([[0, 1], [2, 3]], columns=mcol)
2845+ self.assertTrue(isinstance(cdf_multi2['A'], CustomSeries))
2846+
27942847 def test_constructor_subclass_dict(self):
27952848 # Test for passing dict subclass to constructor
27962849 data = {'col1': tm.TestSubDict((x, 10.0 * x) for x in range(10)),
0 commit comments