11""" define extension dtypes """
22import re
3- from typing import Any , Dict , Optional , Tuple , Type
3+ import typing
4+ from typing import Any , Dict , List , Optional , Tuple , Type
45import warnings
56
67import numpy as np
1516from .base import ExtensionDtype , _DtypeOpsMixin
1617from .inference import is_list_like
1718
18- str_type = str
1919
20-
21- def register_extension_dtype (cls ):
20+ def register_extension_dtype (cls : 'ExtensionDtype' ) -> 'ExtensionDtype' :
2221 """
2322 Register an ExtensionType with pandas as class decorator.
2423
@@ -60,20 +59,20 @@ class Registry:
6059 These are tried in order.
6160 """
6261 def __init__ (self ):
63- self .dtypes = []
62+ self .dtypes = [] # type: List[ExtensionDtype]
6463
65- def register (self , dtype ) :
64+ def register (self , dtype : 'ExtensionDtype' ) -> None :
6665 """
6766 Parameters
6867 ----------
6968 dtype : ExtensionDtype
7069 """
71- if not issubclass (dtype , ( PandasExtensionDtype , ExtensionDtype ) ):
70+ if not issubclass (dtype , ExtensionDtype ):
7271 raise ValueError ("can only register pandas extension dtypes" )
7372
7473 self .dtypes .append (dtype )
7574
76- def find (self , dtype ) :
75+ def find (self , dtype : 'ExtensionDtype' ) -> Optional [ ExtensionDtype ] :
7776 """
7877 Parameters
7978 ----------
@@ -117,25 +116,25 @@ class PandasExtensionDtype(_DtypeOpsMixin):
117116 # and ExtensionDtype's @properties in the subclasses below. The kind and
118117 # type variables in those subclasses are explicitly typed below.
119118 subdtype = None
120- str = None # type: Optional[str_type ]
119+ str = None # type: Optional[str ]
121120 num = 100
122121 shape = tuple () # type: Tuple[int, ...]
123122 itemsize = 8
124123 base = None
125124 isbuiltin = 0
126125 isnative = 0
127- _cache = {} # type: Dict[str_type , 'PandasExtensionDtype']
126+ _cache = {} # type: Dict[str , 'PandasExtensionDtype']
128127
129- def __unicode__ (self ):
128+ def __unicode__ (self ) -> str :
130129 return self .name
131130
132- def __str__ (self ):
131+ def __str__ (self ) -> str :
133132 """
134133 Return a string representation for a particular Object
135134 """
136135 return self .__unicode__ ()
137136
138- def __bytes__ (self ):
137+ def __bytes__ (self ) -> bytes :
139138 """
140139 Return a string representation for a particular object.
141140 """
@@ -144,22 +143,22 @@ def __bytes__(self):
144143 encoding = get_option ("display.encoding" )
145144 return self .__unicode__ ().encode (encoding , 'replace' )
146145
147- def __repr__ (self ):
146+ def __repr__ (self ) -> str :
148147 """
149148 Return a string representation for a particular object.
150149 """
151150 return str (self )
152151
153- def __hash__ (self ):
152+ def __hash__ (self ) -> int :
154153 raise NotImplementedError ("sub-classes should implement an __hash__ "
155154 "method" )
156155
157- def __getstate__ (self ):
156+ def __getstate__ (self ) -> Dict [ str , Any ] :
158157 # pickle support; we don't want to pickle the cache
159158 return {k : getattr (self , k , None ) for k in self ._metadata }
160159
161160 @classmethod
162- def reset_cache (cls ):
161+ def reset_cache (cls ) -> None :
163162 """ clear the cache """
164163 cls ._cache = {}
165164
@@ -217,23 +216,30 @@ class CategoricalDtype(PandasExtensionDtype, ExtensionDtype):
217216 # TODO: Document public vs. private API
218217 name = 'category'
219218 type = CategoricalDtypeType # type: Type[CategoricalDtypeType]
220- kind = 'O' # type: str_type
219+ kind = 'O' # type: str
221220 str = '|O08'
222221 base = np .dtype ('O' )
223222 _metadata = ('categories' , 'ordered' )
224- _cache = {} # type: Dict[str_type , PandasExtensionDtype]
223+ _cache = {} # type: Dict[str , PandasExtensionDtype]
225224
226- def __init__ (self , categories = None , ordered = None ):
225+ def __init__ (self , categories = None , ordered : bool = None ):
227226 self ._finalize (categories , ordered , fastpath = False )
228227
229228 @classmethod
230- def _from_fastpath (cls , categories = None , ordered = None ):
229+ def _from_fastpath (cls ,
230+ categories = None ,
231+ ordered : bool = None
232+ ) -> 'CategoricalDtype' :
231233 self = cls .__new__ (cls )
232234 self ._finalize (categories , ordered , fastpath = True )
233235 return self
234236
235237 @classmethod
236- def _from_categorical_dtype (cls , dtype , categories = None , ordered = None ):
238+ def _from_categorical_dtype (cls ,
239+ dtype : 'CategoricalDtype' ,
240+ categories = None ,
241+ ordered : bool = None ,
242+ ) -> 'CategoricalDtype' :
237243 if categories is ordered is None :
238244 return dtype
239245 if categories is None :
@@ -243,8 +249,12 @@ def _from_categorical_dtype(cls, dtype, categories=None, ordered=None):
243249 return cls (categories , ordered )
244250
245251 @classmethod
246- def _from_values_or_dtype (cls , values = None , categories = None , ordered = None ,
247- dtype = None ):
252+ def _from_values_or_dtype (cls ,
253+ values = None ,
254+ categories = None ,
255+ ordered : bool = None ,
256+ dtype : 'CategoricalDtype' = None ,
257+ ) -> 'CategoricalDtype' :
248258 """
249259 Construct dtype from the input parameters used in :class:`Categorical`.
250260
@@ -326,7 +336,11 @@ def _from_values_or_dtype(cls, values=None, categories=None, ordered=None,
326336
327337 return dtype
328338
329- def _finalize (self , categories , ordered , fastpath = False ):
339+ def _finalize (self ,
340+ categories ,
341+ ordered : Optional [bool ],
342+ fastpath : bool = False ,
343+ ) -> None :
330344
331345 if ordered is not None :
332346 self .validate_ordered (ordered )
@@ -338,14 +352,14 @@ def _finalize(self, categories, ordered, fastpath=False):
338352 self ._categories = categories
339353 self ._ordered = ordered
340354
341- def __setstate__ (self , state ) :
355+ def __setstate__ (self , state : 'Dict[str, Any]' ) -> None :
342356 # for pickle compat. __get_state__ is defined in the
343357 # PandasExtensionDtype superclass and uses the public properties to
344358 # pickle -> need to set the settable private ones here (see GH26067)
345359 self ._categories = state .pop ('categories' , None )
346360 self ._ordered = state .pop ('ordered' , False )
347361
348- def __hash__ (self ):
362+ def __hash__ (self ) -> int :
349363 # _hash_categories returns a uint64, so use the negative
350364 # space for when we have unknown categories to avoid a conflict
351365 if self .categories is None :
@@ -356,7 +370,7 @@ def __hash__(self):
356370 # We *do* want to include the real self.ordered here
357371 return int (self ._hash_categories (self .categories , self .ordered ))
358372
359- def __eq__ (self , other ) :
373+ def __eq__ (self , other : Any ) -> bool :
360374 """
361375 Rules for CDT equality:
362376 1) Any CDT is equal to the string 'category'
@@ -403,7 +417,7 @@ def __repr__(self):
403417 return tpl .format (data , self .ordered )
404418
405419 @staticmethod
406- def _hash_categories (categories , ordered = True ):
420+ def _hash_categories (categories , ordered : bool = True ) -> int :
407421 from pandas .core .util .hashing import (
408422 hash_array , _combine_hash_arrays , hash_tuples
409423 )
@@ -453,7 +467,7 @@ def construct_array_type(cls):
453467 return Categorical
454468
455469 @classmethod
456- def construct_from_string (cls , string ) :
470+ def construct_from_string (cls , string : str ) -> 'CategoricalDtype' :
457471 """
458472 attempt to construct this type from a string, raise a TypeError if
459473 it's not possible """
@@ -466,7 +480,7 @@ def construct_from_string(cls, string):
466480 pass
467481
468482 @staticmethod
469- def validate_ordered (ordered ) :
483+ def validate_ordered (ordered : bool ) -> None :
470484 """
471485 Validates that we have a valid ordered parameter. If
472486 it is not a boolean, a TypeError will be raised.
@@ -486,7 +500,7 @@ def validate_ordered(ordered):
486500 raise TypeError ("'ordered' must either be 'True' or 'False'" )
487501
488502 @staticmethod
489- def validate_categories (categories , fastpath = False ):
503+ def validate_categories (categories , fastpath : bool = False ):
490504 """
491505 Validates that we have good categories
492506
@@ -521,7 +535,7 @@ def validate_categories(categories, fastpath=False):
521535
522536 return categories
523537
524- def update_dtype (self , dtype ) :
538+ def update_dtype (self , dtype : 'CategoricalDtype' ) -> 'CategoricalDtype' :
525539 """
526540 Returns a CategoricalDtype with categories and ordered taken from dtype
527541 if specified, otherwise falling back to self if unspecified
@@ -560,17 +574,18 @@ def categories(self):
560574 """
561575 An ``Index`` containing the unique categories allowed.
562576 """
563- return self ._categories
577+ from pandas import Index
578+ return typing .cast (Index , self ._categories )
564579
565580 @property
566- def ordered (self ):
581+ def ordered (self ) -> bool :
567582 """
568583 Whether the categories have an ordered relationship.
569584 """
570585 return self ._ordered
571586
572587 @property
573- def _is_boolean (self ):
588+ def _is_boolean (self ) -> bool :
574589 from pandas .core .dtypes .common import is_bool_dtype
575590
576591 return is_bool_dtype (self .categories )
@@ -614,14 +629,14 @@ class DatetimeTZDtype(PandasExtensionDtype, ExtensionDtype):
614629 datetime64[ns, tzfile('/usr/share/zoneinfo/US/Central')]
615630 """
616631 type = Timestamp # type: Type[Timestamp]
617- kind = 'M' # type: str_type
632+ kind = 'M' # type: str
618633 str = '|M8[ns]'
619634 num = 101
620635 base = np .dtype ('M8[ns]' )
621636 na_value = NaT
622637 _metadata = ('unit' , 'tz' )
623638 _match = re .compile (r"(datetime64|M8)\[(?P<unit>.+), (?P<tz>.+)\]" )
624- _cache = {} # type: Dict[str_type , PandasExtensionDtype]
639+ _cache = {} # type: Dict[str , PandasExtensionDtype]
625640
626641 def __init__ (self , unit = "ns" , tz = None ):
627642 if isinstance (unit , DatetimeTZDtype ):
@@ -765,13 +780,13 @@ class PeriodDtype(ExtensionDtype, PandasExtensionDtype):
765780 period[M]
766781 """
767782 type = Period # type: Type[Period]
768- kind = 'O' # type: str_type
783+ kind = 'O' # type: str
769784 str = '|O08'
770785 base = np .dtype ('O' )
771786 num = 102
772787 _metadata = ('freq' ,)
773788 _match = re .compile (r"(P|p)eriod\[(?P<freq>.+)\]" )
774- _cache = {} # type: Dict[str_type , PandasExtensionDtype]
789+ _cache = {} # type: Dict[str , PandasExtensionDtype]
775790
776791 def __new__ (cls , freq = None ):
777792 """
@@ -919,13 +934,13 @@ class IntervalDtype(PandasExtensionDtype, ExtensionDtype):
919934 interval[int64]
920935 """
921936 name = 'interval'
922- kind = None # type: Optional[str_type ]
937+ kind = None # type: Optional[str ]
923938 str = '|O08'
924939 base = np .dtype ('O' )
925940 num = 103
926941 _metadata = ('subtype' ,)
927942 _match = re .compile (r"(I|i)nterval\[(?P<subtype>.+)\]" )
928- _cache = {} # type: Dict[str_type , PandasExtensionDtype]
943+ _cache = {} # type: Dict[str , PandasExtensionDtype]
929944
930945 def __new__ (cls , subtype = None ):
931946 from pandas .core .dtypes .common import (
0 commit comments