diff --git a/ingress/Torch-MLIR/generate-mlir.py b/ingress/Torch-MLIR/generate-mlir.py deleted file mode 100644 index 888e6dd..0000000 --- a/ingress/Torch-MLIR/generate-mlir.py +++ /dev/null @@ -1,72 +0,0 @@ -#!/usr/bin/env python3 - -import argparse -import os -import torch -import torch.nn as nn -from torch_mlir import fx -from torch_mlir.fx import OutputType - -# Parse arguments for selecting which model to load and which MLIR dialect to generate -def parse_args(): - parser = argparse.ArgumentParser(description="Generate MLIR for Torch-MLIR models.") - parser.add_argument( - "--model", - type=str, - required=True, - help="Path to the Torch model file.", - ) - parser.add_argument( - "--dialect", - type=str, - choices=["torch", "linalg", "stablehlo", "tosa"], - default="linalg", - help="MLIR dialect to generate.", - ) - return parser.parse_args() - -# Functin to load the Torch model -def load_torch_model(model_path): - - if not os.path.exists(model_path): - raise FileNotFoundError(f"Model file {model_path} does not exist.") - - model = torch.load(model_path) - return model - -# Function to generate MLIR from the Torch model -# See: https://github.com/MrSidims/PytorchExplorer/blob/main/backend/server.py#L237 -def generate_mlir(model, dialect): - - # Convert the Torch model to MLIR - output_type = None - if dialect == "torch": - output_type = OutputType.TORCH - elif dialect == "linalg": - output_type = OutputType.LINALG - elif dialect == "stablehlo": - output_type = OutputType.STABLEHLO - elif dialect == "tosa": - output_type = OutputType.TOSA - else: - raise ValueError(f"Unsupported dialect: {dialect}") - - module = fx.export_and_import(model, "", output_type=output_type) - return module - -# Main function to execute the script -def main(): - args = parse_args() - - # Load the Torch model - model = load_torch_model(args.model) - - # Generate MLIR from the model - mlir_module = generate_mlir(model, args.dialect) - - # Print or save the MLIR module - print(mlir_module) - -# Entry point for the script -if __name__ == "__main__": - main() \ No newline at end of file diff --git a/ingress/Torch-MLIR/generate-mlir.sh b/ingress/Torch-MLIR/generate-mlir.sh deleted file mode 100755 index 0a079c6..0000000 --- a/ingress/Torch-MLIR/generate-mlir.sh +++ /dev/null @@ -1,42 +0,0 @@ -#!/usr/bin/env bash - -# Command line argument for model to load and MLIR dialect to generate -while getopts "m:d:" opt; do - case $opt in - m) - MODEL=$OPTARG - ;; - d) - DIALECT=$OPTARG - ;; - *) - echo "Usage: $0 [-m model] [-d dialect]" - exit 1 - ;; - esac -done -if [ -z "$MODEL" ]; then - echo "Model not specified. Please provide a model using -m option." - exit 1 -fi -if [ -z "$DIALECT" ]; then - DIALECT="linalg" -fi - -# Enable local virtualenv created by install-virtualenv.sh -if [ ! -d "torch-mlir-venv" ]; then - echo "Virtual environment not found. Please run install-virtualenv.sh first." - exit 1 -fi -source torch-mlir-venv/bin/activate - -# Find script directory -SCRIPT_DIR=$(dirname "$(readlink -f "$0")") - -# Use the Python script to generate MLIR -echo "Generating MLIR for model '$MODEL' with dialect '$DIALECT'..." -python $SCRIPT_DIR/generate-mlir.py --model "$MODEL" --dialect "$DIALECT" -if [ $? -ne 0 ]; then - echo "Failed to generate MLIR for model '$MODEL'." - exit 1 -fi diff --git a/python/examples/ingress/torch/MLPModel/model.py b/python/examples/ingress/torch/MLPModel/model.py new file mode 100644 index 0000000..5bbdbbe --- /dev/null +++ b/python/examples/ingress/torch/MLPModel/model.py @@ -0,0 +1,28 @@ +"""Defines a simple PyTorch model to be used in lighthouse's ingress examples.""" + +import torch +import torch.nn as nn + +import os + +class MLPModel(nn.Module): + def __init__(self): + super().__init__() + self.net = nn.Sequential( + nn.Linear(10, 32), + nn.ReLU(), + nn.Linear(32, 2) + ) + + def forward(self, x): + return self.net(x) + + +def get_init_inputs(): + """Function to return args to pass to MLPModel.__init__()""" + return () + + +def get_sample_inputs(): + """Arguments to pass to MLPModel.forward()""" + return (torch.randn(1, 10),) diff --git a/python/examples/ingress/torch/mlp_from_file.py b/python/examples/ingress/torch/mlp_from_file.py new file mode 100644 index 0000000..01edeec --- /dev/null +++ b/python/examples/ingress/torch/mlp_from_file.py @@ -0,0 +1,61 @@ +""" +Example demonstrating how to load a PyTorch model to MLIR using Lighthouse +without initializing the model class on the user's side. + +The script uses 'lighthouse.ingress.torch.import_from_file' function that +takes a path to a Python file containing the model definition (a Python class derived from 'nn.Module'), +along with the names of functions to get model init arguments and sample inputs. The function +imports the model class on its own, initializes it, and passes it to torch_mlir +to get a MLIR module in the specified dialect. + +The script uses the model from 'MLPModel/model.py' as an example. +""" + +import os +from pathlib import Path + +# MLIR infrastructure imports (only needed if you want to manipulate the MLIR module) +import mlir.dialects.func as func +from mlir import ir + +# Lighthouse imports +from lighthouse.ingress.torch import import_from_file + +# Step 1: Set up paths to locate the model definition file +script_dir = Path(os.path.dirname(os.path.abspath(__file__))) +model_path = script_dir / "MLPModel" / "model.py" + +ir_context = ir.Context() + +# Step 2: Convert PyTorch model to MLIR +# Conversion step where Lighthouse: +# - Loads the MLPModel class and instantiates it with arguments obtained from 'get_init_inputs()' +# - Calls get_sample_inputs() to get sample input tensors for shape inference +# - Converts PyTorch model to linalg-on-tensors dialect operations using torch_mlir +mlir_module_ir: ir.Module = import_from_file( + model_path, # Path to the Python file containing the model + model_class_name="MLPModel", # Name of the PyTorch nn.Module class to convert + init_args_fn_name="get_init_inputs", # Function that returns args for model.__init__() + sample_args_fn_name="get_sample_inputs", # Function that returns sample inputs to pass to 'model(...)' + dialect="linalg-on-tensors", # Target MLIR dialect (linalg ops on tensor types) + ir_context=ir_context # MLIR context for the conversion +) + +# The PyTorch model is now converted to MLIR at this point. You can now convert +# the MLIR module to a text form (e.g. 'str(mlir_module_ir)') and save it to a file. +# +# The following optional MLIR-processing steps are to give you an idea of what can +# also be done with the MLIR module. + +# Step 3: Extract the main function operation from the MLIR module and print its metadata +func_op: func.FuncOp = mlir_module_ir.operation.regions[0].blocks[0].operations[0] +print(f"entry-point name: {func_op.name}") +print(f"entry-point type: {func_op.type}") + +# Step 4: output the imported MLIR module +print("\n\nModule dump:") +mlir_module_ir.dump() + +# You can alternatively write the MLIR module to a file: +# with open("output.mlir", "w") as f: +# f.write(str(mlir_module_ir)) diff --git a/python/examples/ingress/torch/mlp_from_model.py b/python/examples/ingress/torch/mlp_from_model.py new file mode 100644 index 0000000..b51af80 --- /dev/null +++ b/python/examples/ingress/torch/mlp_from_model.py @@ -0,0 +1,54 @@ +""" +Example demonstrating how to load an already initialized PyTorch model +to MLIR using Lighthouse. + +The script uses the 'lighthouse.ingress.torch.import_from_model' function that +takes an initialized PyTorch model (an instance of a Python class derived from 'nn.Module'), +along with its sample inputs. The function passes the model to torch_mlir +to get a MLIR module in the specified dialect. + +The script uses a model from 'MLPModel/model.py' as an example. +""" + +import torch + +# MLIR infrastructure imports (only needed if you want to manipulate the MLIR module) +import mlir.dialects.func as func +from mlir import ir + +# Lighthouse imports +from lighthouse.ingress.torch import import_from_model + +# Import a sample model definition +from MLPModel.model import MLPModel + +# Step 1: Instantiate a model class and prepare sample input +model = MLPModel() +sample_input = torch.randn(1, 10) + +ir_context = ir.Context() +# Step 2: Convert the PyTorch model to MLIR +mlir_module_ir: ir.Module = import_from_model( + model, + sample_args=(sample_input,), + ir_context=ir_context +) + +# The PyTorch model is now converted to MLIR at this point. You can now convert +# the MLIR module to a text form (e.g. 'str(mlir_module_ir)') and save it to a file. +# +# The following optional MLIR-processing steps are to give you an idea of what can +# also be done with the MLIR module. + +# Step 3: Extract the main function operation from the MLIR module and print its metadata +func_op: func.FuncOp = mlir_module_ir.operation.regions[0].blocks[0].operations[0] +print(f"entry-point name: {func_op.name}") +print(f"entry-point type: {func_op.type}") + +# Step 4: output the imported MLIR module +print("\n\nModule dump:") +mlir_module_ir.dump() + +# You can alternatively write the MLIR module to a file: +# with open("output.mlir", "w") as f: +# f.write(str(mlir_module_ir)) diff --git a/python/lighthouse/ingress/README.md b/python/lighthouse/ingress/README.md new file mode 100644 index 0000000..61236ef --- /dev/null +++ b/python/lighthouse/ingress/README.md @@ -0,0 +1,10 @@ +# Lighthouse Ingress + +The `lighthouse.ingress` module converts various input formats to MLIR modules. + +## Supported Formats + +#### Torch +Converts PyTorch models to MLIR using `lighthouse.ingress.torch`. + +**Examples:** [torch examples](https://github.com/llvm/lighthouse/tree/main/python/examples/ingress/torch) diff --git a/python/lighthouse/ingress/__init__.py b/python/lighthouse/ingress/__init__.py new file mode 100644 index 0000000..2aa3859 --- /dev/null +++ b/python/lighthouse/ingress/__init__.py @@ -0,0 +1 @@ +"""Provides functions to convert source objects (code, models, designs) into MLIR files that the MLIR project can consume""" diff --git a/python/lighthouse/ingress/torch/__init__.py b/python/lighthouse/ingress/torch/__init__.py new file mode 100644 index 0000000..d73f426 --- /dev/null +++ b/python/lighthouse/ingress/torch/__init__.py @@ -0,0 +1,3 @@ +"""Provides functions to convert PyTorch models to MLIR.""" + +from .importer import import_from_file, import_from_model diff --git a/python/lighthouse/ingress/torch/importer.py b/python/lighthouse/ingress/torch/importer.py new file mode 100644 index 0000000..87c7655 --- /dev/null +++ b/python/lighthouse/ingress/torch/importer.py @@ -0,0 +1,226 @@ +import importlib +import importlib.util +from pathlib import Path +from typing import Iterable, Mapping + +from lighthouse.ingress.torch.utils import load_and_run_callable, maybe_load_and_run_callable + +try: + import torch + import torch.nn as nn +except ImportError as e: + raise ImportError( + "PyTorch is required to use the torch import functionality. " + "Make sure to install ingress-torch dependencies e.g. 'uv sync --extra ingress-torch-cpu'" + ) from e + +try: + from torch_mlir import fx + from torch_mlir.fx import OutputType +except ImportError as e: + raise ImportError( + "torch-mlir is required to use the torch import functionality. " + "Make sure to install ingress-torch dependencies e.g. 'uv sync --extra ingress-torch-cpu'" + ) from e + +from mlir import ir + +def import_from_model( + model: nn.Module, + sample_args: Iterable, + sample_kwargs: Mapping = None, + dialect: OutputType | str = OutputType.LINALG_ON_TENSORS, + ir_context: ir.Context | None = None, + **kwargs, +) -> str | ir.Module: + """Import a PyTorch nn.Module into MLIR. + + The function uses torch-mlir's FX importer to convert the given PyTorch model + into an MLIR module in the specified dialect. The user has to provide sample + input arguments (e.g. a torch.Tensor with the correct shape). + + Args: + model (nn.Module): The PyTorch model to import. + sample_args (Iterable): Sample input arguments to the model. + sample_kwargs (Mapping, optional): Sample keyword arguments to the model. + dialect (torch_mlir.fx.OutputType | {"linalg-on-tensors", "torch", "tosa"}): + The target dialect for the imported MLIR module. Defaults to + ``OutputType.LINALG_ON_TENSORS``. + ir_context (ir.Context, optional): An optional MLIR context to use for parsing + the module. If not provided, the module is returned as a string. + **kwargs: Additional keyword arguments passed to the ``torch_mlir.fx.export_and_import`` function. + + Returns: + str | ir.Module: The imported MLIR module as a string or an ir.Module if `ir_context` is provided. + + Examples: + >>> import torch + >>> import torch.nn as nn + >>> from lighthouse.ingress.torch_import import import_from_model + >>> class SimpleModel(nn.Module): + ... def __init__(self): + ... super().__init__() + ... self.fc = nn.Linear(10, 5) + ... def forward(self, x): + ... return self.fc(x) + >>> model = SimpleModel() + >>> sample_input = (torch.randn(1, 10),) + >>> # + >>> # option 1: get MLIR module as a string + >>> mlir_module : str = import_from_model(model, sample_input, dialect="linalg-on-tensors") + >>> print(mlir_module) # prints the MLIR module in linalg-on-tensors dialect + >>> # option 2: get MLIR module as an ir.Module + >>> ir_context = ir.Context() + >>> mlir_module_ir : ir.Module = import_from_model(model, sample_input, dialect="tosa", ir_context=ir_context) + """ + if dialect == "linalg": + raise ValueError( + "Dialect 'linalg' is not supported. Did you mean 'linalg-on-tensors'?" + ) + + if sample_kwargs is None: + sample_kwargs = {} + + model.eval() + module = fx.export_and_import( + model, *sample_args, output_type=dialect, **sample_kwargs, **kwargs + ) + + text_module = str(module) + if ir_context is None: + return text_module + # Cross boundary from torch-mlir's mlir to environment's mlir + return ir.Module.parse(text_module, context=ir_context) + + +def import_from_file( + filepath: str | Path, + model_class_name: str = "Model", + init_args_fn_name: str | None = "get_init_inputs", + init_kwargs_fn_name: str | None = None, + sample_args_fn_name: str = "get_inputs", + sample_kwargs_fn_name: str | None = None, + state_path: str | Path | None = None, + dialect: OutputType | str = OutputType.LINALG_ON_TENSORS, + ir_context: ir.Context | None = None, + **kwargs, +) -> str | ir.Module: + """Load a PyTorch nn.Module from a file and import it into MLIR. + + The function takes a `filepath` to a Python file containing the model definition, + along with the names of functions to get model init arguments and sample inputs. + The function imports the model class on its own, instantiates it, and passes + it to ``torch_mlir`` to get a MLIR module in the specified `dialect`. + + Args: + filepath (str | Path): Path to the Python file containing the model definition. + model_class_name (str, optional): The name of the model class in the file. + Defaults to "Model". + init_args_fn_name (str | None, optional): The name of the function in the file + that returns the arguments for initializing the model. If None, the model + is initialized without arguments. Defaults to "get_init_inputs". + init_kwargs_fn_name (str | None, optional): The name of the function in the file + that returns the keyword arguments for initializing the model. If None, the model + is initialized without keyword arguments. + sample_args_fn_name (str, optional): The name of the function in the file that + returns the sample input arguments for the model. Defaults to "get_inputs". + sample_kwargs_fn_name (str, optional): The name of the function in the file that + returns the sample keyword input arguments for the model. Defaults to None. + state_path (str | Path | None, optional): Optional path to a file containing + the model's ``state_dict``. Defaults to None. + dialect (torch_mlir.fx.OutputType | {"linalg-on-tensors", "torch", "tosa"}, optional): + The target dialect for the imported MLIR module. Defaults to + ``OutputType.LINALG_ON_TENSORS``. + ir_context (ir.Context, optional): An optional MLIR context to use for parsing + the module. If not provided, the module is returned as a string. + **kwargs: Additional keyword arguments passed to the ``torch_mlir.fx.export_and_import`` function. + + Returns: + str | ir.Module: The imported MLIR module as a string or an ir.Module if `ir_context` is provided. + + Examples: + Given a file `path/to/model_file.py` with the following content: + ```python + import torch + import torch.nn as nn + + class MyModel(nn.Module): + def __init__(self): + super().__init__() + self.fc = nn.Linear(10, 5) + def forward(self, x): + return self.fc(x) + + def get_inputs(): + return (torch.randn(1, 10),) + ``` + + The import script would look like: + >>> from lighthouse.ingress.torch_import import import_from_file + >>> # option 1: get MLIR module as a string + >>> mlir_module : str = import_from_file( + ... "path/to/model_file.py", + ... model_class_name="MyModel", + ... init_args_fn_name=None, + ... dialect="linalg-on-tensors" + ... ) + >>> print(mlir_module) # prints the MLIR module in linalg-on-tensors dialect + >>> # option 2: get MLIR module as an ir.Module + >>> ir_context = ir.Context() + >>> mlir_module_ir : ir.Module = import_from_file( + ... "path/to/model_file.py", + ... model_class_name="MyModel", + ... init_args_fn_name=None, + ... dialect="linalg-on-tensors", + ... ir_context=ir_context + ... ) + """ + if isinstance(filepath, str): + filepath = Path(filepath) + module_name = filepath.stem + + spec = importlib.util.spec_from_file_location(module_name, filepath) + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + + model = getattr(module, model_class_name, None) + if model is None: + raise ValueError(f"Model class '{model_class_name}' not found in {filepath}") + + model_init_args = maybe_load_and_run_callable( + module, + init_args_fn_name, + default=tuple(), + error_msg=f"Init args function '{init_args_fn_name}' not found in {filepath}" + ) + model_init_kwargs = maybe_load_and_run_callable( + module, + init_kwargs_fn_name, + default={}, + error_msg=f"Init kwargs function '{init_kwargs_fn_name}' not found in {filepath}" + ) + sample_args = load_and_run_callable( + module, + sample_args_fn_name, + f"Sample args function '{sample_args_fn_name}' not found in {filepath}" + ) + sample_kwargs = maybe_load_and_run_callable( + module, + sample_kwargs_fn_name, + default={}, + error_msg=f"Sample kwargs function '{sample_kwargs_fn_name}' not found in {filepath}" + ) + + nn_model: nn.Module = model(*model_init_args, **model_init_kwargs) + if state_path is not None: + state_dict = torch.load(state_path) + nn_model.load_state_dict(state_dict) + + return import_from_model( + nn_model, + sample_args=sample_args, + sample_kwargs=sample_kwargs, + dialect=dialect, + ir_context=ir_context, + **kwargs, + ) diff --git a/python/lighthouse/ingress/torch/utils.py b/python/lighthouse/ingress/torch/utils.py new file mode 100644 index 0000000..9a464b3 --- /dev/null +++ b/python/lighthouse/ingress/torch/utils.py @@ -0,0 +1,50 @@ +from types import ModuleType +from typing import Any + + +def load_and_run_callable( + module: ModuleType, + symbol_name: str, + error_msg: str | None = None, +): + """Helper to load and run a callable from a module by its symbol name. + + Args: + module (ModuleType): The python module to load the callable from. + symbol_name (str): The name of the callable symbol to load. + error_msg (str | None): Custom error message to use when raising an error + for missing symbol. If not provided, a default message will be used. + + Returns: + Any: The result of calling the loaded callable. + """ + func = getattr(module, symbol_name, None) + if func is None: + if error_msg: + raise ValueError(error_msg) + raise ValueError( + f"Symbol '{symbol_name}' not found in module '{module.__name__}'" + ) + if not callable(func): + raise ValueError(f"Symbol '{symbol_name}' is not callable") + return func() + + +def maybe_load_and_run_callable( + module: ModuleType, + symbol_name: str | None, + default: Any, + error_msg: str | None = None, +): + """Helper to conditionally load and run a callable from a module by its symbol name. + + If `symbol_name` is None, the function returns the provided default value. Otherwise + it calls ``load_and_run_callable`` with the provided arguments. + """ + if symbol_name is None: + return default + return load_and_run_callable( + module, + symbol_name, + error_msg=error_msg + )