11""" define extension dtypes """
22import re
3- from typing import Any , Dict , Optional , Tuple , Type
3+ import typing
4+ from typing import Any , Dict , List , Optional , Tuple , Type
45import warnings
56
67import numpy as np
1516from .base import ExtensionDtype , _DtypeOpsMixin
1617from .inference import is_list_like
1718
18- str_type = str
1919
20-
21- def register_extension_dtype (cls ):
20+ def register_extension_dtype (cls : 'ExtensionDtype' ) -> 'ExtensionDtype' :
2221 """
2322 Register an ExtensionType with pandas as class decorator.
2423
@@ -60,20 +59,20 @@ class Registry:
6059 These are tried in order.
6160 """
6261 def __init__ (self ):
63- self .dtypes = []
62+ self .dtypes = [] # type: List[ExtensionDtype]
6463
65- def register (self , dtype ) :
64+ def register (self , dtype : 'ExtensionDtype' ) -> None :
6665 """
6766 Parameters
6867 ----------
6968 dtype : ExtensionDtype
7069 """
71- if not issubclass (dtype , ( PandasExtensionDtype , ExtensionDtype ) ):
70+ if not issubclass (dtype , ExtensionDtype ):
7271 raise ValueError ("can only register pandas extension dtypes" )
7372
7473 self .dtypes .append (dtype )
7574
76- def find (self , dtype ) :
75+ def find (self , dtype : 'ExtensionDtype' ) -> Optional [ ExtensionDtype ] :
7776 """
7877 Parameters
7978 ----------
@@ -117,25 +116,25 @@ class PandasExtensionDtype(_DtypeOpsMixin):
117116 # and ExtensionDtype's @properties in the subclasses below. The kind and
118117 # type variables in those subclasses are explicitly typed below.
119118 subdtype = None
120- str = None # type: Optional[str_type ]
119+ str = None # type: Optional[str ]
121120 num = 100
122121 shape = tuple () # type: Tuple[int, ...]
123122 itemsize = 8
124123 base = None
125124 isbuiltin = 0
126125 isnative = 0
127- _cache = {} # type: Dict[str_type , 'PandasExtensionDtype']
126+ _cache = {} # type: Dict[str , 'PandasExtensionDtype']
128127
129- def __unicode__ (self ):
128+ def __unicode__ (self ) -> str :
130129 return self .name
131130
132- def __str__ (self ):
131+ def __str__ (self ) -> str :
133132 """
134133 Return a string representation for a particular Object
135134 """
136135 return self .__unicode__ ()
137136
138- def __bytes__ (self ):
137+ def __bytes__ (self ) -> bytes :
139138 """
140139 Return a string representation for a particular object.
141140 """
@@ -144,22 +143,22 @@ def __bytes__(self):
144143 encoding = get_option ("display.encoding" )
145144 return self .__unicode__ ().encode (encoding , 'replace' )
146145
147- def __repr__ (self ):
146+ def __repr__ (self ) -> str :
148147 """
149148 Return a string representation for a particular object.
150149 """
151150 return str (self )
152151
153- def __hash__ (self ):
152+ def __hash__ (self ) -> int :
154153 raise NotImplementedError ("sub-classes should implement an __hash__ "
155154 "method" )
156155
157- def __getstate__ (self ):
156+ def __getstate__ (self ) -> Dict [ str , Any ] :
158157 # pickle support; we don't want to pickle the cache
159158 return {k : getattr (self , k , None ) for k in self ._metadata }
160159
161160 @classmethod
162- def reset_cache (cls ):
161+ def reset_cache (cls ) -> None :
163162 """ clear the cache """
164163 cls ._cache = {}
165164
@@ -217,23 +216,27 @@ class CategoricalDtype(PandasExtensionDtype, ExtensionDtype):
217216 # TODO: Document public vs. private API
218217 name = 'category'
219218 type = CategoricalDtypeType # type: Type[CategoricalDtypeType]
220- kind = 'O' # type: str_type
219+ kind = 'O' # type: str
221220 str = '|O08'
222221 base = np .dtype ('O' )
223222 _metadata = ('categories' , 'ordered' )
224- _cache = {} # type: Dict[str_type , PandasExtensionDtype]
223+ _cache = {} # type: Dict[str , PandasExtensionDtype]
225224
226- def __init__ (self , categories = None , ordered = None ):
225+ def __init__ (self , categories = None , ordered : bool = None ):
227226 self ._finalize (categories , ordered , fastpath = False )
228227
229228 @classmethod
230- def _from_fastpath (cls , categories = None , ordered = None ):
229+ def _from_fastpath (cls , categories = None , ordered : bool = None ):
231230 self = cls .__new__ (cls )
232231 self ._finalize (categories , ordered , fastpath = True )
233232 return self
234233
235234 @classmethod
236- def _from_categorical_dtype (cls , dtype , categories = None , ordered = None ):
235+ def _from_categorical_dtype (cls ,
236+ dtype : 'CategoricalDtype' ,
237+ categories = None ,
238+ ordered : bool = None ,
239+ ) -> 'CategoricalDtype' :
237240 if categories is ordered is None :
238241 return dtype
239242 if categories is None :
@@ -243,8 +246,11 @@ def _from_categorical_dtype(cls, dtype, categories=None, ordered=None):
243246 return cls (categories , ordered )
244247
245248 @classmethod
246- def _from_values_or_dtype (cls , values = None , categories = None , ordered = None ,
247- dtype = None ):
249+ def _from_values_or_dtype (cls ,
250+ values = None ,
251+ categories = None ,
252+ ordered : bool = None ,
253+ dtype : 'CategoricalDtype' = None ):
248254 """
249255 Construct dtype from the input parameters used in :class:`Categorical`.
250256
@@ -326,7 +332,11 @@ def _from_values_or_dtype(cls, values=None, categories=None, ordered=None,
326332
327333 return dtype
328334
329- def _finalize (self , categories , ordered , fastpath = False ):
335+ def _finalize (self ,
336+ categories ,
337+ ordered : Optional [bool ],
338+ fastpath : bool = False ,
339+ ) -> None :
330340
331341 if ordered is not None :
332342 self .validate_ordered (ordered )
@@ -338,14 +348,14 @@ def _finalize(self, categories, ordered, fastpath=False):
338348 self ._categories = categories
339349 self ._ordered = ordered
340350
341- def __setstate__ (self , state ) :
351+ def __setstate__ (self , state : 'Dict[str, Any]' ) -> None :
342352 # for pickle compat. __get_state__ is defined in the
343353 # PandasExtensionDtype superclass and uses the public properties to
344354 # pickle -> need to set the settable private ones here (see GH26067)
345355 self ._categories = state .pop ('categories' , None )
346356 self ._ordered = state .pop ('ordered' , False )
347357
348- def __hash__ (self ):
358+ def __hash__ (self ) -> int :
349359 # _hash_categories returns a uint64, so use the negative
350360 # space for when we have unknown categories to avoid a conflict
351361 if self .categories is None :
@@ -356,7 +366,7 @@ def __hash__(self):
356366 # We *do* want to include the real self.ordered here
357367 return int (self ._hash_categories (self .categories , self .ordered ))
358368
359- def __eq__ (self , other ) :
369+ def __eq__ (self , other : Any ) -> bool :
360370 """
361371 Rules for CDT equality:
362372 1) Any CDT is equal to the string 'category'
@@ -403,7 +413,7 @@ def __repr__(self):
403413 return tpl .format (data , self .ordered )
404414
405415 @staticmethod
406- def _hash_categories (categories , ordered = True ):
416+ def _hash_categories (categories , ordered : bool = True ) -> int :
407417 from pandas .core .util .hashing import (
408418 hash_array , _combine_hash_arrays , hash_tuples
409419 )
@@ -453,7 +463,7 @@ def construct_array_type(cls):
453463 return Categorical
454464
455465 @classmethod
456- def construct_from_string (cls , string ) :
466+ def construct_from_string (cls , string : str ) -> 'CategoricalDtype' :
457467 """
458468 attempt to construct this type from a string, raise a TypeError if
459469 it's not possible """
@@ -466,7 +476,7 @@ def construct_from_string(cls, string):
466476 pass
467477
468478 @staticmethod
469- def validate_ordered (ordered ) :
479+ def validate_ordered (ordered : bool ) -> None :
470480 """
471481 Validates that we have a valid ordered parameter. If
472482 it is not a boolean, a TypeError will be raised.
@@ -486,7 +496,7 @@ def validate_ordered(ordered):
486496 raise TypeError ("'ordered' must either be 'True' or 'False'" )
487497
488498 @staticmethod
489- def validate_categories (categories , fastpath = False ):
499+ def validate_categories (categories , fastpath : bool = False ):
490500 """
491501 Validates that we have good categories
492502
@@ -521,7 +531,7 @@ def validate_categories(categories, fastpath=False):
521531
522532 return categories
523533
524- def update_dtype (self , dtype ) :
534+ def update_dtype (self , dtype : 'CategoricalDtype' ) -> 'CategoricalDtype' :
525535 """
526536 Returns a CategoricalDtype with categories and ordered taken from dtype
527537 if specified, otherwise falling back to self if unspecified
@@ -560,17 +570,18 @@ def categories(self):
560570 """
561571 An ``Index`` containing the unique categories allowed.
562572 """
563- return self ._categories
573+ from pandas import Index
574+ return typing .cast (Index , self ._categories )
564575
565576 @property
566- def ordered (self ):
577+ def ordered (self ) -> bool :
567578 """
568579 Whether the categories have an ordered relationship.
569580 """
570581 return self ._ordered
571582
572583 @property
573- def _is_boolean (self ):
584+ def _is_boolean (self ) -> bool :
574585 from pandas .core .dtypes .common import is_bool_dtype
575586
576587 return is_bool_dtype (self .categories )
@@ -614,14 +625,14 @@ class DatetimeTZDtype(PandasExtensionDtype, ExtensionDtype):
614625 datetime64[ns, tzfile('/usr/share/zoneinfo/US/Central')]
615626 """
616627 type = Timestamp # type: Type[Timestamp]
617- kind = 'M' # type: str_type
628+ kind = 'M' # type: str
618629 str = '|M8[ns]'
619630 num = 101
620631 base = np .dtype ('M8[ns]' )
621632 na_value = NaT
622633 _metadata = ('unit' , 'tz' )
623634 _match = re .compile (r"(datetime64|M8)\[(?P<unit>.+), (?P<tz>.+)\]" )
624- _cache = {} # type: Dict[str_type , PandasExtensionDtype]
635+ _cache = {} # type: Dict[str , PandasExtensionDtype]
625636
626637 def __init__ (self , unit = "ns" , tz = None ):
627638 if isinstance (unit , DatetimeTZDtype ):
@@ -765,13 +776,13 @@ class PeriodDtype(ExtensionDtype, PandasExtensionDtype):
765776 period[M]
766777 """
767778 type = Period # type: Type[Period]
768- kind = 'O' # type: str_type
779+ kind = 'O' # type: str
769780 str = '|O08'
770781 base = np .dtype ('O' )
771782 num = 102
772783 _metadata = ('freq' ,)
773784 _match = re .compile (r"(P|p)eriod\[(?P<freq>.+)\]" )
774- _cache = {} # type: Dict[str_type , PandasExtensionDtype]
785+ _cache = {} # type: Dict[str , PandasExtensionDtype]
775786
776787 def __new__ (cls , freq = None ):
777788 """
@@ -919,13 +930,13 @@ class IntervalDtype(PandasExtensionDtype, ExtensionDtype):
919930 interval[int64]
920931 """
921932 name = 'interval'
922- kind = None # type: Optional[str_type ]
933+ kind = None # type: Optional[str ]
923934 str = '|O08'
924935 base = np .dtype ('O' )
925936 num = 103
926937 _metadata = ('subtype' ,)
927938 _match = re .compile (r"(I|i)nterval\[(?P<subtype>.+)\]" )
928- _cache = {} # type: Dict[str_type , PandasExtensionDtype]
939+ _cache = {} # type: Dict[str , PandasExtensionDtype]
929940
930941 def __new__ (cls , subtype = None ):
931942 from pandas .core .dtypes .common import (
0 commit comments