Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
148 changes: 148 additions & 0 deletions ingress/Torch-MLIR/convert-kernel-bench-to-mlir.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
#!/usr/bin/env python3

from pathlib import Path

from mlir import ir, passmanager
from lighthouse.ingress import torch as torch_ingress


kernels_as_pytorch_folder = Path(__file__).parent / "KernelBench" / "KernelBench"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

since this depends on where the git was cloned in the bash script, perhaps that last step (clone) could be done in this script as well?

Copy link
Contributor Author

@rolfmorel rolfmorel Oct 8, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure.

Doing a git clone in either script feels unclean. I also don't like the idea of it being a submodule as that then seems to imply you have to clone KernelBench to do anything useful with lighthouse. It seems to me KernelBench will be just one source of ingress compute graphs of interest, with it potentially making sense to allow users/CI to opt-in to which paths they want to run tests with. What's the right mechanism for that? I am not sure.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

KernelBench is NOT an ingress. Torch-MLIR is.

We now have three PRs that work with FX importer, none using the other. We should have one FX importer script that is used by others.

kernels_as_pytorch_level1 = kernels_as_pytorch_folder / "level1"
kernels_as_pytorch_level2 = kernels_as_pytorch_folder / "level2"

kernels_as_mlir_folder = Path(__file__).parent / "cache"
kernels_as_mlir_level1 = kernels_as_mlir_folder / "level1"
kernels_as_mlir_level1.mkdir(parents=True, exist_ok=True)
kernels_as_mlir_level2 = kernels_as_mlir_folder / "level2"
kernels_as_mlir_level2.mkdir(parents=True, exist_ok=True)

level1, level2 = Path("level1"), Path("level2")
ignore_list = [
level1 / "12_Matmul_with_diagonal_matrices_.py", # torch.operator "torch.aten.diag"
level1
/ "34_InstanceNorm.py", # LLVM ERROR: SmallVector unable to grow. Requested capacity (93898875033000)
level1
/ "72_conv_transposed_3D_asymmetric_input_asymmetric_kernel___strided_padded_grouped_.py", # Bare exception during torch-backend-to-linalg-on-tensors-backend-pipeline
level1
/ "89_cumsum.py", # Dialect `tm_tensor' not found for custom op 'tm_tensor.scan'
level1
/ "90_cumprod.py", # Dialect `tm_tensor' not found for custom op 'tm_tensor.scan'
level1
/ "91_cumsum_reverse.py", # Dialect `tm_tensor' not found for custom op 'tm_tensor.scan'
level1
/ "92_cumsum_exclusive.py", # Dialect `tm_tensor' not found for custom op 'tm_tensor.scan'
level1
/ "93_masked_cumsum.py", # Dialect `tm_tensor' not found for custom op 'tm_tensor.scan'
level1
/ "95_CrossEntropyLoss.py", # Bare exception during torch-backend-to-linalg-on-tensors-backend-pipeline
level1
/ "96_HuberLoss.py", # Bare exception during torch-backend-to-linalg-on-tensors-backend-pipeline
level1
/ "97_ScaledDotProductAttention.py", # AssertionError: Torch not compiled with CUDA enabled
level1
/ "99_TripletMarginLoss.py", # Bare exception during torch-backend-to-linalg-on-tensors-backend-pipeline
level2
/ "17_Conv2d_InstanceNorm_Divide.py", # LLVM ERROR: SmallVector unable to grow. Requested capacity (94899412484104)
level2
/ "18_Matmul_Sum_Max_AvgPool_LogSumExp_LogSumExp.py", # error: failed to legalize operation 'torch.constant.int'
level2
/ "22_Matmul_Scale_ResidualAdd_Clamp_LogSumExp_Mish.py", # error: failed to legalize operation 'torch.constant.int'
level2
/ "28_BMM_InstanceNorm_Sum_ResidualAdd_Multiply.py", # LLVM ERROR: SmallVector unable to grow. Requested capacity (94899412484104)
level2
/ "42_ConvTranspose2d_GlobalAvgPool_BiasAdd_LogSumExp_Sum_Multiply.py", # error: failed to legalize operation 'torch.constant.int'
level2
/ "43_Conv3d_Max_LogSumExp_ReLU.py", # error: failed to legalize operation 'torch.constant.int'
level2
/ "45_Gemm_Sigmoid_LogSumExp.py", # error: failed to legalize operation 'torch.constant.int'
level2
/ "51_Gemm_Subtract_GlobalAvgPool_LogSumExp_GELU_ResidualAdd.py", # error: failed to legalize operation 'torch.constant.int'
level2
/ "52_Conv2d_Activation_BatchNorm.py", # failed to legalize operation 'torch.operator'
level2 / "55_Matmul_MaxPool_Sum_Scale.py", # MLIR file too big: 16G
level2 / "59_Matmul_Swish_Scaling.py", # MLIR file too big: 16G
level2 / "56_Matmul_Sigmoid_Sum.py", # MLIR file too big: 16G
level2 / "66_Matmul_Dropout_Softmax.py", # MLIR file too big: 4G
level2 / "68_Matmul_Min_Subtract.py", # MLIR file too big: 4G
level2 / "94_Gemm_BiasAdd_Hardtanh_Mish_GroupNorm.py", # MLIR file too big: 1G
level2 / "33_Gemm_Scale_BatchNorm.py", # MLIR file too big: 1G
level2 / "88_Gemm_GroupNorm_Swish_Multiply_Swish.py", # MLIR file too big: 1G
level2 / "75_Gemm_GroupNorm_Min_BiasAdd.py", # MLIR file too big: 1G
level2 / "84_Gemm_BatchNorm_Scaling_Softmax.py", # MLIR file too big: 1G
level2 / "97_Matmul_BatchNorm_BiasAdd_Divide_Swish.py", # MLIR file too big: 1G
level2 / "62_Matmul_GroupNorm_LeakyReLU_Sum.py", # MLIR file too big: 1G
level2 / "30_Gemm_GroupNorm_Hardtanh.py", # MLIR file too big: 1G
level2 / "95_Matmul_Add_Swish_Tanh_GELU_Hardtanh.py", # MLIR file too big: 1G
level2 / "29_Matmul_Mish_Mish.py", # MLIR file too big: 1G
level2 / "99_Matmul_GELU_Softmax.py", # MLIR file too big: 1G
level2 / "98_Matmul_AvgPool_GELU_Scale_Max.py", # MLIR file too big: 1G
level2 / "80_Gemm_Max_Subtract_GELU.py", # MLIR file too big: 1G
level2 / "81_Gemm_Swish_Divide_Clamp_Tanh_Clamp.py", # MLIR file too big: 1G
level2 / "12_Gemm_Multiply_LeakyReLU.py", # MLIR file too big: 1G
level2 / "53_Gemm_Scaling_Hardtanh_GELU.py", # MLIR file too big: 1G
level2 / "9_Matmul_Subtract_Multiply_ReLU.py", # MLIR file too big: 1G
level2 / "70_Gemm_Sigmoid_Scaling_ResidualAdd.py", # MLIR file too big: 1G
level2 / "86_Matmul_Divide_GELU.py", # MLIR file too big: 1G
level2 / "63_Gemm_ReLU_Divide.py", # MLIR file too big: 1G
level2 / "76_Gemm_Add_ReLU.py", # MLIR file too big: 1G
level2 / "14_Gemm_Divide_Sum_Scaling.py", # MLIR file too big: 1G
level2 / "39_Gemm_Scale_BatchNorm.py", # MLIR file too big: 256M
level2 / "41_Gemm_BatchNorm_GELU_ReLU.py", # MLIR file too big: 256M
level2 / "40_Matmul_Scaling_ResidualAdd.py", # MLIR file too big: 256M
level2 / "37_Matmul_Swish_Sum_GroupNorm.py", # MLIR file too big: 64.3M
level2
/ "58_ConvTranspose3d_LogSumExp_HardSwish_Subtract_Clamp.py", # error: failed to legalize operation 'torch.constant.int'
level2
/ "64_Gemm_LogSumExp_LeakyReLU_LeakyReLU_GELU_GELU.py", # error: failed to legalize operation 'torch.constant.int'
level2
/ "79_Conv3d_Multiply_InstanceNorm_Clamp_Multiply_Max.py", # LLVM ERROR: SmallVector unable to grow. Requested capacity (94312016449768)
level2
/ "92_Conv2d_GroupNorm_Tanh_HardSwish_ResidualAdd_LogSumExp.py", # error: failed to legalize operation 'torch.constant.int'
]


ctx = ir.Context()
pm = passmanager.PassManager(context=ctx)
pm.add("linalg-specialize-generic-ops")

for pytorch_level, mlir_level in (
(kernels_as_pytorch_level1, kernels_as_mlir_level1),
(kernels_as_pytorch_level2, kernels_as_mlir_level2),
):
for kernel_pytorch_file in pytorch_level.iterdir():
level_and_kernel = (
Path(kernel_pytorch_file.parent.name) / kernel_pytorch_file.name
)
if level_and_kernel in ignore_list or not kernel_pytorch_file.is_file():
print(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

print to stderr?

f"Skipping: {kernel_pytorch_file.parent.name}/{kernel_pytorch_file.name}"
)
continue

kernel_name = kernel_pytorch_file.stem

kernel_as_mlir_path = mlir_level / (kernel_name + ".mlir")
if kernel_as_mlir_path.exists():
print(
f"Already in cache: {kernel_pytorch_file.parent.name}/{kernel_pytorch_file.name}"
)
continue
print(
f"Processing: {kernel_pytorch_file.parent.name}/{kernel_pytorch_file.name}"
)
mlir_kernel = torch_ingress.import_from_file(
kernel_pytorch_file, ir_context=ctx
)

before_clean_up = "//" + str(mlir_kernel)[:-1].replace("\n", "\n//") + "\n"
try:
pm.run(mlir_kernel.operation) # cleanup
except Exception as e:
print(f"Error: got the following error cleaning up {kernel_name}")
raise e

with kernel_as_mlir_path.open("w") as f:
print("// Torch-MLIR output:", file=f)
print(before_clean_up, file=f)
print("// MLIR output after clean-up:", file=f)
print(mlir_kernel, file=f)
72 changes: 0 additions & 72 deletions ingress/Torch-MLIR/generate-mlir.py

This file was deleted.

42 changes: 0 additions & 42 deletions ingress/Torch-MLIR/generate-mlir.sh

This file was deleted.

63 changes: 63 additions & 0 deletions python/examples/ingress/torch/01-dummy-mlp-from-file.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
"""
Example demonstrating how to load a PyTorch model to MLIR using Lighthouse
without instantiating the model on the user's side.

The script uses 'lighthouse.ingress.torch.import_from_file' function that
takes a path to a Python file containing the model definition, along with
the names of functions to get model init arguments and sample inputs. The function
imports the model class on its own, instantiates it, and passes it to torch_mlir
to get a MLIR module in the specified dialect.

The script uses the model from 'DummyMLP/model.py' as an example.
"""

import os
from pathlib import Path

# MLIR infrastructure imports (only needed if you want to manipulate the MLIR module)
import mlir.dialects.func as func
from mlir import ir, passmanager

# Lighthouse imports
from lighthouse.ingress.torch import import_from_file

# Step 1: Set up paths to locate the model definition file
script_dir = Path(os.path.dirname(os.path.abspath(__file__)))
model_path = script_dir / "DummyMLP" / "model.py"

ir_context = ir.Context()

# Step 2: Convert PyTorch model to MLIR
# Conversion step where Lighthouse:
# - Loads the DummyMLP class and instantiates it with arguments obtained from 'get_init_inputs()'
# - Calls get_sample_inputs() to get sample input tensors for shape inference
# - Converts PyTorch model to linalg-on-tensors dialect operations using torch_mlir
mlir_module_ir: ir.Module = import_from_file(
model_path, # Path to the Python file containing the model
model_class_name="DummyMLP", # Name of the PyTorch nn.Module class to convert
init_args_fn_name="get_init_inputs", # Function that returns args for model.__init__()
sample_args_fn_name="get_sample_inputs", # Function that returns sample inputs to pass to 'model(...)'
dialect="linalg-on-tensors", # Target MLIR dialect (linalg ops on tensor types)
ir_context=ir_context # MLIR context for the conversion
)

# The PyTorch model is now converted to MLIR at this point. You can now convert
# the MLIR module to a text form (e.g. 'str(mlir_module_ir)') and save it to a file.
#
# The following optional MLIR-processing steps are to give you an idea of what can
# also be done with the MLIR module.

# Step 3: Extract the main function operation from the MLIR module and print its metadata
func_op: func.FuncOp = mlir_module_ir.operation.regions[0].blocks[0].operations[0]
print(f"entry-point name: {func_op.name}")
print(f"entry-point type: {func_op.type}")

# Step 4: Apply some MLIR passes using a PassManager
pm = passmanager.PassManager(context=ir_context)
pm.add("linalg-specialize-generic-ops")
pm.add("one-shot-bufferize")
pm.run(mlir_module_ir.operation)

# Step 5: Output the final MLIR
print("\n\nModule dump after running the pipeline:")
mlir_module_ir.dump()
56 changes: 56 additions & 0 deletions python/examples/ingress/torch/02-dummy-mlp-from-model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
"""
Example demonstrating how to load an already instantiated PyTorch model
to MLIR using Lighthouse.

The script uses the 'lighthouse.ingress.torch.import_from_model' function that
takes a PyTorch model that has already been instantiated, along with its sample inputs.
The function passes the model to torch_mlir to get a MLIR module in the
specified dialect.

The script uses a model from 'DummyMLP/model.py' as an example.
"""

import torch

# MLIR infrastructure imports (only needed if you want to manipulate the MLIR module)
import mlir.dialects.func as func
from mlir import ir, passmanager

# Lighthouse imports
from lighthouse.ingress.torch import import_from_model

# Import a sample model definition
from DummyMLP.model import DummyMLP

# Step 1: Instantiate a model and prepare sample input
model = DummyMLP()
sample_input = torch.randn(1, 10)

ir_context = ir.Context()
# Step 2: Convert the PyTorch model to MLIR
mlir_module_ir: ir.Module = import_from_model(
model,
sample_args=(sample_input,),
ir_context=ir_context
)

# The PyTorch model is now converted to MLIR at this point. You can now convert
# the MLIR module to a text form (e.g. 'str(mlir_module_ir)') and save it to a file.
#
# The following optional MLIR-processing steps are to give you an idea of what can
# also be done with the MLIR module.

# Step 3: Extract the main function operation from the MLIR module and print its metadata
func_op: func.FuncOp = mlir_module_ir.operation.regions[0].blocks[0].operations[0]
print(f"entry-point name: {func_op.name}")
print(f"entry-point type: {func_op.type}")

# Step 4: Apply some MLIR passes using a PassManager
pm = passmanager.PassManager(context=ir_context)
pm.add("linalg-specialize-generic-ops")
pm.add("one-shot-bufferize")
pm.run(mlir_module_ir.operation)

# Step 5: Output the final MLIR
print("\n\nModule dump after running the pipeline:")
mlir_module_ir.dump()
Loading