@@ -4458,7 +4458,8 @@ def shift_months(int64_t[:] dtindex, int months, object day=None):
44584458 Py_ssize_t i
44594459 pandas_datetimestruct dts
44604460 int count = len (dtindex)
4461- int days_in_current_month
4461+ int months_to_roll
4462+ bint roll_check
44624463 int64_t[:] out = np.empty(count, dtype = ' int64' )
44634464
44644465 if day is None :
@@ -4472,36 +4473,44 @@ def shift_months(int64_t[:] dtindex, int months, object day=None):
44724473 dts.day = min (dts.day, days_in_month(dts))
44734474 out[i] = pandas_datetimestruct_to_datetime(PANDAS_FR_ns, & dts)
44744475 elif day == ' start' :
4476+ roll_check = False
4477+ if months <= 0 :
4478+ months += 1
4479+ roll_check = True
44754480 with nogil:
44764481 for i in range (count):
44774482 if dtindex[i] == NPY_NAT: out[i] = NPY_NAT; continue
44784483 pandas_datetime_to_datetimestruct(dtindex[i], PANDAS_FR_ns, & dts)
4479- dts.year = _year_add_months(dts, months)
4480- dts.month = _month_add_months(dts, months)
4484+ months_to_roll = months
4485+
4486+ # offset semantics - if on the anchor point and going backwards
4487+ # shift to next
4488+ if roll_check and dts.day == 1 :
4489+ months_to_roll -= 1
4490+
4491+ dts.year = _year_add_months(dts, months_to_roll)
4492+ dts.month = _month_add_months(dts, months_to_roll)
4493+ dts.day = 1
44814494
4482- # offset semantics - when subtracting if at the start anchor
4483- # point, shift back by one more month
4484- if months <= 0 and dts.day == 1 :
4485- dts.year = _year_add_months(dts, - 1 )
4486- dts.month = _month_add_months(dts, - 1 )
4487- else :
4488- dts.day = 1
44894495 out[i] = pandas_datetimestruct_to_datetime(PANDAS_FR_ns, & dts)
44904496 elif day == ' end' :
4497+ roll_check = False
4498+ if months > 0 :
4499+ months -= 1
4500+ roll_check = True
44914501 with nogil:
44924502 for i in range (count):
44934503 if dtindex[i] == NPY_NAT: out[i] = NPY_NAT; continue
44944504 pandas_datetime_to_datetimestruct(dtindex[i], PANDAS_FR_ns, & dts)
4495- days_in_current_month = days_in_month(dts)
4496-
4497- dts.year = _year_add_months(dts, months)
4498- dts.month = _month_add_months(dts, months)
4505+ months_to_roll = months
44994506
45004507 # similar semantics - when adding shift forward by one
45014508 # month if already at an end of month
4502- if months >= 0 and dts.day == days_in_current_month:
4503- dts.year = _year_add_months(dts, 1 )
4504- dts.month = _month_add_months(dts, 1 )
4509+ if roll_check and dts.day == days_in_month(dts):
4510+ months_to_roll += 1
4511+
4512+ dts.year = _year_add_months(dts, months_to_roll)
4513+ dts.month = _month_add_months(dts, months_to_roll)
45054514
45064515 dts.day = days_in_month(dts)
45074516 out[i] = pandas_datetimestruct_to_datetime(PANDAS_FR_ns, & dts)
0 commit comments