Skip to content

Commit 89d8410

Browse files
committed
fixed linting
1 parent 05ffb7b commit 89d8410

File tree

3 files changed

+173
-37
lines changed

3 files changed

+173
-37
lines changed

src/array_api_extra/_delegation.py

Lines changed: 151 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -768,7 +768,7 @@ def argpartition(
768768
Axis along which to partition. The default is ``-1`` (the last axis).
769769
If ``None``, the flattened array is used.
770770
xp : array_namespace, optional
771-
The standard-compatible namespace for `x`. Default: infer.
771+
The standard-compatible namespace for `a`. Default: infer.
772772
773773
Returns
774774
-------
@@ -908,45 +908,179 @@ def quantile(
908908
xp: ModuleType | None = None,
909909
) -> Array:
910910
"""
911-
TODO
911+
Compute the q-th quantile of the data along the specified axis.
912+
913+
Parameters
914+
----------
915+
a : array_like of real numbers
916+
Input array or object that can be converted to an array.
917+
q : array_like of float
918+
Probability or sequence of probabilities of the quantiles to compute.
919+
Values must be between 0 and 1 inclusive.
920+
axis : {int, tuple of int, None}, optional
921+
Axis or axes along which the quantiles are computed. The default is
922+
to compute the quantile(s) along a flattened version of the array.
923+
method : str, optional
924+
This parameter specifies the method to use for estimating the
925+
quantile. There are many different methods.
926+
The recommended options, numbered as they appear in [1]_, are:
927+
928+
1. 'inverted_cdf'
929+
2. 'averaged_inverted_cdf'
930+
3. 'closest_observation'
931+
4. 'interpolated_inverted_cdf'
932+
5. 'hazen'
933+
6. 'weibull'
934+
7. 'linear' (default)
935+
8. 'median_unbiased'
936+
9. 'normal_unbiased'
937+
938+
The first three methods are discontinuous.
939+
Only 'linear' is implemented for now.
940+
941+
keepdims : bool, optional
942+
If this is set to True, the axes which are reduced are left in
943+
the result as dimensions with size one. With this option, the
944+
result will broadcast correctly against the original array `a`.
945+
946+
xp : array_namespace, optional
947+
The standard-compatible namespace for `a` and `q`. Default: infer.
948+
949+
Returns
950+
-------
951+
scalar or ndarray
952+
If `q` is a single probability and `axis=None`, then the result
953+
is a scalar. If multiple probability levels are given, first axis
954+
of the result corresponds to the quantiles. The other axes are
955+
the axes that remain after the reduction of `a`. If the input
956+
contains integers or floats smaller than ``float64``, the output
957+
data-type is ``float64``. Otherwise, the output data-type is the
958+
same as that of the input. If `out` is specified, that array is
959+
returned instead.
960+
961+
Notes
962+
-----
963+
Given a sample `a` from an underlying distribution, `quantile` provides a
964+
nonparametric estimate of the inverse cumulative distribution function.
965+
966+
By default, this is done by interpolating between adjacent elements in
967+
``y``, a sorted copy of `a`::
968+
969+
(1-g)*y[j] + g*y[j+1]
970+
971+
where the index ``j`` and coefficient ``g`` are the integral and
972+
fractional components of ``q * (n-1)``, and ``n`` is the number of
973+
elements in the sample.
974+
975+
This is a special case of Equation 1 of H&F [1]_. More generally,
976+
977+
- ``j = (q*n + m - 1) // 1``, and
978+
- ``g = (q*n + m - 1) % 1``,
979+
980+
where ``m`` may be defined according to several different conventions.
981+
The preferred convention may be selected using the ``method`` parameter:
982+
983+
=============================== =============== ===============
984+
``method`` number in H&F ``m``
985+
=============================== =============== ===============
986+
``interpolated_inverted_cdf`` 4 ``0``
987+
``hazen`` 5 ``1/2``
988+
``weibull`` 6 ``q``
989+
``linear`` (default) 7 ``1 - q``
990+
``median_unbiased`` 8 ``q/3 + 1/3``
991+
``normal_unbiased`` 9 ``q/4 + 3/8``
992+
=============================== =============== ===============
993+
994+
Note that indices ``j`` and ``j + 1`` are clipped to the range ``0`` to
995+
``n - 1`` when the results of the formula would be outside the allowed
996+
range of non-negative indices. The ``- 1`` in the formulas for ``j`` and
997+
``g`` accounts for Python's 0-based indexing.
998+
999+
The table above includes only the estimators from H&F that are continuous
1000+
functions of probability `q` (estimators 4-9). NumPy also provides the
1001+
three discontinuous estimators from H&F (estimators 1-3), where ``j`` is
1002+
defined as above, ``m`` is defined as follows, and ``g`` is a function
1003+
of the real-valued ``index = q*n + m - 1`` and ``j``.
1004+
1005+
1. ``inverted_cdf``: ``m = 0`` and ``g = int(index - j > 0)``
1006+
2. ``averaged_inverted_cdf``: ``m = 0`` and
1007+
``g = (1 + int(index - j > 0)) / 2``
1008+
3. ``closest_observation``: ``m = -1/2`` and
1009+
``g = 1 - int((index == j) & (j%2 == 1))``
1010+
1011+
**Weighted quantiles:**
1012+
More formally, the quantile at probability level :math:`q` of a cumulative
1013+
distribution function :math:`F(y)=P(Y \\leq y)` with probability measure
1014+
:math:`P` is defined as any number :math:`x` that fulfills the
1015+
*coverage conditions*
1016+
1017+
.. math:: P(Y < x) \\leq q \\quad\\text{and}\\quad P(Y \\leq x) \\geq q
1018+
1019+
with random variable :math:`Y\\sim P`.
1020+
Sample quantiles, the result of `quantile`, provide nonparametric
1021+
estimation of the underlying population counterparts, represented by the
1022+
unknown :math:`F`, given a data vector `a` of length ``n``.
1023+
1024+
Some of the estimators above arise when one considers :math:`F` as the
1025+
empirical distribution function of the data, i.e.
1026+
:math:`F(y) = \\frac{1}{n} \\sum_i 1_{a_i \\leq y}`.
1027+
Then, different methods correspond to different choices of :math:`x` that
1028+
fulfill the above coverage conditions. Methods that follow this approach
1029+
are ``inverted_cdf`` and ``averaged_inverted_cdf``.
1030+
1031+
For weighted quantiles, the coverage conditions still hold. The
1032+
empirical cumulative distribution is simply replaced by its weighted
1033+
version, i.e.
1034+
:math:`P(Y \\leq t) = \\frac{1}{\\sum_i w_i} \\sum_i w_i 1_{x_i \\leq t}`.
1035+
Only ``method="inverted_cdf"`` supports weights.
1036+
1037+
References
1038+
----------
1039+
.. [1] R. J. Hyndman and Y. Fan,
1040+
"Sample quantiles in statistical packages,"
1041+
The American Statistician, 50(4), pp. 361-365, 1996
9121042
"""
9131043
methods = {"linear"}
9141044

9151045
if method not in methods:
916-
message = f"`method` must be one of {methods}"
917-
raise ValueError(message)
1046+
msg = f"`method` must be one of {methods}"
1047+
raise ValueError(msg)
9181048
if keepdims not in {True, False}:
919-
message = "If specified, `keepdims` must be True or False."
920-
raise ValueError(message)
1049+
msg = "If specified, `keepdims` must be True or False."
1050+
raise ValueError(msg)
9211051
if xp is None:
9221052
xp = array_namespace(a)
9231053

9241054
a = xp.asarray(a)
925-
if not xp.isdtype(a.dtype, ('integral', 'real floating')):
926-
raise ValueError("`a` must have real dtype.")
927-
if not xp.isdtype(xp.asarray(q).dtype, 'real floating'):
928-
raise ValueError("`q` must have real floating dtype.")
1055+
if not xp.isdtype(a.dtype, ("integral", "real floating")):
1056+
msg = "`a` must have real dtype."
1057+
raise ValueError(msg)
1058+
if not xp.isdtype(xp.asarray(q).dtype, "real floating"):
1059+
msg = "`q` must have real floating dtype."
1060+
raise ValueError(msg)
9291061
ndim = a.ndim
9301062
if ndim < 1:
9311063
msg = "`a` must be at least 1-dimensional"
9321064
raise TypeError(msg)
9331065
if axis is not None and ((axis >= ndim) or (axis < -ndim)):
934-
message = "`axis` is not compatible with the dimension of `a`."
935-
raise ValueError(message)
1066+
msg = "`axis` is not compatible with the dimension of `a`."
1067+
raise ValueError(msg)
9361068

9371069
# Array API states: Mixed integer and floating-point type promotion rules
9381070
# are not specified because behavior varies between implementations.
939-
# => We choose to do:
940-
dtype = (
941-
xp.float64 if xp.isdtype(a.dtype, 'integral')
942-
else xp.result_type(a, xp.asarray(q)) # both a and q are floats
1071+
# We chose to align with numpy (see docstring):
1072+
dtype = xp.result_type(
1073+
xp.float64 if xp.isdtype(a.dtype, "integral") else a,
1074+
xp.asarray(q),
1075+
xp.float64, # at least float64
9431076
)
9441077
device = get_device(a)
9451078
a = xp.asarray(a, dtype=dtype, device=device)
9461079
q = xp.asarray(q, dtype=dtype, device=device)
9471080

9481081
if xp.any((q > 1) | (q < 0) | xp.isnan(q)):
949-
raise ValueError("`q` values must be in the range [0, 1]")
1082+
msg = "`q` values must be in the range [0, 1]"
1083+
raise ValueError(msg)
9501084

9511085
# Delegate where possible.
9521086
if is_numpy_namespace(xp):

src/array_api_extra/_lib/_quantile.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
"""Implementations of the quantile function."""
2+
13
from types import ModuleType
24

35
from ._utils._compat import device as get_device
@@ -9,22 +11,19 @@ def quantile( # numpydoc ignore=PR01,RT01
911
a: Array,
1012
q: Array | float,
1113
/,
12-
method: str = 'linear', # noqa: ARG001
14+
method: str = "linear", # noqa: ARG001
1315
axis: int | None = None,
1416
keepdims: bool = False,
1517
*,
1618
xp: ModuleType,
1719
) -> Array:
1820
"""See docstring in `array_api_extra._delegation.py`."""
1921
device = get_device(a)
20-
floating_dtype = xp.float64 #xp.result_type(a, xp.asarray(q))
22+
floating_dtype = xp.float64 # xp.result_type(a, xp.asarray(q))
2123
a = xp.asarray(a, dtype=floating_dtype, device=device)
2224
a_shape = list(a.shape)
2325
p: Array = xp.asarray(q, dtype=floating_dtype, device=device)
2426

25-
if xp.any((p > 1) | (p < 0) | xp.isnan(p)):
26-
raise ValueError("`q` values must be in the range [0, 1]")
27-
2827
q_scalar = p.ndim == 0
2928
if q_scalar:
3029
p = xp.reshape(p, (1,))
@@ -37,7 +36,7 @@ def quantile( # numpydoc ignore=PR01,RT01
3736
else:
3837
axis = int(axis)
3938

40-
n, = eager_shape(a, axis)
39+
(n,) = eager_shape(a, axis)
4140
# If data has length zero along `axis`, the result will be an array of NaNs just
4241
# as if the data had length 1 along axis and were filled with NaNs.
4342
if n == 0:
@@ -66,22 +65,23 @@ def quantile( # numpydoc ignore=PR01,RT01
6665
return res[0, ...] if q_scalar else res
6766

6867

69-
def _quantile_hf(a: Array, q: Array, n: float, axis: int, xp: ModuleType) -> Array:
68+
def _quantile_hf( # numpydoc ignore=GL08
69+
a: Array, q: Array, n: float, axis: int, xp: ModuleType
70+
) -> Array:
7071
m = 1 - q
71-
jg = q*n + m - 1
72+
jg = q * n + m - 1
7273

7374
j = jg // 1
74-
j = xp.clip(j, 0., n - 1)
75-
jp1 = xp.clip(j + 1, 0., n - 1)
75+
j = xp.clip(j, 0.0, n - 1)
76+
jp1 = xp.clip(j + 1, 0.0, n - 1)
7677
# `̀j` and `jp1` are 1d arrays
7778

7879
g = jg % 1
79-
g = xp.where(j < 0, 0, g) # equiv to g[j < 0] = 0, but work with strictest
80+
g = xp.where(j < 0, 0, g) # equivalent to g[j < 0] = 0, but works with strictest
8081
new_g_shape = [1] * a.ndim
8182
new_g_shape[axis] = g.shape[0]
8283
g = xp.reshape(g, tuple(new_g_shape))
8384

84-
return (
85-
(1 - g) * xp.take(a, xp.astype(j, xp.int64), axis=axis)
86-
+ g * xp.take(a, xp.astype(jp1, xp.int64), axis=axis)
85+
return (1 - g) * xp.take(a, xp.astype(j, xp.int64), axis=axis) + g * xp.take(
86+
a, xp.astype(jp1, xp.int64), axis=axis
8787
)

tests/test_funcs.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1547,8 +1547,9 @@ def test_multiple_quantiles(self, xp: ModuleType):
15471547
xp_assert_close(actual, expect)
15481548

15491549
def test_shape(self, xp: ModuleType):
1550-
a = xp.asarray(np.random.rand(3, 4, 5))
1551-
q = xp.asarray(np.random.rand(2))
1550+
rng = np.random.default_rng()
1551+
a = xp.asarray(rng.random((3, 4, 5)))
1552+
q = xp.asarray(rng.random(2))
15521553
assert quantile(a, q, axis=0).shape == (2, 4, 5)
15531554
assert quantile(a, q, axis=1).shape == (2, 3, 5)
15541555
assert quantile(a, q, axis=2).shape == (2, 3, 4)
@@ -1558,8 +1559,9 @@ def test_shape(self, xp: ModuleType):
15581559
assert quantile(a, q, axis=2, keepdims=True).shape == (2, 3, 4, 1)
15591560

15601561
def test_against_numpy(self, xp: ModuleType):
1561-
a_np = np.random.rand(3, 4, 5)
1562-
q_np = np.random.rand(2)
1562+
rng = np.random.default_rng()
1563+
a_np = rng.random((3, 4, 5))
1564+
q_np = rng.random(2)
15631565
a = xp.asarray(a_np)
15641566
q = xp.asarray(q_np)
15651567
for keepdims in [False, True]:
@@ -1583,7 +1585,7 @@ def test_2d_axis_keepdims(self, xp: ModuleType):
15831585

15841586
def test_methods(self, xp: ModuleType):
15851587
x = xp.asarray([1, 2, 3, 4, 5])
1586-
methods = ["linear"] #"hazen", "weibull"]
1588+
methods = ["linear"] # "hazen", "weibull"]
15871589
for method in methods:
15881590
actual = quantile(x, 0.5, method=method)
15891591
# All methods should give reasonable results
@@ -1617,7 +1619,7 @@ def test_invalid_q(self, xp: ModuleType):
16171619
_ = quantile(x, -0.5)
16181620

16191621
def test_device(self, xp: ModuleType, device: Device):
1620-
if hasattr(device, 'type') and getattr(device, 'type') == "meta":
1622+
if hasattr(device, "type") and device.type == "meta": # pyright: ignore[reportAttributeAccessIssue]
16211623
pytest.xfail("No Tensor.item() on meta device")
16221624
x = xp.asarray([1, 2, 3, 4, 5], device=device)
16231625
actual = quantile(x, 0.5)

0 commit comments

Comments
 (0)