Skip to content

Conversation

dchigarev
Copy link
Contributor

@dchigarev dchigarev commented Sep 17, 2025

The PR adds two utility functions as part of lighthouse's python package to convert torch models to a mlir module using torch_mlir.

A user can import one of the functions (import_from_model or import_from_file) and get a mlir.ir.Module that they can use to run passes on or simply write its content into a file.

some use cases

1. Import from an instance of a model:

from lighthouse.ingress.torch import import_from_model
from mlir import ir

ctx = ir.Context()

module : ir.Module = import_from_model(torch_model_instance, sample_args=(torch.rand(1, 10),), ir_context=ctx)

# can now run some passes on the module

2. Import from a file where a torch model is defined:

Imagine we want to import a model from KernelBench. They ship models as python files where models and their arguments are uniformly defined.

from lighthouse.ingress.torch import import_from_file
from mlir import ir

ctx = ir.Context()
kernel_bench_root = Path(...)

module : ir.Module = import_from_model(
    filepath=kernel_bench_root / "level1" / "10_3D_tensor_matrix_multiplication.py",
    ir_context=ctx
)

# can now run some passes on the module

The utility functions use torch_mlir and mlir installed in the python env following #6 (and not from lighthouses bindings as suggested in #3).

Copy link
Contributor

@rolfmorel rolfmorel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @dchigarev for making progress on PyTorch ingress!

In general it all seems to make sense to me! My comments are on relatively small matters.

Having said that, I would say the majority of the PR is on enabling the cmdline interface, which I expect to also be the most contentious. Personally, I am not a fan of such interfaces and prefer the scripting approach. If other people are in favour though, I am not opposed for the code to be included.

Do you happen to have examples of similar cmdline interfaces being used for enabling PyTorch lowerings in other projects?

@dchigarev
Copy link
Contributor Author

Do you happen to have examples of similar cmdline interfaces being used for enabling PyTorch lowerings in other projects?

@rolfmorel thanks for your time and feedback!

No, I haven't seen such cmdline approach anywhere (I wasn't looking to deep though). On the surface of IREE's and Blade's documentation I could only found the user-script approach. So even if they have a cmdline option, they don't seem to promote it very well.

@banach-space
Copy link

This is great, thank you so much for working on this 🙏🏻

I have a few high-level suggestions.

Keep this PR simple and restrict to the required minimum.

The cmdline interface looks complex and is merely a "wrapper" for the script logic. We can't avoid having a script, but we can avoid the cmdline interface. And, with a complex cmdline interface like this, I would wrap it into yet another script. My suggestion - drop the interface for now. This will allow us to focus on the core logic instead.

Consistent filenames and hyphenation.

generate-mlir.py vs py_src vs dummy_mlp_factory.py vs export_bash.py. LLVM seems to prefer - over _. Whichever one we choose, lets use it consistently.

Use doctoring consistently.

Lets use (function + module) docstrings consistently (instead of mixing docstring and plain Python comments starting with #).

Do we need all the Bash scripts?

There's seems to be a fair bit of duplication, e.g. export_bash.sh vs export_py.sh vs export.py. It's not clear to me what all the scripts do and whether we need them. My suggestion - less is more.

Naming.

This PR modifies the torch-mlir/generate-mlir.py -> torch-mlir/py_src/main.py

IIUC, generate-mlir.py was misleading - no MLIR is generated. Instead, the script "exports" MLIR, right? To me, a generator would be something like https://github.com/libxsmm/tpp-mlir/blob/main/tools/mlir-gen/mlir-gen.cpp.

While main.py is an improvement (i.e. not misleading), it's a bit too enigmatic - why not export.py? Or export-mlir-from-pytorch.mlir? Basically, something descriptive. That said, naming is hard 🤷🏻

Final thoughts.

Really fantastic to see this, just a bit concerned that this PR is trying to achieve too many things in one go. I recommend trimming it - I'd much rather focus on the core part and also make sure that we establish a consistent way of naming, structuring and implementing things.

I've some other, more specific comments inline.

Thanks again for working on this! 🙏🏻



def generate_mlir(model, sample_args, sample_kwargs=None, dialect="linalg"):
# Convert the Torch model to MLIR

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we use docstrings consistently throughout this project?

@@ -0,0 +1,16 @@
#!/usr/bin/env bash

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

DOCUMENTME - what is the purpose of this script and how do I use it?

@Groverkss
Copy link
Member

Building wrapper scripts around torch-mlir is not scalable at all. torch-mlir is not a library to build things with, not a tool to build scripts around. The proper way of doing this is shipping fx_importer as part of bindings: #3 (ready for review) and then building export over it and ship it as part of the python package. I'm going to send a pr on building an aot export for torch and onnx around that today to give an idea of how it should be done.

@rolfmorel
Copy link
Contributor

rolfmorel commented Oct 8, 2025

Given the distinct though all ingress-related PRs currently up, I thought that delineating their separate purposes might be helpful:

  • Add python bindings for lighthouse #3
    • -> Set up a build system / initialization approach for making both mlir and torch-mlir available to be used from Python, so as to enable ingress and lowering
    • IMO, this would initially just enable ingress and lowering as a bunch of Python programs, i.e. all we need is that the Python packages are available when the ingress and compiler scripts get called.
  • [ingress][torch-mlir][RFC] Initial version of fx-importer script using torch-mlir #4
    • -> Develop "the one" conversion script for going from PyTorch to mlir.
    • IMO, as the cmdline interface is more controversial, this is probably best factored out to a separate PR so that the Python script-based approach can be merged.
  • [ingress][pytorch] Basic KernelBench to MLIR conversion #5
    • -> One source of PyTorch kernels that we want to use to demonstrate working pipelines on.
    • Currently invokes the couple of lines of Python necessary to get torch-mlir to do the conversion, though should use the general importer/conversion mechanism that #4 will provide and rely on #3 to make sure the dependencies are taken care off (excepting KernelBench, which is a distinct kind of dependency in need of a general mechanism for how deal with them).

Regarding the interaction between the importer script and the scripts that deal with input sources: My feeling is that providing a little python package that can be used as a utility by the separate input-processing scripts might be most helpful.

  • For example, from lighthouse.ingress import pytorch_module_converter could be used in, e.g. #5, to replace the conversion code in that script, e.g. if its signature was pytorch_module_converter(python_module: str | module) -> mlir.ir.Module (where the str type is for paths to scripts with a certain interface and the (Python) module for already imported modules with that same interface). The "interface" of the python_module could be like that of KernelBench, i.e. the module should have attributes Model: torch.nn.Module and gen_inputs: () -> List[torch.Tensor] and gen_init_inputs: () -> List[Any] (i.e. the "standard" arguments for Model.__init__).

As the PyTorch and torch-mlir libs need to live in the same process anyway, I do not see much benefit coming from trying to separate out the importer/converter code to a script that actually runs in a separate process from the code that deals with the input sources.

@dchigarev dchigarev force-pushed the dchigarev/fx_importer branch from a08bcc5 to b984314 Compare October 20, 2025 10:05
@dchigarev
Copy link
Contributor Author

dchigarev commented Oct 20, 2025

@rolfmorel

I've updated the PR following your suggestions. Lighthouse now has a python/lighthouse/* folder that is supposed to be an installable python package (will be installable after #7) that provides utility ingress functions (e.g. import_from_file/import_from_model).

p.s.
The conversion still uses torch_mlir from the user env (as inspired by #6).

@dchigarev dchigarev marked this pull request as ready for review October 20, 2025 10:34
Signed-off-by: dchigarev <[email protected]>
Signed-off-by: dchigarev <[email protected]>
Copy link
Contributor

@rolfmorel rolfmorel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Many thanks for the revision (and the commitment to the PR)! It is looking good!

Have left a number of minor comments. I will soon try to rebase #5 on this branch and confirm that that works as expected. Happy to approve once both are sorted 👍

Comment on lines 51 to 53
ir_context : ir.Context, optional
An optional MLIR context to use for parsing the module.
If not provided, the module is returned as a string.
Copy link
Contributor

@rolfmorel rolfmorel Oct 20, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just to leave a note: this is somewhat surprising to me, though as I don't currently have a better suggestion to expose the same functionality (i.e. returning before conversion to environment's mlir) I am okay with it.

rolfmorel pushed a commit that referenced this pull request Oct 21, 2025
The PR modifies `pyproject.toml` to make the content of `python/lighthouse` to be installable via `uv pip install .` For now the package is empty, but after #4 is merged users would be able to access ingress helper functions as part of the package:
```python
from lighthouse.ingress.torch import import_from_file

...
```

* install lighthouse on 'uv sync'

Signed-off-by: dchigarev <[email protected]>

* use dynamic version in pyproject.toml

Signed-off-by: dchigarev <[email protected]>

---------

Signed-off-by: dchigarev <[email protected]>
inputs_args_fn = getattr(module, inputs_args_fn_name, None)
if inputs_args_fn is None:
raise ValueError(f"Inputs args function '{inputs_args_fn_name}' not found in {filepath}")
model_init_args = maybe_load_and_run_callable(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe the following works and would mean a bit less abstraction:

try:
     model_init_args = init_args_fn_name and getattr(module, init_args_fn_name)() or tuple()
except AttributeError:
     raise ValueError(f"Init args function '{init_args_fn_name}' not found in {filepath}")

I know not everyone will find using the boolean operators in this way intuitive though (and technically it's not completely right, in the sense that returned falsity values will get replaced by tuple()). Nonetheless thought to suggest it.

Copy link
Contributor

@rolfmorel rolfmorel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @dchigarev ! That's it for my comments. IMO this is looking great.

Am approving now, though with the caveat that I will still rebase #5 on this branch before we merge this PR (to check that the API works there as expected). Will report back on that by EOD tomorrow at the latest.

@banach-space, would you like to give this a final pass before it goes in?

Signed-off-by: dchigarev <[email protected]>
@rolfmorel
Copy link
Contributor

Can confirm that this is working as expected for #5: https://github.com/llvm/lighthouse/pull/5/files/0912de15b458a78a88f111045fe0e54618ae83a1..63b82406443cdd449bbb1100a1639853b7417160

Barring one or two outstanding comments, this is go to IMO 👍

Copy link

@banach-space banach-space left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the updates and for pushing on this 🙏🏻

Approving as is - looks very clean and clear. I have some suggestions for minor improvements, but this is already great, so feel free to ignore.

If folks agree with my suggestions but have no bandwidth for PRs, I can upload something myself. Thanks!

# Step 4: Apply some MLIR passes using a PassManager
pm = passmanager.PassManager(context=ir_context)
pm.add("linalg-specialize-generic-ops")
pm.add("one-shot-bufferize")

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we bufferize? Bufferization is quite an involved transformation and IMHO, we should only do the bare minimum here. Specifically, these are two orthogonal things to me:

  • importing a PyTorch model into MLIR,
  • running transformations on MLIR.

WDYT? Thinking in terms of "separation of concerns".

Copy link

@adam-smnk adam-smnk Oct 22, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1 to removing bufferization specifically. It can have many flavors and generally we want to stay at tensor longer.

OTOH, it's just an arbitrary example so, it's fine. Alternatively, an extra comment spelling out the message or motivation here could help to clarify intent.

Copy link
Contributor Author

@dchigarev dchigarev Oct 22, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right, I also think that applying bufferization in an ingress-example could be too much :)

But do you think we should remove the PassManager case from the ingress-examples completely? I also believe that ingress and running a pipeline are two separate things, we could leave a hint to the users though in form of an in-code comment on what to do next with an imported mlir module, e.g.

....
# Step 4: output the imported MLIR module
print("\n\nModule dump after running the pipeline:")
mlir_module_ir.dump()

# You can alternatively write the MLIR module to a file:
# with open("output.mlir", "w") as f:
#     f.write(str(mlir_module_ir))
#
# Or apply some MLIR passes using a PassManager:
# pm = passmanager.PassManager(context=ir_context)
# pm.add("linalg-specialize-generic-ops")
# pm.add(...)
# pm.run(mlir_module_ir.operation)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It'd make sense to completely skip these parts and focus only on ingress.
We can make other examples that focus on lowering later.

Copy link
Contributor Author

@dchigarev dchigarev Oct 22, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree, let's remove it completely

Comment on lines 2 to 3
Example demonstrating how to load an already instantiated PyTorch model
to MLIR using Lighthouse.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The difference between this file and 01-dummy-mlir-from-model.py is "instantiation" (here the model is "already instantiated"). But what does "model instantiation" mean? Genuine question - I think that it would be good to capture this somewhere :)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

by "instantiation" I meant an instantiation (creation) of the model's class :)

changed to "model initialization", means the same but sound simplier

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that "model" is just a bit too overloaded term and that's a potential source of confusion (I remember getting quite confused first time I played with PyTorch).

[nit] Could you specify that you mean the PyTorch "model" (class)?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tried to make the docstring clearer on what pytorch model means

Copy link

@adam-smnk adam-smnk left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great work 😎

print(f"entry-point name: {func_op.name}")
print(f"entry-point type: {func_op.type}")

# Step 4: Apply some MLIR passes using a PassManager

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Going to category could be a more useful or more broadly applicable default but TBD

# Step 4: Apply some MLIR passes using a PassManager
pm = passmanager.PassManager(context=ir_context)
pm.add("linalg-specialize-generic-ops")
pm.add("one-shot-bufferize")
Copy link

@adam-smnk adam-smnk Oct 22, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1 to removing bufferization specifically. It can have many flavors and generally we want to stay at tensor longer.

OTOH, it's just an arbitrary example so, it's fine. Alternatively, an extra comment spelling out the message or motivation here could help to clarify intent.

@dchigarev dchigarev requested a review from adam-smnk October 22, 2025 09:05
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants