@@ -454,8 +454,9 @@ def test_tz_dtype_matches(self):
454454
455455
456456class TestReductions :
457- @pytest .mark .parametrize ("tz" , [None , "US/Central" ])
458- def test_min_max (self , tz ):
457+ @pytest .fixture
458+ def arr1d (self , tz_naive_fixture ):
459+ tz = tz_naive_fixture
459460 dtype = DatetimeTZDtype (tz = tz ) if tz is not None else np .dtype ("M8[ns]" )
460461 arr = DatetimeArray ._from_sequence (
461462 [
@@ -468,6 +469,11 @@ def test_min_max(self, tz):
468469 ],
469470 dtype = dtype ,
470471 )
472+ return arr
473+
474+ def test_min_max (self , arr1d ):
475+ arr = arr1d
476+ tz = arr .tz
471477
472478 result = arr .min ()
473479 expected = pd .Timestamp ("2000-01-02" , tz = tz )
@@ -493,3 +499,70 @@ def test_min_max_empty(self, skipna, tz):
493499
494500 result = arr .max (skipna = skipna )
495501 assert result is pd .NaT
502+
503+ @pytest .mark .parametrize ("tz" , [None , "US/Central" ])
504+ @pytest .mark .parametrize ("skipna" , [True , False ])
505+ def test_median_empty (self , skipna , tz ):
506+ dtype = DatetimeTZDtype (tz = tz ) if tz is not None else np .dtype ("M8[ns]" )
507+ arr = DatetimeArray ._from_sequence ([], dtype = dtype )
508+ result = arr .median (skipna = skipna )
509+ assert result is pd .NaT
510+
511+ arr = arr .reshape (0 , 3 )
512+ result = arr .median (axis = 0 , skipna = skipna )
513+ expected = type (arr )._from_sequence ([pd .NaT , pd .NaT , pd .NaT ], dtype = arr .dtype )
514+ tm .assert_equal (result , expected )
515+
516+ result = arr .median (axis = 1 , skipna = skipna )
517+ expected = type (arr )._from_sequence ([pd .NaT ], dtype = arr .dtype )
518+ tm .assert_equal (result , expected )
519+
520+ def test_median (self , arr1d ):
521+ arr = arr1d
522+
523+ result = arr .median ()
524+ assert result == arr [0 ]
525+ result = arr .median (skipna = False )
526+ assert result is pd .NaT
527+
528+ result = arr .dropna ().median (skipna = False )
529+ assert result == arr [0 ]
530+
531+ result = arr .median (axis = 0 )
532+ assert result == arr [0 ]
533+
534+ def test_median_axis (self , arr1d ):
535+ arr = arr1d
536+ assert arr .median (axis = 0 ) == arr .median ()
537+ assert arr .median (axis = 0 , skipna = False ) is pd .NaT
538+
539+ msg = r"abs\(axis\) must be less than ndim"
540+ with pytest .raises (ValueError , match = msg ):
541+ arr .median (axis = 1 )
542+
543+ @pytest .mark .filterwarnings ("ignore:All-NaN slice encountered:RuntimeWarning" )
544+ def test_median_2d (self , arr1d ):
545+ arr = arr1d .reshape (1 , - 1 )
546+
547+ # axis = None
548+ assert arr .median () == arr1d .median ()
549+ assert arr .median (skipna = False ) is pd .NaT
550+
551+ # axis = 0
552+ result = arr .median (axis = 0 )
553+ expected = arr1d
554+ tm .assert_equal (result , expected )
555+
556+ # Since column 3 is all-NaT, we get NaT there with or without skipna
557+ result = arr .median (axis = 0 , skipna = False )
558+ expected = arr1d
559+ tm .assert_equal (result , expected )
560+
561+ # axis = 1
562+ result = arr .median (axis = 1 )
563+ expected = type (arr )._from_sequence ([arr1d .median ()])
564+ tm .assert_equal (result , expected )
565+
566+ result = arr .median (axis = 1 , skipna = False )
567+ expected = type (arr )._from_sequence ([pd .NaT ], dtype = arr .dtype )
568+ tm .assert_equal (result , expected )
0 commit comments