|
| 1 | +""" |
| 2 | +.. autoclass:: SplitPytatoPyOpenCLArrayContext |
| 3 | +
|
| 4 | +""" |
| 5 | + |
| 6 | +__copyright__ = """ |
| 7 | +Copyright (C) 2023 Kaushik Kulkarni |
| 8 | +Copyright (C) 2023 Andreas Kloeckner |
| 9 | +Copyright (C) 2022 Matthias Diener |
| 10 | +Copyright (C) 2022 Matt Smith |
| 11 | +""" |
| 12 | + |
| 13 | +__license__ = """ |
| 14 | +Permission is hereby granted, free of charge, to any person obtaining a copy |
| 15 | +of this software and associated documentation files (the "Software"), to deal |
| 16 | +in the Software without restriction, including without limitation the rights |
| 17 | +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell |
| 18 | +copies of the Software, and to permit persons to whom the Software is |
| 19 | +furnished to do so, subject to the following conditions: |
| 20 | +
|
| 21 | +The above copyright notice and this permission notice shall be included in |
| 22 | +all copies or substantial portions of the Software. |
| 23 | +
|
| 24 | +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR |
| 25 | +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, |
| 26 | +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE |
| 27 | +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER |
| 28 | +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, |
| 29 | +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN |
| 30 | +THE SOFTWARE. |
| 31 | +""" |
| 32 | + |
| 33 | +import sys |
| 34 | +from typing import TYPE_CHECKING |
| 35 | + |
| 36 | +import loopy as lp |
| 37 | + |
| 38 | +from arraycontext.impl.pytato import PytatoPyOpenCLArrayContext |
| 39 | + |
| 40 | + |
| 41 | +if TYPE_CHECKING or getattr(sys, "_BUILDING_SPHINX_DOCS", False): |
| 42 | + import pytato |
| 43 | + |
| 44 | + |
| 45 | +class SplitPytatoPyOpenCLArrayContext(PytatoPyOpenCLArrayContext): |
| 46 | + """ |
| 47 | + .. note:: |
| 48 | +
|
| 49 | + Refer to :meth:`transform_dag` and :meth:`transform_loopy_program` for |
| 50 | + details on the transformation algorithm provided by this array context. |
| 51 | +
|
| 52 | + .. warning:: |
| 53 | +
|
| 54 | + For expression graphs with large number of nodes high compile times are |
| 55 | + expected. |
| 56 | + """ |
| 57 | + def transform_dag(self, |
| 58 | + dag: "pytato.DictOfNamedArrays") -> "pytato.DictOfNamedArrays": |
| 59 | + r""" |
| 60 | + Returns a transformed version of *dag*, where the applied transform is: |
| 61 | +
|
| 62 | + #. Materialize as per MPMS materialization heuristic. |
| 63 | + #. materialize every :class:`pytato.array.Einsum`\ 's inputs and outputs. |
| 64 | + """ |
| 65 | + import pytato as pt |
| 66 | + |
| 67 | + # Step 1. Collapse equivalent nodes in DAG. |
| 68 | + # ----------------------------------------- |
| 69 | + # type-ignore-reason: mypy is right pytato provides imprecise types. |
| 70 | + dag = pt.transform.deduplicate_data_wrappers(dag) # type: ignore[assignment] |
| 71 | + |
| 72 | + # Step 2. Materialize einsum inputs/outputs. |
| 73 | + # ------------------------------------------ |
| 74 | + from .utils import get_inputs_and_outputs_of_einsum |
| 75 | + einsum_inputs_outputs = frozenset.union( |
| 76 | + *get_inputs_and_outputs_of_einsum(dag)) |
| 77 | + |
| 78 | + def materialize_einsum(expr: pt.transform.ArrayOrNames |
| 79 | + ) -> pt.transform.ArrayOrNames: |
| 80 | + if expr in einsum_inputs_outputs: |
| 81 | + if isinstance(expr, pt.InputArgumentBase): |
| 82 | + return expr |
| 83 | + else: |
| 84 | + return expr.tagged(pt.tags.ImplStored()) |
| 85 | + else: |
| 86 | + return expr |
| 87 | + |
| 88 | + # type-ignore-reason: mypy is right pytato provides imprecise types. |
| 89 | + dag = pt.transform.map_and_copy(dag, # type: ignore[assignment] |
| 90 | + materialize_einsum) |
| 91 | + |
| 92 | + # Step 3. MPMS materialize |
| 93 | + # ------------------------ |
| 94 | + dag = pt.transform.materialize_with_mpms(dag) |
| 95 | + |
| 96 | + return dag |
| 97 | + |
| 98 | + def transform_loopy_program(self, |
| 99 | + t_unit: lp.TranslationUnit) -> lp.TranslationUnit: |
| 100 | + r""" |
| 101 | + Returns a transformed version of *t_unit*, where the applied transform is: |
| 102 | +
|
| 103 | + #. An execution grid size :math:`G` is selected based on *self*'s |
| 104 | + OpenCL-device. |
| 105 | + #. The iteration domain for each statement in the *t_unit* is divided to |
| 106 | + equally among the work-items in :math:`G`. |
| 107 | + #. Kernel boundaries are drawn between every statement in the instruction. |
| 108 | + Although one can relax this constraint by letting :mod:`loopy` compute |
| 109 | + where to insert the global barriers, but it is not guaranteed to be |
| 110 | + performance profitable since we do not attempt any further loop-fusion |
| 111 | + and/or array contraction. |
| 112 | + #. Once the kernel boundaries are inferred, :func:`alias_global_temporaries` |
| 113 | + is invoked to reduce the memory peak memory used by the transformed |
| 114 | + program. |
| 115 | + """ |
| 116 | + # Step 1. Split the iteration across work-items |
| 117 | + # --------------------------------------------- |
| 118 | + from .utils import split_iteration_domain_across_work_items |
| 119 | + t_unit = split_iteration_domain_across_work_items(t_unit, self.queue.device) |
| 120 | + |
| 121 | + # Step 2. Add a global barrier between individual loop nests. |
| 122 | + # ------------------------------------------------------ |
| 123 | + from .utils import add_gbarrier_between_disjoint_loop_nests |
| 124 | + t_unit = add_gbarrier_between_disjoint_loop_nests(t_unit) |
| 125 | + |
| 126 | + # Step 3. Alias global temporaries with disjoint live intervals |
| 127 | + # ------------------------------------------------------------- |
| 128 | + from .utils import alias_global_temporaries |
| 129 | + t_unit = alias_global_temporaries(t_unit) |
| 130 | + |
| 131 | + return t_unit |
| 132 | + |
| 133 | +# vim: fdm=marker |
0 commit comments