22# See https://llvm.org/LICENSE.txt for license information.
33# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
44
5- from __future__ import annotations
6- from typing import Callable , Optional , Sequence
5+ from typing import Callable , Optional , Sequence , Union
76
87from .... import ir
9- from .... dialects import transform
10- from .... dialects . transform import structured
8+ from .. import AnyOpType , OperationType , NamedSequenceOp , YieldOp
9+ from .. import structured
1110
1211
1312class Handle (ir .Value ):
@@ -25,16 +24,16 @@ def __init__(
2524 self ,
2625 v : ir .Value ,
2726 * ,
28- parent : Optional [Handle ] = None ,
29- children : Optional [Sequence [Handle ]] = None ,
27+ parent : Optional [" Handle" ] = None ,
28+ children : Optional [Sequence [" Handle" ]] = None ,
3029 ):
3130 super ().__init__ (v )
3231 self .parent = parent
3332 self .children = children if children is not None else []
3433
3534
36- @ir .register_value_caster (transform . AnyOpType .get_static_typeid ())
37- @ir .register_value_caster (transform . OperationType .get_static_typeid ())
35+ @ir .register_value_caster (AnyOpType .get_static_typeid ())
36+ @ir .register_value_caster (OperationType .get_static_typeid ())
3837class OpHandle (Handle ):
3938 """
4039 Wrapper around a transform operation handle with methods to chain further
@@ -52,11 +51,13 @@ def __init__(
5251
5352 def match_ops (
5453 self ,
55- ops : str
56- | ir .OpView
57- | structured .MatchInterfaceEnum
58- | Sequence [str | ir .OpView ],
59- ) -> OpHandle :
54+ ops : Union [
55+ str ,
56+ ir .OpView ,
57+ structured .MatchInterfaceEnum ,
58+ Sequence [Union [str , ir .OpView ]],
59+ ],
60+ ) -> "OpHandle" :
6061 """
6162 Emits a `transform.structured.MatchOp`.
6263 Returns a handle to payload ops that match the given names, types, or
@@ -70,23 +71,23 @@ def match_ops(
7071 if isinstance (ops , str ):
7172 ops = structured .MatchInterfaceEnum [ops ]
7273 match_op = structured .MatchOp (
73- transform . AnyOpType .get (),
74+ AnyOpType .get (),
7475 self ,
7576 interface = ops ,
7677 )
7778
7879 # Handle op name(s), either given directly as string or given as op.
7980 else :
8081 if isinstance (ops , str ):
81- op_type = transform . OperationType .get (ops )
82+ op_type = OperationType .get (ops )
8283 op_names = [ops ]
8384 elif isinstance (ops , Sequence ):
84- op_type = transform . AnyOpType .get ()
85+ op_type = AnyOpType .get ()
8586 op_names = [
8687 op if isinstance (op , str ) else op .OPERATION_NAME for op in ops
8788 ]
8889 else :
89- op_type = transform . OperationType .get (ops .OPERATION_NAME )
90+ op_type = OperationType .get (ops .OPERATION_NAME )
9091 op_names = [ops .OPERATION_NAME ]
9192 match_op = structured .MatchOp .match_op_names (
9293 op_type ,
@@ -100,7 +101,7 @@ def match_ops(
100101
101102
102103def insert_transform_script (
103- block_or_insertion_point : ir .Block | ir .InsertionPoint ,
104+ block_or_insertion_point : Union [ ir .Block , ir .InsertionPoint ] ,
104105 script : Callable [[OpHandle ], None ],
105106 dump_script : bool = False ,
106107) -> None :
@@ -137,12 +138,12 @@ def test_match_ops_single(module: OpHandle):
137138
138139 with context , ir .Location .unknown (context ):
139140 with insertion_point :
140- named_sequence_op = transform . NamedSequenceOp (
141- "__transform_main" , [transform . AnyOpType .get ()], []
141+ named_sequence_op = NamedSequenceOp (
142+ "__transform_main" , [AnyOpType .get ()], []
142143 )
143144 with ir .InsertionPoint (named_sequence_op .body ):
144145 script (named_sequence_op .bodyTarget )
145- transform . YieldOp ([])
146+ YieldOp ([])
146147
147148 if dump_script :
148149 print (named_sequence_op )
0 commit comments