From b774189888db0ab1ee1a18698400b8eca110d469 Mon Sep 17 00:00:00 2001 From: max Date: Wed, 20 Dec 2023 14:58:07 -0600 Subject: [PATCH 1/3] [mlir][python] move transform extras to dialects --- mlir/python/CMakeLists.txt | 2 +- .../mlir/dialects/transform/__init__.py | 1 + .../transform/extras}/__init__.py | 22 +++++++++---------- 3 files changed, 13 insertions(+), 12 deletions(-) rename mlir/python/mlir/{extras/dialects/transform => dialects/transform/extras}/__init__.py (87%) diff --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt index 41d91cf677833..55c5973e40e52 100644 --- a/mlir/python/CMakeLists.txt +++ b/mlir/python/CMakeLists.txt @@ -172,7 +172,7 @@ declare_mlir_python_sources( ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir" GEN_ENUM_BINDINGS SOURCES - extras/dialects/transform/__init__.py) + dialects/transform/extras/__init__.py) declare_mlir_dialect_extension_python_bindings( ADD_TO_PARENT MLIRPythonSources.Dialects diff --git a/mlir/python/mlir/dialects/transform/__init__.py b/mlir/python/mlir/dialects/transform/__init__.py index 7ae4fefbac412..175634c7d458f 100644 --- a/mlir/python/mlir/dialects/transform/__init__.py +++ b/mlir/python/mlir/dialects/transform/__init__.py @@ -6,6 +6,7 @@ from .._transform_ops_gen import * from .._transform_ops_gen import _Dialect from ..._mlir_libs._mlirDialectsTransform import * +from ..._mlir_libs._mlirDialectsTransform import AnyOpType, OperationType try: from ...ir import * diff --git a/mlir/python/mlir/extras/dialects/transform/__init__.py b/mlir/python/mlir/dialects/transform/extras/__init__.py similarity index 87% rename from mlir/python/mlir/extras/dialects/transform/__init__.py rename to mlir/python/mlir/dialects/transform/extras/__init__.py index 9e313324318aa..8c69f12e54e36 100644 --- a/mlir/python/mlir/extras/dialects/transform/__init__.py +++ b/mlir/python/mlir/dialects/transform/extras/__init__.py @@ -6,8 +6,8 @@ from typing import Callable, Optional, Sequence from .... import ir -from ....dialects import transform -from ....dialects.transform import structured +from .. import AnyOpType, OperationType, NamedSequenceOp, YieldOp +from .. import structured class Handle(ir.Value): @@ -33,8 +33,8 @@ def __init__( self.children = children if children is not None else [] -@ir.register_value_caster(transform.AnyOpType.get_static_typeid()) -@ir.register_value_caster(transform.OperationType.get_static_typeid()) +@ir.register_value_caster(AnyOpType.get_static_typeid()) +@ir.register_value_caster(OperationType.get_static_typeid()) class OpHandle(Handle): """ Wrapper around a transform operation handle with methods to chain further @@ -70,7 +70,7 @@ def match_ops( if isinstance(ops, str): ops = structured.MatchInterfaceEnum[ops] match_op = structured.MatchOp( - transform.AnyOpType.get(), + AnyOpType.get(), self, interface=ops, ) @@ -78,15 +78,15 @@ def match_ops( # Handle op name(s), either given directly as string or given as op. else: if isinstance(ops, str): - op_type = transform.OperationType.get(ops) + op_type = OperationType.get(ops) op_names = [ops] elif isinstance(ops, Sequence): - op_type = transform.AnyOpType.get() + op_type = AnyOpType.get() op_names = [ op if isinstance(op, str) else op.OPERATION_NAME for op in ops ] else: - op_type = transform.OperationType.get(ops.OPERATION_NAME) + op_type = OperationType.get(ops.OPERATION_NAME) op_names = [ops.OPERATION_NAME] match_op = structured.MatchOp.match_op_names( op_type, @@ -137,12 +137,12 @@ def test_match_ops_single(module: OpHandle): with context, ir.Location.unknown(context): with insertion_point: - named_sequence_op = transform.NamedSequenceOp( - "__transform_main", [transform.AnyOpType.get()], [] + named_sequence_op = NamedSequenceOp( + "__transform_main", [AnyOpType.get()], [] ) with ir.InsertionPoint(named_sequence_op.body): script(named_sequence_op.bodyTarget) - transform.YieldOp([]) + YieldOp([]) if dump_script: print(named_sequence_op) From bbf7f34d7aa7e41d1b802269a3cb8b9cb67d4c62 Mon Sep 17 00:00:00 2001 From: max Date: Wed, 20 Dec 2023 15:19:40 -0600 Subject: [PATCH 2/3] fix tests --- mlir/test/python/dialects/transform_extras.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlir/test/python/dialects/transform_extras.py b/mlir/test/python/dialects/transform_extras.py index dbfa8a2dc73c4..e7b43ea63c31c 100644 --- a/mlir/test/python/dialects/transform_extras.py +++ b/mlir/test/python/dialects/transform_extras.py @@ -4,7 +4,7 @@ from mlir import ir from mlir.dialects import scf from mlir.dialects.transform import structured -from mlir.extras.dialects.transform import OpHandle, insert_transform_script +from mlir.dialects.transform.extras import OpHandle, insert_transform_script def build_transform_script(script: Callable[[OpHandle], None]): From ba34bd2dbe21979d9009313c68bb230b9af99537 Mon Sep 17 00:00:00 2001 From: max Date: Wed, 20 Dec 2023 15:22:49 -0600 Subject: [PATCH 3/3] replace pipes for type hints --- .../dialects/transform/extras/__init__.py | 21 ++++++++++--------- 1 file changed, 11 insertions(+), 10 deletions(-) diff --git a/mlir/python/mlir/dialects/transform/extras/__init__.py b/mlir/python/mlir/dialects/transform/extras/__init__.py index 8c69f12e54e36..c715dac1ef7eb 100644 --- a/mlir/python/mlir/dialects/transform/extras/__init__.py +++ b/mlir/python/mlir/dialects/transform/extras/__init__.py @@ -2,8 +2,7 @@ # See https://llvm.org/LICENSE.txt for license information. # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -from __future__ import annotations -from typing import Callable, Optional, Sequence +from typing import Callable, Optional, Sequence, Union from .... import ir from .. import AnyOpType, OperationType, NamedSequenceOp, YieldOp @@ -25,8 +24,8 @@ def __init__( self, v: ir.Value, *, - parent: Optional[Handle] = None, - children: Optional[Sequence[Handle]] = None, + parent: Optional["Handle"] = None, + children: Optional[Sequence["Handle"]] = None, ): super().__init__(v) self.parent = parent @@ -52,11 +51,13 @@ def __init__( def match_ops( self, - ops: str - | ir.OpView - | structured.MatchInterfaceEnum - | Sequence[str | ir.OpView], - ) -> OpHandle: + ops: Union[ + str, + ir.OpView, + structured.MatchInterfaceEnum, + Sequence[Union[str, ir.OpView]], + ], + ) -> "OpHandle": """ Emits a `transform.structured.MatchOp`. Returns a handle to payload ops that match the given names, types, or @@ -100,7 +101,7 @@ def match_ops( def insert_transform_script( - block_or_insertion_point: ir.Block | ir.InsertionPoint, + block_or_insertion_point: Union[ir.Block, ir.InsertionPoint], script: Callable[[OpHandle], None], dump_script: bool = False, ) -> None: