Skip to content

Commit b4d1cf9

Browse files
committed
Implement PytatoSplitArrayContext
1 parent 50511fe commit b4d1cf9

File tree

3 files changed

+500
-0
lines changed

3 files changed

+500
-0
lines changed

arraycontext/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@
5353
from .impl.jax import EagerJAXArrayContext
5454
from .impl.pyopencl import PyOpenCLArrayContext
5555
from .impl.pytato import PytatoJAXArrayContext, PytatoPyOpenCLArrayContext
56+
from .impl.pytato.split_actx import SplitPytatoPyOpenCLArrayContext
5657
from .loopy import make_loopy_program
5758
# deprecated, remove in 2022.
5859
from .metadata import _FirstAxisIsElementsTag
@@ -98,6 +99,8 @@
9899
"outer",
99100

100101
"PyOpenCLArrayContext", "PytatoPyOpenCLArrayContext",
102+
"SplitPytatoPyOpenCLArrayContext",
103+
101104
"PytatoJAXArrayContext",
102105
"EagerJAXArrayContext",
103106

Lines changed: 133 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,133 @@
1+
"""
2+
.. autoclass:: SplitPytatoPyOpenCLArrayContext
3+
4+
"""
5+
6+
__copyright__ = """
7+
Copyright (C) 2023 Kaushik Kulkarni
8+
Copyright (C) 2023 Andreas Kloeckner
9+
Copyright (C) 2022 Matthias Diener
10+
Copyright (C) 2022 Matt Smith
11+
"""
12+
13+
__license__ = """
14+
Permission is hereby granted, free of charge, to any person obtaining a copy
15+
of this software and associated documentation files (the "Software"), to deal
16+
in the Software without restriction, including without limitation the rights
17+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
18+
copies of the Software, and to permit persons to whom the Software is
19+
furnished to do so, subject to the following conditions:
20+
21+
The above copyright notice and this permission notice shall be included in
22+
all copies or substantial portions of the Software.
23+
24+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
25+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
26+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
27+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
28+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
29+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
30+
THE SOFTWARE.
31+
"""
32+
33+
import sys
34+
from typing import TYPE_CHECKING
35+
36+
import loopy as lp
37+
38+
from arraycontext.impl.pytato import PytatoPyOpenCLArrayContext
39+
40+
41+
if TYPE_CHECKING or getattr(sys, "_BUILDING_SPHINX_DOCS", False):
42+
import pytato
43+
44+
45+
class SplitPytatoPyOpenCLArrayContext(PytatoPyOpenCLArrayContext):
46+
"""
47+
.. note::
48+
49+
Refer to :meth:`transform_dag` and :meth:`transform_loopy_program` for
50+
details on the transformation algorithm provided by this array context.
51+
52+
.. warning::
53+
54+
For expression graphs with large number of nodes high compile times are
55+
expected.
56+
"""
57+
def transform_dag(self,
58+
dag: "pytato.DictOfNamedArrays") -> "pytato.DictOfNamedArrays":
59+
r"""
60+
Returns a transformed version of *dag*, where the applied transform is:
61+
62+
#. Materialize as per MPMS materialization heuristic.
63+
#. materialize every :class:`pytato.array.Einsum`\ 's inputs and outputs.
64+
"""
65+
import pytato as pt
66+
67+
# Step 1. Collapse equivalent nodes in DAG.
68+
# -----------------------------------------
69+
# type-ignore-reason: mypy is right pytato provides imprecise types.
70+
dag = pt.transform.deduplicate_data_wrappers(dag) # type: ignore[assignment]
71+
72+
# Step 2. Materialize einsum inputs/outputs.
73+
# ------------------------------------------
74+
from .utils import get_inputs_and_outputs_of_einsum
75+
einsum_inputs_outputs = frozenset.union(
76+
*get_inputs_and_outputs_of_einsum(dag))
77+
78+
def materialize_einsum(expr: pt.transform.ArrayOrNames
79+
) -> pt.transform.ArrayOrNames:
80+
if expr in einsum_inputs_outputs:
81+
if isinstance(expr, pt.InputArgumentBase):
82+
return expr
83+
else:
84+
return expr.tagged(pt.tags.ImplStored())
85+
else:
86+
return expr
87+
88+
# type-ignore-reason: mypy is right pytato provides imprecise types.
89+
dag = pt.transform.map_and_copy(dag, # type: ignore[assignment]
90+
materialize_einsum)
91+
92+
# Step 3. MPMS materialize
93+
# ------------------------
94+
dag = pt.transform.materialize_with_mpms(dag)
95+
96+
return dag
97+
98+
def transform_loopy_program(self,
99+
t_unit: lp.TranslationUnit) -> lp.TranslationUnit:
100+
r"""
101+
Returns a transformed version of *t_unit*, where the applied transform is:
102+
103+
#. An execution grid size :math:`G` is selected based on *self*'s
104+
OpenCL-device.
105+
#. The iteration domain for each statement in the *t_unit* is divided to
106+
equally among the work-items in :math:`G`.
107+
#. Kernel boundaries are drawn between every statement in the instruction.
108+
Although one can relax this constraint by letting :mod:`loopy` compute
109+
where to insert the global barriers, but it is not guaranteed to be
110+
performance profitable since we do not attempt any further loop-fusion
111+
and/or array contraction.
112+
#. Once the kernel boundaries are inferred, :func:`alias_global_temporaries`
113+
is invoked to reduce the memory peak memory used by the transformed
114+
program.
115+
"""
116+
# Step 1. Split the iteration across work-items
117+
# ---------------------------------------------
118+
from .utils import split_iteration_domain_across_work_items
119+
t_unit = split_iteration_domain_across_work_items(t_unit, self.queue.device)
120+
121+
# Step 2. Add a global barrier between individual loop nests.
122+
# ------------------------------------------------------
123+
from .utils import add_gbarrier_between_disjoint_loop_nests
124+
t_unit = add_gbarrier_between_disjoint_loop_nests(t_unit)
125+
126+
# Step 3. Alias global temporaries with disjoint live intervals
127+
# -------------------------------------------------------------
128+
from .utils import alias_global_temporaries
129+
t_unit = alias_global_temporaries(t_unit)
130+
131+
return t_unit
132+
133+
# vim: fdm=marker

0 commit comments

Comments
 (0)