1- """ 
1+ from  __future__ import  annotations 
2+ 
3+ 
4+ __doc__  =  """ 
25.. currentmodule:: arraycontext 
36
47A :mod:`pytato`-based array context defers the evaluation of an array until its 
6265from  pytools .tag  import  Tag , ToTagSetConvertible , normalize_tags 
6366
6467from  arraycontext .container .traversal  import  rec_map_array_container , with_array_context 
65- from  arraycontext .context  import  Array , ArrayContext , ArrayOrContainer , ScalarLike 
68+ from  arraycontext .context  import  (
69+     Array ,
70+     ArrayContext ,
71+     ArrayOrContainer ,
72+     ScalarLike ,
73+     UntransformedCodeWarning ,
74+ )
6675from  arraycontext .metadata  import  NameHint 
6776
6877
6978if  TYPE_CHECKING :
79+     import  loopy  as  lp 
7080    import  pyopencl  as  cl 
7181    import  pytato 
7282
@@ -137,7 +147,6 @@ def __init__(
137147        """ 
138148        super ().__init__ ()
139149
140-         import  loopy  as  lp 
141150        import  pytato  as  pt 
142151        self ._freeze_prg_cache : Dict [pt .DictOfNamedArrays , lp .TranslationUnit ] =  {}
143152        self ._dag_transform_cache : Dict [
@@ -180,8 +189,8 @@ def empty_like(self, ary):
180189
181190    # {{{ compilation 
182191
183-     def  transform_dag (self , dag : " pytato.DictOfNamedArrays" 
184-                       ) ->  " pytato.DictOfNamedArrays" 
192+     def  transform_dag (self , dag : pytato .DictOfNamedArrays 
193+                       ) ->  pytato .DictOfNamedArrays :
185194        """ 
186195        Returns a transformed version of *dag*. Sub-classes are supposed to 
187196        override this method to implement context-specific transformations on 
@@ -194,10 +203,22 @@ def transform_dag(self, dag: "pytato.DictOfNamedArrays"
194203        """ 
195204        return  dag 
196205
197-     def  transform_loopy_program (self , t_unit ):
198-         raise  ValueError (
199-             f"{ type (self ).__name__ }  
200-             "Sub-classes are supposed to implement it." )
206+     def  transform_loopy_program (self , t_unit : lp .TranslationUnit ) ->  lp .TranslationUnit :
207+         from  warnings  import  warn 
208+         warn ("Using the base " 
209+                 f"{ type (self ).__name__ }  
210+                 "to transform a translation unit. " 
211+                 "This is a no-op and will result in unoptimized C code for" 
212+                 "the requested optimization, all in a single statement." 
213+                 "This will work, but is unlikely to be performatn." 
214+                 f"Instead, subclass { type (self ).__name__ }  
215+                 "the specific transform logic required to transform the program " 
216+                 "for your package or application. Check higher-level packages " 
217+                 "(e.g. meshmode), which may already have subclasses you may want " 
218+                 "to build on." ,
219+                 UntransformedCodeWarning , stacklevel = 2 )
220+ 
221+         return  t_unit 
201222
202223    @abc .abstractmethod  
203224    def  einsum (self , spec , * args , arg_names = None , tagged = ()):
@@ -250,7 +271,7 @@ class PytatoPyOpenCLArrayContext(_BasePytatoArrayContext):
250271    .. automethod:: compile 
251272    """ 
252273    def  __init__ (
253-             self , queue : " cl.CommandQueue" allocator = None , * ,
274+             self , queue : cl .CommandQueue , allocator = None , * ,
254275            use_memory_pool : Optional [bool ] =  None ,
255276            compile_trace_callback : Optional [Callable [[Any , str , Any ], None ]] =  None ,
256277
@@ -642,8 +663,8 @@ def compile(self, f: Callable[..., Any]) -> Callable[..., Any]:
642663        from  .compile  import  LazilyPyOpenCLCompilingFunctionCaller 
643664        return  LazilyPyOpenCLCompilingFunctionCaller (self , f )
644665
645-     def  transform_dag (self , dag : " pytato.DictOfNamedArrays" 
646-                       ) ->  " pytato.DictOfNamedArrays" 
666+     def  transform_dag (self , dag : pytato .DictOfNamedArrays 
667+                       ) ->  pytato .DictOfNamedArrays :
647668        import  pytato  as  pt 
648669        dag  =  pt .transform .materialize_with_mpms (dag )
649670        return  dag 
0 commit comments