44import numpy as np
55import pytest
66
7+ from pandas ._libs .tslibs import tz_compare
8+ from pandas ._libs .tslibs .dtypes import NpyDatetimeUnit
9+
710from pandas .core .dtypes .dtypes import DatetimeTZDtype
811
912import pandas as pd
@@ -20,28 +23,48 @@ def unit(self, request):
2023 @pytest .fixture
2124 def reso (self , unit ):
2225 """Fixture returning datetime resolution for a given time unit"""
23- # TODO: avoid hard-coding
24- return {"s" : 7 , "ms" : 8 , "us" : 9 }[unit ]
26+ return {
27+ "s" : NpyDatetimeUnit .NPY_FR_s .value ,
28+ "ms" : NpyDatetimeUnit .NPY_FR_ms .value ,
29+ "us" : NpyDatetimeUnit .NPY_FR_us .value ,
30+ }[unit ]
31+
32+ @pytest .fixture
33+ def dtype (self , unit , tz_naive_fixture ):
34+ tz = tz_naive_fixture
35+ if tz is None :
36+ return np .dtype (f"datetime64[{ unit } ]" )
37+ else :
38+ return DatetimeTZDtype (unit = unit , tz = tz )
2539
26- @pytest .mark .xfail (reason = "_box_func is not yet patched to get reso right" )
27- def test_non_nano (self , unit , reso ):
40+ def test_non_nano (self , unit , reso , dtype ):
2841 arr = np .arange (5 , dtype = np .int64 ).view (f"M8[{ unit } ]" )
29- dta = DatetimeArray ._simple_new (arr , dtype = arr . dtype )
42+ dta = DatetimeArray ._simple_new (arr , dtype = dtype )
3043
31- assert dta .dtype == arr . dtype
44+ assert dta .dtype == dtype
3245 assert dta [0 ]._reso == reso
46+ assert tz_compare (dta .tz , dta [0 ].tz )
47+ assert (dta [0 ] == dta [:1 ]).all ()
3348
3449 @pytest .mark .filterwarnings (
3550 "ignore:weekofyear and week have been deprecated:FutureWarning"
3651 )
3752 @pytest .mark .parametrize (
3853 "field" , DatetimeArray ._field_ops + DatetimeArray ._bool_ops
3954 )
40- def test_fields (self , unit , reso , field ):
41- dti = pd .date_range ("2016-01-01" , periods = 55 , freq = "D" )
42- arr = np .asarray (dti ).astype (f"M8[{ unit } ]" )
55+ def test_fields (self , unit , reso , field , dtype ):
56+ tz = getattr (dtype , "tz" , None )
57+ dti = pd .date_range ("2016-01-01" , periods = 55 , freq = "D" , tz = tz )
58+ if tz is None :
59+ arr = np .asarray (dti ).astype (f"M8[{ unit } ]" )
60+ else :
61+ arr = np .asarray (dti .tz_convert ("UTC" ).tz_localize (None )).astype (
62+ f"M8[{ unit } ]"
63+ )
4364
44- dta = DatetimeArray ._simple_new (arr , dtype = arr .dtype )
65+ dta = DatetimeArray ._simple_new (arr , dtype = dtype )
66+
67+ # FIXME: assert (dti == dta).all()
4568
4669 res = getattr (dta , field )
4770 expected = getattr (dti ._data , field )
0 commit comments