Skip to content

Commit 85e4285

Browse files
kaushikcfdinducer
authored andcommitted
pytato: avoid codegen when possible
1 parent 829f892 commit 85e4285

File tree

1 file changed

+20
-8
lines changed

1 file changed

+20
-8
lines changed

arraycontext/impl/pytato/__init__.py

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -486,6 +486,16 @@ def _record_leaf_ary_in_dict(
486486

487487
# }}}
488488

489+
def _to_frozen(key: Tuple[Any, ...], ary) -> TaggableCLArray:
490+
key_str = "_ary" + _ary_container_key_stringifier(key)
491+
return key_to_frozen_subary[key_str]
492+
493+
if not key_to_pt_arrays:
494+
# all cl arrays => no need to perform any codegen
495+
return with_array_context(
496+
rec_keyed_map_array_container(_to_frozen, array),
497+
actx=None)
498+
489499
pt_dict_of_named_arrays = pt.make_dict_of_named_arrays(
490500
key_to_pt_arrays)
491501
normalized_expr, bound_arguments = _normalize_pt_expr(
@@ -544,10 +554,6 @@ def _record_leaf_ary_in_dict(
544554
for k, v in out_dict.items()}
545555
}
546556

547-
def _to_frozen(key: Tuple[Any, ...], ary) -> TaggableCLArray:
548-
key_str = "_ary" + _ary_container_key_stringifier(key)
549-
return key_to_frozen_subary[key_str]
550-
551557
return with_array_context(
552558
rec_keyed_map_array_container(_to_frozen, array),
553559
actx=None)
@@ -800,6 +806,16 @@ def _record_leaf_ary_in_dict(key: Tuple[Any, ...],
800806

801807
# }}}
802808

809+
def _to_frozen(key: Tuple[Any, ...], ary) -> jnp.ndarray:
810+
key_str = "_ary" + _ary_container_key_stringifier(key)
811+
return key_to_frozen_subary[key_str]
812+
813+
if not key_to_pt_arrays:
814+
# all cl arrays => no need to perform any codegen
815+
return with_array_context(
816+
rec_keyed_map_array_container(_to_frozen, array),
817+
actx=None)
818+
803819
pt_dict_of_named_arrays = pt.make_dict_of_named_arrays(key_to_pt_arrays)
804820
transformed_dag = self.transform_dag(pt_dict_of_named_arrays)
805821
pt_prg = pt.generate_jax(transformed_dag, jit=True)
@@ -812,10 +828,6 @@ def _record_leaf_ary_in_dict(key: Tuple[Any, ...],
812828
for k, v in out_dict.items()}
813829
}
814830

815-
def _to_frozen(key: Tuple[Any, ...], ary) -> jnp.ndarray:
816-
key_str = "_ary" + _ary_container_key_stringifier(key)
817-
return key_to_frozen_subary[key_str]
818-
819831
return with_array_context(
820832
rec_keyed_map_array_container(_to_frozen, array),
821833
actx=None)

0 commit comments

Comments
 (0)