@@ -175,40 +175,56 @@ def test_where_other(self):
175175 def test_where_invalid_dtypes (self ):
176176 dti = date_range ("20130101" , periods = 3 , tz = "US/Eastern" )
177177
178- i2 = Index ([pd .NaT , pd .NaT ] + dti [2 :].tolist ())
178+ tail = dti [2 :].tolist ()
179+ i2 = Index ([pd .NaT , pd .NaT ] + tail )
179180
180- msg = "value should be a 'Timestamp', 'NaT', or array of those. Got"
181- msg2 = "Cannot compare tz-naive and tz-aware datetime-like objects"
182- with pytest .raises (TypeError , match = msg2 ):
183- # passing tz-naive ndarray to tzaware DTI
184- dti .where (notna (i2 ), i2 .values )
181+ mask = notna (i2 )
185182
186- with pytest .raises (TypeError , match = msg2 ):
187- # passing tz-aware DTI to tznaive DTI
188- dti .tz_localize (None ).where (notna (i2 ), i2 )
183+ # passing tz-naive ndarray to tzaware DTI
184+ result = dti .where (mask , i2 .values )
185+ expected = Index ([pd .NaT .asm8 , pd .NaT .asm8 ] + tail , dtype = object )
186+ tm .assert_index_equal (result , expected )
189187
190- with pytest .raises (TypeError , match = msg ):
191- dti .where (notna (i2 ), i2 .tz_localize (None ).to_period ("D" ))
188+ # passing tz-aware DTI to tznaive DTI
189+ naive = dti .tz_localize (None )
190+ result = naive .where (mask , i2 )
191+ expected = Index ([i2 [0 ], i2 [1 ]] + naive [2 :].tolist (), dtype = object )
192+ tm .assert_index_equal (result , expected )
192193
193- with pytest .raises (TypeError , match = msg ):
194- dti .where (notna (i2 ), i2 .asi8 .view ("timedelta64[ns]" ))
194+ pi = i2 .tz_localize (None ).to_period ("D" )
195+ result = dti .where (mask , pi )
196+ expected = Index ([pi [0 ], pi [1 ]] + tail , dtype = object )
197+ tm .assert_index_equal (result , expected )
195198
196- with pytest .raises (TypeError , match = msg ):
197- dti .where (notna (i2 ), i2 .asi8 )
199+ tda = i2 .asi8 .view ("timedelta64[ns]" )
200+ result = dti .where (mask , tda )
201+ expected = Index ([tda [0 ], tda [1 ]] + tail , dtype = object )
202+ assert isinstance (expected [0 ], np .timedelta64 )
203+ tm .assert_index_equal (result , expected )
198204
199- with pytest .raises (TypeError , match = msg ):
200- # non-matching scalar
201- dti .where (notna (i2 ), pd .Timedelta (days = 4 ))
205+ result = dti .where (mask , i2 .asi8 )
206+ expected = Index ([pd .NaT .value , pd .NaT .value ] + tail , dtype = object )
207+ assert isinstance (expected [0 ], int )
208+ tm .assert_index_equal (result , expected )
209+
210+ # non-matching scalar
211+ td = pd .Timedelta (days = 4 )
212+ result = dti .where (mask , td )
213+ expected = Index ([td , td ] + tail , dtype = object )
214+ assert expected [0 ] is td
215+ tm .assert_index_equal (result , expected )
202216
203217 def test_where_mismatched_nat (self , tz_aware_fixture ):
204218 tz = tz_aware_fixture
205219 dti = date_range ("2013-01-01" , periods = 3 , tz = tz )
206220 cond = np .array ([True , False , True ])
207221
208- msg = "value should be a 'Timestamp', 'NaT', or array of those. Got"
209- with pytest .raises (TypeError , match = msg ):
210- # wrong-dtyped NaT
211- dti .where (cond , np .timedelta64 ("NaT" , "ns" ))
222+ tdnat = np .timedelta64 ("NaT" , "ns" )
223+ expected = Index ([dti [0 ], tdnat , dti [2 ]], dtype = object )
224+ assert expected [1 ] is tdnat
225+
226+ result = dti .where (cond , tdnat )
227+ tm .assert_index_equal (result , expected )
212228
213229 def test_where_tz (self ):
214230 i = date_range ("20130101" , periods = 3 , tz = "US/Eastern" )
0 commit comments