Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions arraycontext/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -583,4 +583,8 @@ def tag_axes(

# }}}


class UntransformedCodeWarning(UserWarning):
pass

# vim: foldmethod=marker
34 changes: 23 additions & 11 deletions arraycontext/impl/pyopencl/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
"""
from __future__ import annotations


__doc__ = """
.. currentmodule:: arraycontext
.. autoclass:: PyOpenCLArrayContext
.. automodule:: arraycontext.impl.pyopencl.taggable_cl_array
Expand Down Expand Up @@ -36,7 +39,13 @@
from pytools.tag import ToTagSetConvertible

from arraycontext.container.traversal import rec_map_array_container, with_array_context
from arraycontext.context import Array, ArrayContext, ArrayOrContainer, ScalarLike
from arraycontext.context import (
Array,
ArrayContext,
ArrayOrContainer,
ScalarLike,
UntransformedCodeWarning,
)


if TYPE_CHECKING:
Expand Down Expand Up @@ -72,8 +81,8 @@ class PyOpenCLArrayContext(ArrayContext):
"""

def __init__(self,
queue: "pyopencl.CommandQueue",
allocator: Optional["pyopencl.tools.AllocatorBase"] = None,
queue: pyopencl.CommandQueue,
allocator: Optional[pyopencl.tools.AllocatorBase] = None,
wait_event_queue_length: Optional[int] = None,
force_device_scalars: bool = False) -> None:
r"""
Expand Down Expand Up @@ -301,16 +310,19 @@ def clone(self):

# {{{ transform_loopy_program

def transform_loopy_program(self, t_unit):
def transform_loopy_program(self, t_unit: lp.TranslationUnit) -> lp.TranslationUnit:
from warnings import warn
warn("Using arraycontext.PyOpenCLArrayContext.transform_loopy_program "
"to transform a program. This is deprecated and will stop working "
"in 2022. Instead, subclass PyOpenCLArrayContext and implement "
"the specific logic required to transform the program for your "
"package or application. Check higher-level packages "
warn("Using the base "
f"{type(self).__name__}.transform_loopy_program "
"to transform a translation unit. "
"This is largely a no-op and unlikely to result in fast generated "
"code."
f"Instead, subclass {type(self).__name__} and implement "
"the specific transform logic required to transform the program "
"for your package or application. Check higher-level packages "
"(e.g. meshmode), which may already have subclasses you may want "
"to build on.",
DeprecationWarning, stacklevel=2)
UntransformedCodeWarning, stacklevel=2)

# accommodate loopy with and without kernel callables

Expand Down
45 changes: 33 additions & 12 deletions arraycontext/impl/pytato/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
"""
from __future__ import annotations


__doc__ = """
.. currentmodule:: arraycontext

A :mod:`pytato`-based array context defers the evaluation of an array until its
Expand Down Expand Up @@ -62,11 +65,18 @@
from pytools.tag import Tag, ToTagSetConvertible, normalize_tags

from arraycontext.container.traversal import rec_map_array_container, with_array_context
from arraycontext.context import Array, ArrayContext, ArrayOrContainer, ScalarLike
from arraycontext.context import (
Array,
ArrayContext,
ArrayOrContainer,
ScalarLike,
UntransformedCodeWarning,
)
from arraycontext.metadata import NameHint


if TYPE_CHECKING:
import loopy as lp
import pyopencl as cl
import pytato

Expand Down Expand Up @@ -137,7 +147,6 @@ def __init__(
"""
super().__init__()

import loopy as lp
import pytato as pt
self._freeze_prg_cache: Dict[pt.DictOfNamedArrays, lp.TranslationUnit] = {}
self._dag_transform_cache: Dict[
Expand Down Expand Up @@ -180,8 +189,8 @@ def empty_like(self, ary):

# {{{ compilation

def transform_dag(self, dag: "pytato.DictOfNamedArrays"
) -> "pytato.DictOfNamedArrays":
def transform_dag(self, dag: pytato.DictOfNamedArrays
) -> pytato.DictOfNamedArrays:
"""
Returns a transformed version of *dag*. Sub-classes are supposed to
override this method to implement context-specific transformations on
Expand All @@ -194,10 +203,22 @@ def transform_dag(self, dag: "pytato.DictOfNamedArrays"
"""
return dag

def transform_loopy_program(self, t_unit):
raise ValueError(
f"{type(self).__name__} does not implement transform_loopy_program. "
"Sub-classes are supposed to implement it.")
def transform_loopy_program(self, t_unit: lp.TranslationUnit) -> lp.TranslationUnit:
from warnings import warn
warn("Using the base "
f"{type(self).__name__}.transform_loopy_program "
"to transform a translation unit. "
"This is a no-op and will result in unoptimized C code for"
"the requested optimization, all in a single statement."
"This will work, but is unlikely to be performatn."
f"Instead, subclass {type(self).__name__} and implement "
"the specific transform logic required to transform the program "
"for your package or application. Check higher-level packages "
"(e.g. meshmode), which may already have subclasses you may want "
"to build on.",
UntransformedCodeWarning, stacklevel=2)

return t_unit

@abc.abstractmethod
def einsum(self, spec, *args, arg_names=None, tagged=()):
Expand Down Expand Up @@ -250,7 +271,7 @@ class PytatoPyOpenCLArrayContext(_BasePytatoArrayContext):
.. automethod:: compile
"""
def __init__(
self, queue: "cl.CommandQueue", allocator=None, *,
self, queue: cl.CommandQueue, allocator=None, *,
use_memory_pool: Optional[bool] = None,
compile_trace_callback: Optional[Callable[[Any, str, Any], None]] = None,

Expand Down Expand Up @@ -642,8 +663,8 @@ def compile(self, f: Callable[..., Any]) -> Callable[..., Any]:
from .compile import LazilyPyOpenCLCompilingFunctionCaller
return LazilyPyOpenCLCompilingFunctionCaller(self, f)

def transform_dag(self, dag: "pytato.DictOfNamedArrays"
) -> "pytato.DictOfNamedArrays":
def transform_dag(self, dag: pytato.DictOfNamedArrays
) -> pytato.DictOfNamedArrays:
import pytato as pt
dag = pt.transform.materialize_with_mpms(dag)
return dag
Expand Down
1 change: 0 additions & 1 deletion arraycontext/pytest.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,6 @@ def is_available(cls) -> bool:
def actx_class(self):
from arraycontext import PytatoPyOpenCLArrayContext
actx_cls = PytatoPyOpenCLArrayContext
actx_cls.transform_loopy_program = lambda s, t_unit: t_unit
return actx_cls

def __call__(self):
Expand Down