@@ -174,6 +174,18 @@ def get_date_name_field(const int64_t[:] dtindex, str field, object locale=None)
174174 return out
175175
176176
177+ cdef inline bint _is_on_month(int month, int compare_month, int modby) nogil:
178+ """
179+ Analogous to DateOffset.is_on_offset checking for the month part of a date.
180+ """
181+ if modby == 1 :
182+ return True
183+ elif modby == 3 :
184+ return (month - compare_month) % 3 == 0
185+ else :
186+ return month == compare_month
187+
188+
177189@ cython.wraparound (False )
178190@ cython.boundscheck (False )
179191def get_start_end_field (const int64_t[:] dtindex , str field ,
@@ -191,6 +203,7 @@ def get_start_end_field(const int64_t[:] dtindex, str field,
191203 int start_month = 1
192204 ndarray[int8_t] out
193205 npy_datetimestruct dts
206+ int compare_month, modby
194207
195208 out = np.zeros(count, dtype = ' int8' )
196209
@@ -215,102 +228,15 @@ def get_start_end_field(const int64_t[:] dtindex, str field,
215228 end_month = 12
216229 start_month = 1
217230
218- if field == ' is_month_start' :
219- if is_business:
220- for i in range (count):
221- if dtindex[i] == NPY_NAT:
222- out[i] = 0
223- continue
224-
225- dt64_to_dtstruct(dtindex[i], & dts)
226-
227- if dts.day == get_firstbday(dts.year, dts.month):
228- out[i] = 1
229-
230- else :
231- for i in range (count):
232- if dtindex[i] == NPY_NAT:
233- out[i] = 0
234- continue
235-
236- dt64_to_dtstruct(dtindex[i], & dts)
237-
238- if dts.day == 1 :
239- out[i] = 1
240-
241- elif field == ' is_month_end' :
242- if is_business:
243- for i in range (count):
244- if dtindex[i] == NPY_NAT:
245- out[i] = 0
246- continue
247-
248- dt64_to_dtstruct(dtindex[i], & dts)
249-
250- if dts.day == get_lastbday(dts.year, dts.month):
251- out[i] = 1
252-
253- else :
254- for i in range (count):
255- if dtindex[i] == NPY_NAT:
256- out[i] = 0
257- continue
258-
259- dt64_to_dtstruct(dtindex[i], & dts)
260-
261- if dts.day == get_days_in_month(dts.year, dts.month):
262- out[i] = 1
263-
264- elif field == ' is_quarter_start' :
265- if is_business:
266- for i in range (count):
267- if dtindex[i] == NPY_NAT:
268- out[i] = 0
269- continue
270-
271- dt64_to_dtstruct(dtindex[i], & dts)
272-
273- if ((dts.month - start_month) % 3 == 0 ) and (
274- dts.day == get_firstbday(dts.year, dts.month)):
275- out[i] = 1
276-
277- else :
278- for i in range (count):
279- if dtindex[i] == NPY_NAT:
280- out[i] = 0
281- continue
282-
283- dt64_to_dtstruct(dtindex[i], & dts)
284-
285- if ((dts.month - start_month) % 3 == 0 ) and dts.day == 1 :
286- out[i] = 1
287-
288- elif field == ' is_quarter_end' :
289- if is_business:
290- for i in range (count):
291- if dtindex[i] == NPY_NAT:
292- out[i] = 0
293- continue
294-
295- dt64_to_dtstruct(dtindex[i], & dts)
296-
297- if ((dts.month - end_month) % 3 == 0 ) and (
298- dts.day == get_lastbday(dts.year, dts.month)):
299- out[i] = 1
300-
301- else :
302- for i in range (count):
303- if dtindex[i] == NPY_NAT:
304- out[i] = 0
305- continue
306-
307- dt64_to_dtstruct(dtindex[i], & dts)
308-
309- if ((dts.month - end_month) % 3 == 0 ) and (
310- dts.day == get_days_in_month(dts.year, dts.month)):
311- out[i] = 1
231+ compare_month = start_month if " start" in field else end_month
232+ if " month" in field:
233+ modby = 1
234+ elif " quarter" in field:
235+ modby = 3
236+ else :
237+ modby = 12
312238
313- elif field == ' is_year_start' :
239+ if field in [ " is_month_start " , " is_quarter_start " , " is_year_start" ] :
314240 if is_business:
315241 for i in range (count):
316242 if dtindex[i] == NPY_NAT:
@@ -319,7 +245,7 @@ def get_start_end_field(const int64_t[:] dtindex, str field,
319245
320246 dt64_to_dtstruct(dtindex[i], & dts)
321247
322- if (dts.month == start_month ) and (
248+ if _is_on_month (dts.month, compare_month, modby ) and (
323249 dts.day == get_firstbday(dts.year, dts.month)):
324250 out[i] = 1
325251
@@ -331,10 +257,10 @@ def get_start_end_field(const int64_t[:] dtindex, str field,
331257
332258 dt64_to_dtstruct(dtindex[i], & dts)
333259
334- if (dts.month == start_month ) and dts.day == 1 :
260+ if _is_on_month (dts.month, compare_month, modby ) and dts.day == 1 :
335261 out[i] = 1
336262
337- elif field == ' is_year_end' :
263+ elif field in [ " is_month_end " , " is_quarter_end " , " is_year_end" ] :
338264 if is_business:
339265 for i in range (count):
340266 if dtindex[i] == NPY_NAT:
@@ -343,7 +269,7 @@ def get_start_end_field(const int64_t[:] dtindex, str field,
343269
344270 dt64_to_dtstruct(dtindex[i], & dts)
345271
346- if (dts.month == end_month ) and (
272+ if _is_on_month (dts.month, compare_month, modby ) and (
347273 dts.day == get_lastbday(dts.year, dts.month)):
348274 out[i] = 1
349275
@@ -355,7 +281,7 @@ def get_start_end_field(const int64_t[:] dtindex, str field,
355281
356282 dt64_to_dtstruct(dtindex[i], & dts)
357283
358- if (dts.month == end_month ) and (
284+ if _is_on_month (dts.month, compare_month, modby ) and (
359285 dts.day == get_days_in_month(dts.year, dts.month)):
360286 out[i] = 1
361287
0 commit comments