@@ -691,14 +691,25 @@ def test_agg_relabel_multiindex_duplicates():
691691 tm .assert_frame_equal (result , expected )
692692
693693
694- def test_multiindex_custom_func ():
694+ @pytest .mark .parametrize (
695+ "func, expected_values" ,
696+ [
697+ (lambda s : s .mean (), [[3 , 2 ], [5.5 , 8.0 ], [1.5 , 3.0 ], [6.0 , 5.5 ]]),
698+ (np .mean , [[3.0 , 2.0 ], [5.5 , 8.0 ], [1.5 , 3.0 ], [6.0 , 5.5 ]]),
699+ (np .nanmean , [[3.0 , 2.0 ], [5.5 , 8.0 ], [1.5 , 3.0 ], [6.0 , 5.5 ]]),
700+ ],
701+ )
702+ def test_multiindex_custom_func (func , expected_values ):
695703 # GH 31777
696- df = pd .DataFrame (
697- np .random .rand (10 , 4 ), columns = pd .MultiIndex .from_product ([[1 , 2 ], [3 , 4 ]])
704+ data = [[1 , 4 , 2 , 8 ], [5 , 7 , 1 , 4 ], [2 , 8 , 1 , 4 ], [2 , 8 , 5 , 7 ]]
705+ df = pd .DataFrame (data , columns = pd .MultiIndex .from_product ([[1 , 2 ], [3 , 4 ]]))
706+ grp = df .groupby (np .r_ [np .zeros (2 ), np .ones (2 )])
707+ result = grp .agg (func )
708+ expected_keys = [(1 , 3 ), (1 , 4 ), (2 , 3 ), (2 , 4 )]
709+ expected = pd .DataFrame (
710+ {key : value for key , value in zip (expected_keys , expected_values )},
711+ index = Index ([0.0 , 1.0 ], dtype = float ),
698712 )
699- grp = df .groupby (np .r_ [np .ones (5 ), np .zeros (5 )])
700- result = grp .agg (lambda s : s .mean ())
701- expected = grp .agg ("mean" )
702713 tm .assert_frame_equal (result , expected )
703714
704715
0 commit comments