@@ -714,10 +714,12 @@ group_ohlc_float64 = _group_ohlc['double']
714714
715715@ cython.boundscheck (False )
716716@ cython.wraparound (False )
717- def group_quantile (ndarray[float64_t] out ,
718- ndarray[int64_t] labels ,
719- numeric[:] values ,
720- ndarray[uint8_t] mask ,
717+ def group_quantile (floating[:, :] out ,
718+ int64_t[:] counts ,
719+ floating[:, :] values ,
720+ const int64_t[:] labels ,
721+ Py_ssize_t min_count ,
722+ const uint8_t[:, :] mask ,
721723 float64_t q ,
722724 object interpolation ):
723725 """
@@ -740,12 +742,12 @@ def group_quantile(ndarray[float64_t] out,
740742 provided `out` parameter.
741743 """
742744 cdef:
743- Py_ssize_t i, N= len (labels), ngroups, grp_sz, non_na_sz
745+ Py_ssize_t i, N= len (labels), K, ngroups, grp_sz= 0 , non_na_sz
744746 Py_ssize_t grp_start= 0 , idx= 0
745747 int64_t lab
746748 uint8_t interp
747749 float64_t q_idx, frac, val, next_val
748- ndarray[ int64_t] counts, non_na_counts, sort_arr
750+ int64_t[:, :] non_na_counts, sort_arrs
749751
750752 assert values.shape[0 ] == N
751753
@@ -761,59 +763,64 @@ def group_quantile(ndarray[float64_t] out,
761763 }
762764 interp = inter_methods[interpolation]
763765
764- counts = np.zeros_like(out, dtype = np.int64)
765766 non_na_counts = np.zeros_like(out, dtype = np.int64)
767+ sort_arrs = np.empty_like(values, dtype = np.int64)
766768 ngroups = len (counts)
767769
770+ N, K = (< object > values).shape
771+
768772 # First figure out the size of every group
769773 with nogil:
770774 for i in range (N):
771775 lab = labels[i]
772776 if lab == - 1 : # NA group label
773777 continue
774-
775778 counts[lab] += 1
776- if not mask[i]:
777- non_na_counts[lab] += 1
779+ for j in range (K):
780+ if not mask[i, j]:
781+ non_na_counts[lab, j] += 1
778782
779- # Get an index of values sorted by labels and then values
780- order = (values, labels)
781- sort_arr = np.lexsort(order).astype(np.int64, copy = False )
783+ for j in range (K):
784+ order = (values[:, j], labels)
785+ r = np.lexsort(order).astype(np.int64, copy = False )
786+ # TODO: Need better way to assign r to column j
787+ for i in range (N):
788+ sort_arrs[i, j] = r[i]
782789
783790 with nogil:
784791 for i in range (ngroups):
785792 # Figure out how many group elements there are
786793 grp_sz = counts[i]
787- non_na_sz = non_na_counts[i]
788-
789- if non_na_sz == 0 :
790- out[i] = NaN
791- else :
792- # Calculate where to retrieve the desired value
793- # Casting to int will intentionally truncate result
794- idx = grp_start + < int64_t> (q * < float64_t> (non_na_sz - 1 ))
795-
796- val = values[sort_arr[idx]]
797- # If requested quantile falls evenly on a particular index
798- # then write that index's value out. Otherwise interpolate
799- q_idx = q * (non_na_sz - 1 )
800- frac = q_idx % 1
801-
802- if frac == 0.0 or interp == INTERPOLATION_LOWER:
803- out[i] = val
794+ for j in range (K):
795+ non_na_sz = non_na_counts[i, j]
796+ if non_na_sz == 0 :
797+ out[i, j] = NaN
804798 else :
805- next_val = values[sort_arr[idx + 1 ]]
806- if interp == INTERPOLATION_LINEAR:
807- out[i] = val + (next_val - val) * frac
808- elif interp == INTERPOLATION_HIGHER:
809- out[i] = next_val
810- elif interp == INTERPOLATION_MIDPOINT:
811- out[i] = (val + next_val) / 2.0
812- elif interp == INTERPOLATION_NEAREST:
813- if frac > .5 or (frac == .5 and q > .5 ): # Always OK?
814- out[i] = next_val
815- else :
816- out[i] = val
799+ # Calculate where to retrieve the desired value
800+ # Casting to int will intentionally truncate result
801+ idx = grp_start + < int64_t> (q * < float64_t> (non_na_sz - 1 ))
802+
803+ val = values[sort_arrs[idx, j], j]
804+ # If requested quantile falls evenly on a particular index
805+ # then write that index's value out. Otherwise interpolate
806+ q_idx = q * (non_na_sz - 1 )
807+ frac = q_idx % 1
808+
809+ if frac == 0.0 or interp == INTERPOLATION_LOWER:
810+ out[i, j] = val
811+ else :
812+ next_val = values[sort_arrs[idx + 1 , j], j]
813+ if interp == INTERPOLATION_LINEAR:
814+ out[i, j] = val + (next_val - val) * frac
815+ elif interp == INTERPOLATION_HIGHER:
816+ out[i, j] = next_val
817+ elif interp == INTERPOLATION_MIDPOINT:
818+ out[i, j] = (val + next_val) / 2.0
819+ elif interp == INTERPOLATION_NEAREST:
820+ if frac > .5 or (frac == .5 and q > .5 ): # Always OK?
821+ out[i, j] = next_val
822+ else :
823+ out[i, j] = val
817824
818825 # Increment the index reference in sorted_arr for the next group
819826 grp_start += grp_sz
0 commit comments