| 
4 | 4 | from __future__ import annotations  | 
5 | 5 | 
 
  | 
6 | 6 | import math  | 
 | 7 | +import warnings  | 
7 | 8 | from collections.abc import Callable, Sequence  | 
8 | 9 | from functools import wraps  | 
9 | 10 | from types import ModuleType  | 
10 | 11 | from typing import TYPE_CHECKING, Any, cast, overload  | 
11 | 12 | 
 
  | 
12 |  | -from ._utils._compat import array_namespace, is_dask_namespace, is_jax_namespace  | 
 | 13 | +from ._utils._compat import (  | 
 | 14 | +    array_namespace,  | 
 | 15 | +    is_dask_namespace,  | 
 | 16 | +    is_jax_namespace,  | 
 | 17 | +    is_lazy_array,  | 
 | 18 | +)  | 
13 | 19 | from ._utils._typing import Array, DType  | 
14 | 20 | 
 
  | 
15 | 21 | if TYPE_CHECKING:  | 
@@ -319,3 +325,297 @@ def wrapper(  # type: ignore[no-any-decorated,no-any-explicit]  | 
319 | 325 |         return (xp.asarray(out),)  | 
320 | 326 | 
 
  | 
321 | 327 |     return wrapper  | 
 | 328 | + | 
 | 329 | + | 
 | 330 | +def lazy_raise(  # numpydoc ignore=SA04  | 
 | 331 | +    x: Array,  | 
 | 332 | +    cond: bool | Array,  | 
 | 333 | +    exc: Exception,  | 
 | 334 | +    *,  | 
 | 335 | +    xp: ModuleType | None = None,  | 
 | 336 | +) -> Array:  | 
 | 337 | +    """  | 
 | 338 | +    Raise if an eager check fails on a lazy array.  | 
 | 339 | +
  | 
 | 340 | +    Consider this snippet::  | 
 | 341 | +
  | 
 | 342 | +        >>> def f(x, xp):  | 
 | 343 | +        ...     if xp.any(x < 0):  | 
 | 344 | +        ...         raise ValueError("Some points are negative")  | 
 | 345 | +        ...     return x + 1  | 
 | 346 | +
  | 
 | 347 | +    The above code fails to compile when x is a JAX array and the function is wrapped  | 
 | 348 | +    by `jax.jit`; it is also extremely slow on Dask. Other lazy backends, e.g. ndonnx,  | 
 | 349 | +    are also expected to misbehave.  | 
 | 350 | +
  | 
 | 351 | +    `xp.any(x < 0)` is a 0-dimensional array with `dtype=bool`; the `if` statement calls  | 
 | 352 | +    `bool()` on the Array to convert it to a Python bool.  | 
 | 353 | +
  | 
 | 354 | +    On eager backends such as NumPy, this is not a problem. On Dask, `bool()` implicitly  | 
 | 355 | +    triggers a computation of the whole graph so far; what's worse is that the  | 
 | 356 | +    intermediate results are discarded to optimize memory usage, so when later on user  | 
 | 357 | +    explicitly calls `compute()` on their final output, `x` is recalculated from  | 
 | 358 | +    scratch. On JAX, `bool()` raises if its called code is wrapped by `jax.jit` for the  | 
 | 359 | +    same reason.  | 
 | 360 | +
  | 
 | 361 | +    You should rewrite the above code as follows::  | 
 | 362 | +
  | 
 | 363 | +        >>> def f(x, xp):  | 
 | 364 | +        ...     x = lazy_raise(x, xp.any(x < 0), ValueError("Some points are negative"))  | 
 | 365 | +        ...     return x + 1  | 
 | 366 | +
  | 
 | 367 | +    When `xp` is eager, this is equivalent to the original code; if the error condition  | 
 | 368 | +    resolves to True, the function raises immediately and the next line `return x + 1`  | 
 | 369 | +    is never executed.  | 
 | 370 | +    When `xp` is lazy, the function always returns a lazy array. When eventually the  | 
 | 371 | +    user actually computes it, e.g. in Dask by calling `compute()` and in JAX by having  | 
 | 372 | +    their outermost function decorated with `@jax.jit` return, only then the error  | 
 | 373 | +    condition is evaluated. If True, the exception is raised and propagated as normal,  | 
 | 374 | +    and the following nodes of the graph are never executed (so if the health check was  | 
 | 375 | +    in place to prevent not only incorrect results but e.g. a segmentation fault, it's  | 
 | 376 | +    still going to achieve its purpose).  | 
 | 377 | +
  | 
 | 378 | +    Parameters  | 
 | 379 | +    ----------  | 
 | 380 | +    x : Array  | 
 | 381 | +        Any one Array, potentially lazy, that is used later on to produce the value  | 
 | 382 | +        returned by your function.  | 
 | 383 | +    cond : bool | Array  | 
 | 384 | +        Must be either a plain Python bool or a 0-dimensional Array with boolean dtype.  | 
 | 385 | +        If True, raise the exception. If False, return x.  | 
 | 386 | +    exc : Exception  | 
 | 387 | +        The exception instance to be raised.  | 
 | 388 | +    xp : array_namespace, optional  | 
 | 389 | +        The standard-compatible namespace for `x`. Default: infer.  | 
 | 390 | +
  | 
 | 391 | +    Returns  | 
 | 392 | +    -------  | 
 | 393 | +    Array  | 
 | 394 | +        `x`. If both `x` and `cond` are lazy array, the graph underlying `x` is altered  | 
 | 395 | +        to raise `exc` if `cond` is True.  | 
 | 396 | +
  | 
 | 397 | +    Raises  | 
 | 398 | +    ------  | 
 | 399 | +    type(x)  | 
 | 400 | +        If `cond` evaluates to True.  | 
 | 401 | +
  | 
 | 402 | +    Warnings  | 
 | 403 | +    --------  | 
 | 404 | +    This function raises when x is eager, and quietly skips the check  | 
 | 405 | +    when x is lazy::  | 
 | 406 | +
  | 
 | 407 | +        >>> def f(x, xp):  | 
 | 408 | +        ...     lazy_raise(x, xp.any(x < 0), ValueError("Some points are negative"))  | 
 | 409 | +        ...     return x + 1  | 
 | 410 | +
  | 
 | 411 | +    And so does this one, as lazy_raise replaces `x` but it does so too late to  | 
 | 412 | +    contribute to the return value::  | 
 | 413 | +
  | 
 | 414 | +        >>> def f(x, xp):  | 
 | 415 | +        ...     y = x + 1  | 
 | 416 | +        ...     x = lazy_raise(x, xp.any(x < 0), ValueError("Some points are negative"))  | 
 | 417 | +        ...     return y  | 
 | 418 | +
  | 
 | 419 | +    See Also  | 
 | 420 | +    --------  | 
 | 421 | +    lazy_apply  | 
 | 422 | +    lazy_warn  | 
 | 423 | +    lazy_wait_on  | 
 | 424 | +    dask.graph_manipulation.wait_on  | 
 | 425 | +
  | 
 | 426 | +    Notes  | 
 | 427 | +    -----  | 
 | 428 | +    This function will raise if the :doc:`jax:transfer_guard` is active and `cond` is  | 
 | 429 | +    a JAX array on a non-CPU device.  | 
 | 430 | +    """  | 
 | 431 | + | 
 | 432 | +    def _lazy_raise(x: Array, cond: Array) -> Array:  # numpydoc ignore=PR01,RT01  | 
 | 433 | +        """Eager helper of `lazy_raise` running inside the lazy graph."""  | 
 | 434 | +        if cond:  | 
 | 435 | +            raise exc  | 
 | 436 | +        return x  | 
 | 437 | + | 
 | 438 | +    return _lazy_wait_on_impl(x, cond, _lazy_raise, xp=xp)  | 
 | 439 | + | 
 | 440 | + | 
 | 441 | +# Signature of warnings.warn copied from python/typeshed  | 
 | 442 | +@overload  | 
 | 443 | +def lazy_warn(  # type: ignore[no-any-explicit,no-any-decorated]  # numpydoc ignore=GL08  | 
 | 444 | +    x: Array,  | 
 | 445 | +    cond: bool | Array,  | 
 | 446 | +    message: str,  | 
 | 447 | +    category: type[Warning] | None = None,  | 
 | 448 | +    stacklevel: int = 1,  | 
 | 449 | +    source: Any | None = None,  | 
 | 450 | +    *,  | 
 | 451 | +    xp: ModuleType | None = None,  | 
 | 452 | +) -> None: ...  | 
 | 453 | +@overload  | 
 | 454 | +def lazy_warn(  # type: ignore[no-any-explicit,no-any-decorated]  # numpydoc ignore=GL08  | 
 | 455 | +    x: Array,  | 
 | 456 | +    cond: bool | Array,  | 
 | 457 | +    message: Warning,  | 
 | 458 | +    category: Any = None,  | 
 | 459 | +    stacklevel: int = 1,  | 
 | 460 | +    source: Any | None = None,  | 
 | 461 | +    *,  | 
 | 462 | +    xp: ModuleType | None = None,  | 
 | 463 | +) -> None: ...  | 
 | 464 | + | 
 | 465 | + | 
 | 466 | +def lazy_warn(  # type: ignore[no-any-explicit]  # numpydoc ignore=SA04,PR04  | 
 | 467 | +    x: Array,  | 
 | 468 | +    cond: bool | Array,  | 
 | 469 | +    message: str | Warning,  | 
 | 470 | +    category: Any = None,  | 
 | 471 | +    stacklevel: int = 1,  | 
 | 472 | +    source: Any | None = None,  | 
 | 473 | +    *,  | 
 | 474 | +    xp: ModuleType | None = None,  | 
 | 475 | +) -> Array:  | 
 | 476 | +    """  | 
 | 477 | +    Call `warnings.warn` if an eager check fails on a lazy array.  | 
 | 478 | +
  | 
 | 479 | +    This functions works in the same way as `lazy_raise`; refer to it  | 
 | 480 | +    for the detailed explanation.  | 
 | 481 | +
  | 
 | 482 | +    You should replace::  | 
 | 483 | +
  | 
 | 484 | +        >>> def f(x, xp):  | 
 | 485 | +        ...     if xp.any(x < 0):  | 
 | 486 | +        ...         warnings.warn("Some points are negative", UserWarning, stacklevel=2)  | 
 | 487 | +        ...     return x + 1  | 
 | 488 | +
  | 
 | 489 | +    with::  | 
 | 490 | +
  | 
 | 491 | +        >>> def f(x, xp):  | 
 | 492 | +        ...     x = lazy_raise(x, xp.any(x < 0),  | 
 | 493 | +        ...                    "Some points are negative", UserWarning, stacklevel=2)  | 
 | 494 | +        ...     return x + 1  | 
 | 495 | +
  | 
 | 496 | +    Parameters  | 
 | 497 | +    ----------  | 
 | 498 | +    x : Array  | 
 | 499 | +        Any one Array, potentially lazy, that is used later on to produce the value  | 
 | 500 | +        returned by your function.  | 
 | 501 | +    cond : bool | Array  | 
 | 502 | +        Must be either a plain Python bool or a 0-dimensional Array with boolean dtype.  | 
 | 503 | +        If True, raise the exception. If False, return x.  | 
 | 504 | +    message, category, stacklevel, source :  | 
 | 505 | +        Parameters to `warnings.warn`. `stacklevel` is automatically increased to  | 
 | 506 | +        compensate for the extra wrapper function.  | 
 | 507 | +    xp : array_namespace, optional  | 
 | 508 | +        The standard-compatible namespace for `x`. Default: infer.  | 
 | 509 | +
  | 
 | 510 | +    Returns  | 
 | 511 | +    -------  | 
 | 512 | +    Array  | 
 | 513 | +        `x`. If both `x` and `cond` are lazy array, the graph underlying `x` is altered  | 
 | 514 | +        to issue the warning if `cond` is True.  | 
 | 515 | +
  | 
 | 516 | +    See Also  | 
 | 517 | +    --------  | 
 | 518 | +    warnings.warn  | 
 | 519 | +    lazy_apply  | 
 | 520 | +    lazy_raise  | 
 | 521 | +    lazy_wait_on  | 
 | 522 | +    dask.graph_manipulation.wait_on  | 
 | 523 | +
  | 
 | 524 | +    Notes  | 
 | 525 | +    -----  | 
 | 526 | +    On Dask, the warning is typically going to appear on the log of the  | 
 | 527 | +    worker executing the function instead of on the client.  | 
 | 528 | +    """  | 
 | 529 | + | 
 | 530 | +    def _lazy_warn(x: Array, cond: Array) -> Array:  # numpydoc ignore=PR01,RT01  | 
 | 531 | +        """Eager helper of `lazy_raise` running inside the lazy graph."""  | 
 | 532 | +        if cond:  | 
 | 533 | +            warnings.warn(message, category, stacklevel=stacklevel + 2, source=source)  | 
 | 534 | +        return x  | 
 | 535 | + | 
 | 536 | +    return _lazy_wait_on_impl(x, cond, _lazy_warn, xp=xp)  | 
 | 537 | + | 
 | 538 | + | 
 | 539 | +def lazy_wait_on(  | 
 | 540 | +    x: Array, wait_on: object, *, xp: ModuleType | None = None  | 
 | 541 | +) -> Array:  # numpydoc ignore=SA04  | 
 | 542 | +    """  | 
 | 543 | +    Pause materialization of `x` until `wait_on` has been materialized.  | 
 | 544 | +
  | 
 | 545 | +    This is typically used to collect multiple calls to `lazy_raise` and/or  | 
 | 546 | +    `lazy_warn` from validation functions that would otherwise return None.  | 
 | 547 | +    If `wait_on` is not a lazy array, just return `x`.  | 
 | 548 | +
  | 
 | 549 | +    Read `lazy_raise` for detailed explanation.  | 
 | 550 | +
  | 
 | 551 | +    Parameters  | 
 | 552 | +    ----------  | 
 | 553 | +    x : Array  | 
 | 554 | +        Any one Array, potentially lazy, that is used later on to produce the value  | 
 | 555 | +        returned by your function.  | 
 | 556 | +    wait_on : object  | 
 | 557 | +        Any object. If it's a lazy array, block the materialization of `x` until  | 
 | 558 | +        `wait_on` has been fully materialized.  | 
 | 559 | +    xp : array_namespace, optional  | 
 | 560 | +        The standard-compatible namespace for `x`. Default: infer.  | 
 | 561 | +
  | 
 | 562 | +    Returns  | 
 | 563 | +    -------  | 
 | 564 | +    Array  | 
 | 565 | +        `x`. If both `x` and `wait_on` are lazy arrays, the graph  | 
 | 566 | +        underlying `x` is altered to wait until `wait_on` has been materialized.  | 
 | 567 | +        If `wait_on` raises, the exception is propagated to `x`.  | 
 | 568 | +
  | 
 | 569 | +    See Also  | 
 | 570 | +    --------  | 
 | 571 | +    lazy_apply  | 
 | 572 | +    lazy_raise  | 
 | 573 | +    lazy_warn  | 
 | 574 | +    dask.graph_manipulation.wait_on  | 
 | 575 | +
  | 
 | 576 | +    Examples  | 
 | 577 | +    --------  | 
 | 578 | +    ::  | 
 | 579 | +
  | 
 | 580 | +        def validate(x, xp):  | 
 | 581 | +            # Future that evaluates the checks. Contents are inconsequential.  | 
 | 582 | +            # Avoid zero-sized arrays, as they may be elided by the graph optimizer.  | 
 | 583 | +            future = xp.empty(1)  | 
 | 584 | +            future = lazy_raise(future, xp.any(x < 10), ValueError("Less than 10"))  | 
 | 585 | +            future = lazy_warn(future, xp.any(x > 20), UserWarning, "More than 20"))  | 
 | 586 | +            return future  | 
 | 587 | +
  | 
 | 588 | +        def f(x, xp):  | 
 | 589 | +            x = lazy_wait_on(x, validate(x, xp), xp=xp)  | 
 | 590 | +            return x + 1  | 
 | 591 | +    """  | 
 | 592 | + | 
 | 593 | +    def _lazy_wait_on(x: Array, _: Array) -> Array:  # numpydoc ignore=PR01,RT01  | 
 | 594 | +        """Eager helper of `lazy_wait_on` running inside the lazy graph."""  | 
 | 595 | +        return x  | 
 | 596 | + | 
 | 597 | +    return _lazy_wait_on_impl(x, wait_on, _lazy_wait_on, xp=xp)  | 
 | 598 | + | 
 | 599 | + | 
 | 600 | +def _lazy_wait_on_impl(  # numpydoc ignore=PR01,RT01  | 
 | 601 | +    x: Array,  | 
 | 602 | +    wait_on: object,  | 
 | 603 | +    eager_func: Callable[[Array, Array], Array],  | 
 | 604 | +    xp: ModuleType | None,  | 
 | 605 | +) -> Array:  | 
 | 606 | +    """Implementation of lazy_raise, lazy_warn, and lazy_wait_on."""  | 
 | 607 | +    if not is_lazy_array(wait_on):  | 
 | 608 | +        return eager_func(x, wait_on)  | 
 | 609 | + | 
 | 610 | +    if cast(Array, wait_on).shape != ():  | 
 | 611 | +        msg = "cond/wait_on must be 0-dimensional"  | 
 | 612 | +        raise ValueError(msg)  | 
 | 613 | + | 
 | 614 | +    if xp is None:  | 
 | 615 | +        xp = array_namespace(x, wait_on)  | 
 | 616 | + | 
 | 617 | +    if is_dask_namespace(xp):  | 
 | 618 | +        # lazy_apply would rechunk x  | 
 | 619 | +        return xp.map_blocks(eager_func, x, wait_on, dtype=x.dtype, meta=x._meta)  # pylint: disable=protected-access  | 
 | 620 | + | 
 | 621 | +    return lazy_apply(eager_func, x, wait_on, shape=x.shape, dtype=x.dtype, xp=xp)  | 
0 commit comments