@@ -490,29 +490,34 @@ def test_subplots(self):
490490 df = DataFrame (np .random .rand (10 , 3 ),
491491 index = list (string .ascii_letters [:10 ]))
492492
493- axes = df .plot (subplots = True , sharex = True , legend = True )
493+ for kind in ['bar' , 'barh' , 'line' ]:
494+ axes = df .plot (kind = kind , subplots = True , sharex = True , legend = True )
494495
495- for ax in axes :
496- self .assertIsNotNone (ax .get_legend ())
497-
498- axes = df .plot (subplots = True , sharex = True )
499- for ax in axes [:- 2 ]:
500- [self .assert_ (not label .get_visible ())
501- for label in ax .get_xticklabels ()]
502- [self .assert_ (label .get_visible ())
503- for label in ax .get_yticklabels ()]
496+ for ax , column in zip (axes , df .columns ):
497+ self ._check_legend_labels (ax , [column ])
504498
505- [self .assert_ (label .get_visible ())
506- for label in axes [- 1 ].get_xticklabels ()]
507- [self .assert_ (label .get_visible ())
508- for label in axes [- 1 ].get_yticklabels ()]
499+ axes = df .plot (kind = kind , subplots = True , sharex = True )
500+ for ax in axes [:- 2 ]:
501+ [self .assert_ (not label .get_visible ())
502+ for label in ax .get_xticklabels ()]
503+ [self .assert_ (label .get_visible ())
504+ for label in ax .get_yticklabels ()]
509505
510- axes = df .plot (subplots = True , sharex = False )
511- for ax in axes :
512506 [self .assert_ (label .get_visible ())
513- for label in ax .get_xticklabels ()]
507+ for label in axes [ - 1 ] .get_xticklabels ()]
514508 [self .assert_ (label .get_visible ())
515- for label in ax .get_yticklabels ()]
509+ for label in axes [- 1 ].get_yticklabels ()]
510+
511+ axes = df .plot (kind = kind , subplots = True , sharex = False )
512+ for ax in axes :
513+ [self .assert_ (label .get_visible ())
514+ for label in ax .get_xticklabels ()]
515+ [self .assert_ (label .get_visible ())
516+ for label in ax .get_yticklabels ()]
517+
518+ axes = df .plot (kind = kind , subplots = True , legend = False )
519+ for ax in axes :
520+ self .assertTrue (ax .get_legend () is None )
516521
517522 @slow
518523 def test_bar_colors (self ):
@@ -873,7 +878,7 @@ def test_kde(self):
873878 _check_plot_works (df .plot , kind = 'kde' )
874879 _check_plot_works (df .plot , kind = 'kde' , subplots = True )
875880 ax = df .plot (kind = 'kde' )
876- self .assertIsNotNone (ax . get_legend () )
881+ self ._check_legend_labels (ax , df . columns )
877882 axes = df .plot (kind = 'kde' , logy = True , subplots = True )
878883 for ax in axes :
879884 self .assertEqual (ax .get_yscale (), 'log' )
@@ -1046,6 +1051,64 @@ def test_plot_int_columns(self):
10461051 df = DataFrame (randn (100 , 4 )).cumsum ()
10471052 _check_plot_works (df .plot , legend = True )
10481053
1054+ def _check_legend_labels (self , ax , labels ):
1055+ import pandas .core .common as com
1056+ labels = [com .pprint_thing (l ) for l in labels ]
1057+ self .assertTrue (ax .get_legend () is not None )
1058+ legend_labels = [t .get_text () for t in ax .get_legend ().get_texts ()]
1059+ self .assertEqual (labels , legend_labels )
1060+
1061+ @slow
1062+ def test_df_legend_labels (self ):
1063+ kinds = 'line' , 'bar' , 'barh' , 'kde' , 'density'
1064+ df = DataFrame (randn (3 , 3 ), columns = ['a' , 'b' , 'c' ])
1065+ df2 = DataFrame (randn (3 , 3 ), columns = ['d' , 'e' , 'f' ])
1066+ df3 = DataFrame (randn (3 , 3 ), columns = ['g' , 'h' , 'i' ])
1067+ df4 = DataFrame (randn (3 , 3 ), columns = ['j' , 'k' , 'l' ])
1068+
1069+ for kind in kinds :
1070+ ax = df .plot (kind = kind , legend = True )
1071+ self ._check_legend_labels (ax , df .columns )
1072+
1073+ ax = df2 .plot (kind = kind , legend = False , ax = ax )
1074+ self ._check_legend_labels (ax , df .columns )
1075+
1076+ ax = df3 .plot (kind = kind , legend = True , ax = ax )
1077+ self ._check_legend_labels (ax , df .columns + df3 .columns )
1078+
1079+ ax = df4 .plot (kind = kind , legend = 'reverse' , ax = ax )
1080+ expected = list (df .columns + df3 .columns ) + list (reversed (df4 .columns ))
1081+ self ._check_legend_labels (ax , expected )
1082+
1083+ # Secondary Y
1084+ ax = df .plot (legend = True , secondary_y = 'b' )
1085+ self ._check_legend_labels (ax , ['a' , 'b (right)' , 'c' ])
1086+ ax = df2 .plot (legend = False , ax = ax )
1087+ self ._check_legend_labels (ax , ['a' , 'b (right)' , 'c' ])
1088+ ax = df3 .plot (kind = 'bar' , legend = True , secondary_y = 'h' , ax = ax )
1089+ self ._check_legend_labels (ax , ['a' , 'b (right)' , 'c' , 'g' , 'h (right)' , 'i' ])
1090+
1091+ # Time Series
1092+ ind = date_range ('1/1/2014' , periods = 3 )
1093+ df = DataFrame (randn (3 , 3 ), columns = ['a' , 'b' , 'c' ], index = ind )
1094+ df2 = DataFrame (randn (3 , 3 ), columns = ['d' , 'e' , 'f' ], index = ind )
1095+ df3 = DataFrame (randn (3 , 3 ), columns = ['g' , 'h' , 'i' ], index = ind )
1096+ ax = df .plot (legend = True , secondary_y = 'b' )
1097+ self ._check_legend_labels (ax , ['a' , 'b (right)' , 'c' ])
1098+ ax = df2 .plot (legend = False , ax = ax )
1099+ self ._check_legend_labels (ax , ['a' , 'b (right)' , 'c' ])
1100+ ax = df3 .plot (legend = True , ax = ax )
1101+ self ._check_legend_labels (ax , ['a' , 'b (right)' , 'c' , 'g' , 'h' , 'i' ])
1102+
1103+ # scatter
1104+ ax = df .plot (kind = 'scatter' , x = 'a' , y = 'b' , label = 'data1' )
1105+ self ._check_legend_labels (ax , ['data1' ])
1106+ ax = df2 .plot (kind = 'scatter' , x = 'd' , y = 'e' , legend = False ,
1107+ label = 'data2' , ax = ax )
1108+ self ._check_legend_labels (ax , ['data1' ])
1109+ ax = df3 .plot (kind = 'scatter' , x = 'g' , y = 'h' , label = 'data3' , ax = ax )
1110+ self ._check_legend_labels (ax , ['data1' , 'data3' ])
1111+
10491112 def test_legend_name (self ):
10501113 multi = DataFrame (randn (4 , 4 ),
10511114 columns = [np .array (['a' , 'a' , 'b' , 'b' ]),
@@ -1056,6 +1119,20 @@ def test_legend_name(self):
10561119 leg_title = ax .legend_ .get_title ()
10571120 self .assertEqual (leg_title .get_text (), 'group,individual' )
10581121
1122+ df = DataFrame (randn (5 , 5 ))
1123+ ax = df .plot (legend = True , ax = ax )
1124+ leg_title = ax .legend_ .get_title ()
1125+ self .assertEqual (leg_title .get_text (), 'group,individual' )
1126+
1127+ df .columns .name = 'new'
1128+ ax = df .plot (legend = False , ax = ax )
1129+ leg_title = ax .legend_ .get_title ()
1130+ self .assertEqual (leg_title .get_text (), 'group,individual' )
1131+
1132+ ax = df .plot (legend = True , ax = ax )
1133+ leg_title = ax .legend_ .get_title ()
1134+ self .assertEqual (leg_title .get_text (), 'new' )
1135+
10591136 def _check_plot_fails (self , f , * args , ** kwargs ):
10601137 with tm .assertRaises (Exception ):
10611138 f (* args , ** kwargs )
0 commit comments