11import random
2+ from typing import TYPE_CHECKING , Dict , List , Optional , Set
23
34import matplotlib .lines as mlines
45import matplotlib .patches as patches
56import numpy as np
67
8+ from pandas ._typing import Label
9+
710from pandas .core .dtypes .missing import notna
811
912from pandas .io .formats .printing import pprint_thing
1013from pandas .plotting ._matplotlib .style import _get_standard_colors
1114from pandas .plotting ._matplotlib .tools import _set_ticks_props , _subplots
1215
16+ if TYPE_CHECKING :
17+ from matplotlib .axes import Axes
18+ from matplotlib .figure import Figure
19+
20+ from pandas import DataFrame , Series
21+
1322
1423def scatter_matrix (
15- frame ,
24+ frame : "DataFrame" ,
1625 alpha = 0.5 ,
1726 figsize = None ,
1827 ax = None ,
@@ -114,7 +123,14 @@ def _get_marker_compat(marker):
114123 return marker
115124
116125
117- def radviz (frame , class_column , ax = None , color = None , colormap = None , ** kwds ):
126+ def radviz (
127+ frame : "DataFrame" ,
128+ class_column ,
129+ ax : Optional ["Axes" ] = None ,
130+ color = None ,
131+ colormap = None ,
132+ ** kwds ,
133+ ) -> "Axes" :
118134 import matplotlib .pyplot as plt
119135
120136 def normalize (series ):
@@ -130,7 +146,7 @@ def normalize(series):
130146 if ax is None :
131147 ax = plt .gca (xlim = [- 1 , 1 ], ylim = [- 1 , 1 ])
132148
133- to_plot = {}
149+ to_plot : Dict [ Label , List [ List ]] = {}
134150 colors = _get_standard_colors (
135151 num_colors = len (classes ), colormap = colormap , color_type = "random" , color = color
136152 )
@@ -197,8 +213,14 @@ def normalize(series):
197213
198214
199215def andrews_curves (
200- frame , class_column , ax = None , samples = 200 , color = None , colormap = None , ** kwds
201- ):
216+ frame : "DataFrame" ,
217+ class_column ,
218+ ax : Optional ["Axes" ] = None ,
219+ samples : int = 200 ,
220+ color = None ,
221+ colormap = None ,
222+ ** kwds ,
223+ ) -> "Axes" :
202224 import matplotlib .pyplot as plt
203225
204226 def function (amplitudes ):
@@ -231,7 +253,7 @@ def f(t):
231253 classes = frame [class_column ].drop_duplicates ()
232254 df = frame .drop (class_column , axis = 1 )
233255 t = np .linspace (- np .pi , np .pi , samples )
234- used_legends = set ()
256+ used_legends : Set [ str ] = set ()
235257
236258 color_values = _get_standard_colors (
237259 num_colors = len (classes ), colormap = colormap , color_type = "random" , color = color
@@ -256,7 +278,13 @@ def f(t):
256278 return ax
257279
258280
259- def bootstrap_plot (series , fig = None , size = 50 , samples = 500 , ** kwds ):
281+ def bootstrap_plot (
282+ series : "Series" ,
283+ fig : Optional ["Figure" ] = None ,
284+ size : int = 50 ,
285+ samples : int = 500 ,
286+ ** kwds ,
287+ ) -> "Figure" :
260288
261289 import matplotlib .pyplot as plt
262290
@@ -306,19 +334,19 @@ def bootstrap_plot(series, fig=None, size=50, samples=500, **kwds):
306334
307335
308336def parallel_coordinates (
309- frame ,
337+ frame : "DataFrame" ,
310338 class_column ,
311339 cols = None ,
312- ax = None ,
340+ ax : Optional [ "Axes" ] = None ,
313341 color = None ,
314342 use_columns = False ,
315343 xticks = None ,
316344 colormap = None ,
317- axvlines = True ,
345+ axvlines : bool = True ,
318346 axvlines_kwds = None ,
319- sort_labels = False ,
347+ sort_labels : bool = False ,
320348 ** kwds ,
321- ):
349+ ) -> "Axes" :
322350 import matplotlib .pyplot as plt
323351
324352 if axvlines_kwds is None :
@@ -333,7 +361,7 @@ def parallel_coordinates(
333361 else :
334362 df = frame [cols ]
335363
336- used_legends = set ()
364+ used_legends : Set [ str ] = set ()
337365
338366 ncols = len (df .columns )
339367
@@ -385,7 +413,9 @@ def parallel_coordinates(
385413 return ax
386414
387415
388- def lag_plot (series , lag = 1 , ax = None , ** kwds ):
416+ def lag_plot (
417+ series : "Series" , lag : int = 1 , ax : Optional ["Axes" ] = None , ** kwds
418+ ) -> "Axes" :
389419 # workaround because `c='b'` is hardcoded in matplotlib's scatter method
390420 import matplotlib .pyplot as plt
391421
@@ -402,7 +432,9 @@ def lag_plot(series, lag=1, ax=None, **kwds):
402432 return ax
403433
404434
405- def autocorrelation_plot (series , ax = None , ** kwds ):
435+ def autocorrelation_plot (
436+ series : "Series" , ax : Optional ["Axes" ] = None , ** kwds
437+ ) -> "Axes" :
406438 import matplotlib .pyplot as plt
407439
408440 n = len (series )
0 commit comments