@@ -486,6 +486,16 @@ def _record_leaf_ary_in_dict(
486486
487487 # }}}
488488
489+ def _to_frozen (key : Tuple [Any , ...], ary ) -> TaggableCLArray :
490+ key_str = "_ary" + _ary_container_key_stringifier (key )
491+ return key_to_frozen_subary [key_str ]
492+
493+ if not key_to_pt_arrays :
494+ # all cl arrays => no need to perform any codegen
495+ return with_array_context (
496+ rec_keyed_map_array_container (_to_frozen , array ),
497+ actx = None )
498+
489499 pt_dict_of_named_arrays = pt .make_dict_of_named_arrays (
490500 key_to_pt_arrays )
491501 normalized_expr , bound_arguments = _normalize_pt_expr (
@@ -544,10 +554,6 @@ def _record_leaf_ary_in_dict(
544554 for k , v in out_dict .items ()}
545555 }
546556
547- def _to_frozen (key : Tuple [Any , ...], ary ) -> TaggableCLArray :
548- key_str = "_ary" + _ary_container_key_stringifier (key )
549- return key_to_frozen_subary [key_str ]
550-
551557 return with_array_context (
552558 rec_keyed_map_array_container (_to_frozen , array ),
553559 actx = None )
@@ -800,6 +806,16 @@ def _record_leaf_ary_in_dict(key: Tuple[Any, ...],
800806
801807 # }}}
802808
809+ def _to_frozen (key : Tuple [Any , ...], ary ) -> jnp .ndarray :
810+ key_str = "_ary" + _ary_container_key_stringifier (key )
811+ return key_to_frozen_subary [key_str ]
812+
813+ if not key_to_pt_arrays :
814+ # all cl arrays => no need to perform any codegen
815+ return with_array_context (
816+ rec_keyed_map_array_container (_to_frozen , array ),
817+ actx = None )
818+
803819 pt_dict_of_named_arrays = pt .make_dict_of_named_arrays (key_to_pt_arrays )
804820 transformed_dag = self .transform_dag (pt_dict_of_named_arrays )
805821 pt_prg = pt .generate_jax (transformed_dag , jit = True )
@@ -812,10 +828,6 @@ def _record_leaf_ary_in_dict(key: Tuple[Any, ...],
812828 for k , v in out_dict .items ()}
813829 }
814830
815- def _to_frozen (key : Tuple [Any , ...], ary ) -> jnp .ndarray :
816- key_str = "_ary" + _ary_container_key_stringifier (key )
817- return key_to_frozen_subary [key_str ]
818-
819831 return with_array_context (
820832 rec_keyed_map_array_container (_to_frozen , array ),
821833 actx = None )
0 commit comments