diff --git a/ingress/Torch-MLIR/convert-kernel-bench-to-mlir.py b/ingress/Torch-MLIR/convert-kernel-bench-to-mlir.py new file mode 100755 index 0000000..87b8714 --- /dev/null +++ b/ingress/Torch-MLIR/convert-kernel-bench-to-mlir.py @@ -0,0 +1,148 @@ +#!/usr/bin/env python3 + +from pathlib import Path + +from mlir import ir, passmanager +from lighthouse.ingress import torch as torch_ingress + + +kernels_as_pytorch_folder = Path(__file__).parent / "KernelBench" / "KernelBench" +kernels_as_pytorch_level1 = kernels_as_pytorch_folder / "level1" +kernels_as_pytorch_level2 = kernels_as_pytorch_folder / "level2" + +kernels_as_mlir_folder = Path(__file__).parent / "cache" +kernels_as_mlir_level1 = kernels_as_mlir_folder / "level1" +kernels_as_mlir_level1.mkdir(parents=True, exist_ok=True) +kernels_as_mlir_level2 = kernels_as_mlir_folder / "level2" +kernels_as_mlir_level2.mkdir(parents=True, exist_ok=True) + +level1, level2 = Path("level1"), Path("level2") +ignore_list = [ + level1 / "12_Matmul_with_diagonal_matrices_.py", # torch.operator "torch.aten.diag" + level1 + / "34_InstanceNorm.py", # LLVM ERROR: SmallVector unable to grow. Requested capacity (93898875033000) + level1 + / "72_conv_transposed_3D_asymmetric_input_asymmetric_kernel___strided_padded_grouped_.py", # Bare exception during torch-backend-to-linalg-on-tensors-backend-pipeline + level1 + / "89_cumsum.py", # Dialect `tm_tensor' not found for custom op 'tm_tensor.scan' + level1 + / "90_cumprod.py", # Dialect `tm_tensor' not found for custom op 'tm_tensor.scan' + level1 + / "91_cumsum_reverse.py", # Dialect `tm_tensor' not found for custom op 'tm_tensor.scan' + level1 + / "92_cumsum_exclusive.py", # Dialect `tm_tensor' not found for custom op 'tm_tensor.scan' + level1 + / "93_masked_cumsum.py", # Dialect `tm_tensor' not found for custom op 'tm_tensor.scan' + level1 + / "95_CrossEntropyLoss.py", # Bare exception during torch-backend-to-linalg-on-tensors-backend-pipeline + level1 + / "96_HuberLoss.py", # Bare exception during torch-backend-to-linalg-on-tensors-backend-pipeline + level1 + / "97_ScaledDotProductAttention.py", # AssertionError: Torch not compiled with CUDA enabled + level1 + / "99_TripletMarginLoss.py", # Bare exception during torch-backend-to-linalg-on-tensors-backend-pipeline + level2 + / "17_Conv2d_InstanceNorm_Divide.py", # LLVM ERROR: SmallVector unable to grow. Requested capacity (94899412484104) + level2 + / "18_Matmul_Sum_Max_AvgPool_LogSumExp_LogSumExp.py", # error: failed to legalize operation 'torch.constant.int' + level2 + / "22_Matmul_Scale_ResidualAdd_Clamp_LogSumExp_Mish.py", # error: failed to legalize operation 'torch.constant.int' + level2 + / "28_BMM_InstanceNorm_Sum_ResidualAdd_Multiply.py", # LLVM ERROR: SmallVector unable to grow. Requested capacity (94899412484104) + level2 + / "42_ConvTranspose2d_GlobalAvgPool_BiasAdd_LogSumExp_Sum_Multiply.py", # error: failed to legalize operation 'torch.constant.int' + level2 + / "43_Conv3d_Max_LogSumExp_ReLU.py", # error: failed to legalize operation 'torch.constant.int' + level2 + / "45_Gemm_Sigmoid_LogSumExp.py", # error: failed to legalize operation 'torch.constant.int' + level2 + / "51_Gemm_Subtract_GlobalAvgPool_LogSumExp_GELU_ResidualAdd.py", # error: failed to legalize operation 'torch.constant.int' + level2 + / "52_Conv2d_Activation_BatchNorm.py", # failed to legalize operation 'torch.operator' + level2 / "55_Matmul_MaxPool_Sum_Scale.py", # MLIR file too big: 16G + level2 / "59_Matmul_Swish_Scaling.py", # MLIR file too big: 16G + level2 / "56_Matmul_Sigmoid_Sum.py", # MLIR file too big: 16G + level2 / "66_Matmul_Dropout_Softmax.py", # MLIR file too big: 4G + level2 / "68_Matmul_Min_Subtract.py", # MLIR file too big: 4G + level2 / "94_Gemm_BiasAdd_Hardtanh_Mish_GroupNorm.py", # MLIR file too big: 1G + level2 / "33_Gemm_Scale_BatchNorm.py", # MLIR file too big: 1G + level2 / "88_Gemm_GroupNorm_Swish_Multiply_Swish.py", # MLIR file too big: 1G + level2 / "75_Gemm_GroupNorm_Min_BiasAdd.py", # MLIR file too big: 1G + level2 / "84_Gemm_BatchNorm_Scaling_Softmax.py", # MLIR file too big: 1G + level2 / "97_Matmul_BatchNorm_BiasAdd_Divide_Swish.py", # MLIR file too big: 1G + level2 / "62_Matmul_GroupNorm_LeakyReLU_Sum.py", # MLIR file too big: 1G + level2 / "30_Gemm_GroupNorm_Hardtanh.py", # MLIR file too big: 1G + level2 / "95_Matmul_Add_Swish_Tanh_GELU_Hardtanh.py", # MLIR file too big: 1G + level2 / "29_Matmul_Mish_Mish.py", # MLIR file too big: 1G + level2 / "99_Matmul_GELU_Softmax.py", # MLIR file too big: 1G + level2 / "98_Matmul_AvgPool_GELU_Scale_Max.py", # MLIR file too big: 1G + level2 / "80_Gemm_Max_Subtract_GELU.py", # MLIR file too big: 1G + level2 / "81_Gemm_Swish_Divide_Clamp_Tanh_Clamp.py", # MLIR file too big: 1G + level2 / "12_Gemm_Multiply_LeakyReLU.py", # MLIR file too big: 1G + level2 / "53_Gemm_Scaling_Hardtanh_GELU.py", # MLIR file too big: 1G + level2 / "9_Matmul_Subtract_Multiply_ReLU.py", # MLIR file too big: 1G + level2 / "70_Gemm_Sigmoid_Scaling_ResidualAdd.py", # MLIR file too big: 1G + level2 / "86_Matmul_Divide_GELU.py", # MLIR file too big: 1G + level2 / "63_Gemm_ReLU_Divide.py", # MLIR file too big: 1G + level2 / "76_Gemm_Add_ReLU.py", # MLIR file too big: 1G + level2 / "14_Gemm_Divide_Sum_Scaling.py", # MLIR file too big: 1G + level2 / "39_Gemm_Scale_BatchNorm.py", # MLIR file too big: 256M + level2 / "41_Gemm_BatchNorm_GELU_ReLU.py", # MLIR file too big: 256M + level2 / "40_Matmul_Scaling_ResidualAdd.py", # MLIR file too big: 256M + level2 / "37_Matmul_Swish_Sum_GroupNorm.py", # MLIR file too big: 64.3M + level2 + / "58_ConvTranspose3d_LogSumExp_HardSwish_Subtract_Clamp.py", # error: failed to legalize operation 'torch.constant.int' + level2 + / "64_Gemm_LogSumExp_LeakyReLU_LeakyReLU_GELU_GELU.py", # error: failed to legalize operation 'torch.constant.int' + level2 + / "79_Conv3d_Multiply_InstanceNorm_Clamp_Multiply_Max.py", # LLVM ERROR: SmallVector unable to grow. Requested capacity (94312016449768) + level2 + / "92_Conv2d_GroupNorm_Tanh_HardSwish_ResidualAdd_LogSumExp.py", # error: failed to legalize operation 'torch.constant.int' +] + + +ctx = ir.Context() +pm = passmanager.PassManager(context=ctx) +pm.add("linalg-specialize-generic-ops") + +for pytorch_level, mlir_level in ( + (kernels_as_pytorch_level1, kernels_as_mlir_level1), + (kernels_as_pytorch_level2, kernels_as_mlir_level2), +): + for kernel_pytorch_file in pytorch_level.iterdir(): + level_and_kernel = ( + Path(kernel_pytorch_file.parent.name) / kernel_pytorch_file.name + ) + if level_and_kernel in ignore_list or not kernel_pytorch_file.is_file(): + print( + f"Skipping: {kernel_pytorch_file.parent.name}/{kernel_pytorch_file.name}" + ) + continue + + kernel_name = kernel_pytorch_file.stem + + kernel_as_mlir_path = mlir_level / (kernel_name + ".mlir") + if kernel_as_mlir_path.exists(): + print( + f"Already in cache: {kernel_pytorch_file.parent.name}/{kernel_pytorch_file.name}" + ) + continue + print( + f"Processing: {kernel_pytorch_file.parent.name}/{kernel_pytorch_file.name}" + ) + mlir_kernel = torch_ingress.import_from_file( + kernel_pytorch_file, ir_context=ctx + ) + + before_clean_up = "//" + str(mlir_kernel)[:-1].replace("\n", "\n//") + "\n" + try: + pm.run(mlir_kernel.operation) # cleanup + except Exception as e: + print(f"Error: got the following error cleaning up {kernel_name}") + raise e + + with kernel_as_mlir_path.open("w") as f: + print("// Torch-MLIR output:", file=f) + print(before_clean_up, file=f) + print("// MLIR output after clean-up:", file=f) + print(mlir_kernel, file=f) diff --git a/ingress/Torch-MLIR/generate-mlir.py b/ingress/Torch-MLIR/generate-mlir.py deleted file mode 100644 index 888e6dd..0000000 --- a/ingress/Torch-MLIR/generate-mlir.py +++ /dev/null @@ -1,72 +0,0 @@ -#!/usr/bin/env python3 - -import argparse -import os -import torch -import torch.nn as nn -from torch_mlir import fx -from torch_mlir.fx import OutputType - -# Parse arguments for selecting which model to load and which MLIR dialect to generate -def parse_args(): - parser = argparse.ArgumentParser(description="Generate MLIR for Torch-MLIR models.") - parser.add_argument( - "--model", - type=str, - required=True, - help="Path to the Torch model file.", - ) - parser.add_argument( - "--dialect", - type=str, - choices=["torch", "linalg", "stablehlo", "tosa"], - default="linalg", - help="MLIR dialect to generate.", - ) - return parser.parse_args() - -# Functin to load the Torch model -def load_torch_model(model_path): - - if not os.path.exists(model_path): - raise FileNotFoundError(f"Model file {model_path} does not exist.") - - model = torch.load(model_path) - return model - -# Function to generate MLIR from the Torch model -# See: https://github.com/MrSidims/PytorchExplorer/blob/main/backend/server.py#L237 -def generate_mlir(model, dialect): - - # Convert the Torch model to MLIR - output_type = None - if dialect == "torch": - output_type = OutputType.TORCH - elif dialect == "linalg": - output_type = OutputType.LINALG - elif dialect == "stablehlo": - output_type = OutputType.STABLEHLO - elif dialect == "tosa": - output_type = OutputType.TOSA - else: - raise ValueError(f"Unsupported dialect: {dialect}") - - module = fx.export_and_import(model, "", output_type=output_type) - return module - -# Main function to execute the script -def main(): - args = parse_args() - - # Load the Torch model - model = load_torch_model(args.model) - - # Generate MLIR from the model - mlir_module = generate_mlir(model, args.dialect) - - # Print or save the MLIR module - print(mlir_module) - -# Entry point for the script -if __name__ == "__main__": - main() \ No newline at end of file diff --git a/ingress/Torch-MLIR/generate-mlir.sh b/ingress/Torch-MLIR/generate-mlir.sh deleted file mode 100755 index 0a079c6..0000000 --- a/ingress/Torch-MLIR/generate-mlir.sh +++ /dev/null @@ -1,42 +0,0 @@ -#!/usr/bin/env bash - -# Command line argument for model to load and MLIR dialect to generate -while getopts "m:d:" opt; do - case $opt in - m) - MODEL=$OPTARG - ;; - d) - DIALECT=$OPTARG - ;; - *) - echo "Usage: $0 [-m model] [-d dialect]" - exit 1 - ;; - esac -done -if [ -z "$MODEL" ]; then - echo "Model not specified. Please provide a model using -m option." - exit 1 -fi -if [ -z "$DIALECT" ]; then - DIALECT="linalg" -fi - -# Enable local virtualenv created by install-virtualenv.sh -if [ ! -d "torch-mlir-venv" ]; then - echo "Virtual environment not found. Please run install-virtualenv.sh first." - exit 1 -fi -source torch-mlir-venv/bin/activate - -# Find script directory -SCRIPT_DIR=$(dirname "$(readlink -f "$0")") - -# Use the Python script to generate MLIR -echo "Generating MLIR for model '$MODEL' with dialect '$DIALECT'..." -python $SCRIPT_DIR/generate-mlir.py --model "$MODEL" --dialect "$DIALECT" -if [ $? -ne 0 ]; then - echo "Failed to generate MLIR for model '$MODEL'." - exit 1 -fi diff --git a/python/examples/ingress/torch/01-dummy-mlp-from-file.py b/python/examples/ingress/torch/01-dummy-mlp-from-file.py new file mode 100644 index 0000000..bf7c15c --- /dev/null +++ b/python/examples/ingress/torch/01-dummy-mlp-from-file.py @@ -0,0 +1,63 @@ +""" +Example demonstrating how to load a PyTorch model to MLIR using Lighthouse +without instantiating the model on the user's side. + +The script uses 'lighthouse.ingress.torch.import_from_file' function that +takes a path to a Python file containing the model definition, along with +the names of functions to get model init arguments and sample inputs. The function +imports the model class on its own, instantiates it, and passes it to torch_mlir +to get a MLIR module in the specified dialect. + +The script uses the model from 'DummyMLP/model.py' as an example. +""" + +import os +from pathlib import Path + +# MLIR infrastructure imports (only needed if you want to manipulate the MLIR module) +import mlir.dialects.func as func +from mlir import ir, passmanager + +# Lighthouse imports +from lighthouse.ingress.torch import import_from_file + +# Step 1: Set up paths to locate the model definition file +script_dir = Path(os.path.dirname(os.path.abspath(__file__))) +model_path = script_dir / "DummyMLP" / "model.py" + +ir_context = ir.Context() + +# Step 2: Convert PyTorch model to MLIR +# Conversion step where Lighthouse: +# - Loads the DummyMLP class and instantiates it with arguments obtained from 'get_init_inputs()' +# - Calls get_sample_inputs() to get sample input tensors for shape inference +# - Converts PyTorch model to linalg-on-tensors dialect operations using torch_mlir +mlir_module_ir: ir.Module = import_from_file( + model_path, # Path to the Python file containing the model + model_class_name="DummyMLP", # Name of the PyTorch nn.Module class to convert + init_args_fn_name="get_init_inputs", # Function that returns args for model.__init__() + sample_args_fn_name="get_sample_inputs", # Function that returns sample inputs to pass to 'model(...)' + dialect="linalg-on-tensors", # Target MLIR dialect (linalg ops on tensor types) + ir_context=ir_context # MLIR context for the conversion +) + +# The PyTorch model is now converted to MLIR at this point. You can now convert +# the MLIR module to a text form (e.g. 'str(mlir_module_ir)') and save it to a file. +# +# The following optional MLIR-processing steps are to give you an idea of what can +# also be done with the MLIR module. + +# Step 3: Extract the main function operation from the MLIR module and print its metadata +func_op: func.FuncOp = mlir_module_ir.operation.regions[0].blocks[0].operations[0] +print(f"entry-point name: {func_op.name}") +print(f"entry-point type: {func_op.type}") + +# Step 4: Apply some MLIR passes using a PassManager +pm = passmanager.PassManager(context=ir_context) +pm.add("linalg-specialize-generic-ops") +pm.add("one-shot-bufferize") +pm.run(mlir_module_ir.operation) + +# Step 5: Output the final MLIR +print("\n\nModule dump after running the pipeline:") +mlir_module_ir.dump() diff --git a/python/examples/ingress/torch/02-dummy-mlp-from-model.py b/python/examples/ingress/torch/02-dummy-mlp-from-model.py new file mode 100644 index 0000000..246deda --- /dev/null +++ b/python/examples/ingress/torch/02-dummy-mlp-from-model.py @@ -0,0 +1,56 @@ +""" +Example demonstrating how to load an already instantiated PyTorch model +to MLIR using Lighthouse. + +The script uses the 'lighthouse.ingress.torch.import_from_model' function that +takes a PyTorch model that has already been instantiated, along with its sample inputs. +The function passes the model to torch_mlir to get a MLIR module in the +specified dialect. + +The script uses a model from 'DummyMLP/model.py' as an example. +""" + +import torch + +# MLIR infrastructure imports (only needed if you want to manipulate the MLIR module) +import mlir.dialects.func as func +from mlir import ir, passmanager + +# Lighthouse imports +from lighthouse.ingress.torch import import_from_model + +# Import a sample model definition +from DummyMLP.model import DummyMLP + +# Step 1: Instantiate a model and prepare sample input +model = DummyMLP() +sample_input = torch.randn(1, 10) + +ir_context = ir.Context() +# Step 2: Convert the PyTorch model to MLIR +mlir_module_ir: ir.Module = import_from_model( + model, + sample_args=(sample_input,), + ir_context=ir_context +) + +# The PyTorch model is now converted to MLIR at this point. You can now convert +# the MLIR module to a text form (e.g. 'str(mlir_module_ir)') and save it to a file. +# +# The following optional MLIR-processing steps are to give you an idea of what can +# also be done with the MLIR module. + +# Step 3: Extract the main function operation from the MLIR module and print its metadata +func_op: func.FuncOp = mlir_module_ir.operation.regions[0].blocks[0].operations[0] +print(f"entry-point name: {func_op.name}") +print(f"entry-point type: {func_op.type}") + +# Step 4: Apply some MLIR passes using a PassManager +pm = passmanager.PassManager(context=ir_context) +pm.add("linalg-specialize-generic-ops") +pm.add("one-shot-bufferize") +pm.run(mlir_module_ir.operation) + +# Step 5: Output the final MLIR +print("\n\nModule dump after running the pipeline:") +mlir_module_ir.dump() diff --git a/python/examples/ingress/torch/DummyMLP/model.py b/python/examples/ingress/torch/DummyMLP/model.py new file mode 100644 index 0000000..905ecc8 --- /dev/null +++ b/python/examples/ingress/torch/DummyMLP/model.py @@ -0,0 +1,33 @@ +"""Defines a simple PyTorch model to be used in lighthouse's ingress examples.""" + +import torch +import torch.nn as nn + +import os + +class DummyMLP(nn.Module): + def __init__(self): + super().__init__() + self.net = nn.Sequential( + nn.Linear(10, 32), + nn.ReLU(), + nn.Linear(32, 2) + ) + + def forward(self, x): + return self.net(x) + + +def get_init_inputs(): + """Function to return args to pass to DummyMLP.__init__()""" + return () + + +def get_sample_inputs(): + """Arguments to pass to DummyMLP.forward()""" + return (torch.randn(1, 10),) + + +if __name__ == "__main__": + script_dir = os.path.dirname(os.path.abspath(__file__)) + torch.save(DummyMLP().state_dict(), os.path.join(script_dir, "dummy_mlp.pth")) diff --git a/python/lighthouse/ingress/README.md b/python/lighthouse/ingress/README.md new file mode 100644 index 0000000..61236ef --- /dev/null +++ b/python/lighthouse/ingress/README.md @@ -0,0 +1,10 @@ +# Lighthouse Ingress + +The `lighthouse.ingress` module converts various input formats to MLIR modules. + +## Supported Formats + +#### Torch +Converts PyTorch models to MLIR using `lighthouse.ingress.torch`. + +**Examples:** [torch examples](https://github.com/llvm/lighthouse/tree/main/python/examples/ingress/torch) diff --git a/python/lighthouse/ingress/__init__.py b/python/lighthouse/ingress/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/python/lighthouse/ingress/torch/__init__.py b/python/lighthouse/ingress/torch/__init__.py new file mode 100644 index 0000000..c3a6e72 --- /dev/null +++ b/python/lighthouse/ingress/torch/__init__.py @@ -0,0 +1 @@ +from .torch_import import import_from_file, import_from_model diff --git a/python/lighthouse/ingress/torch/torch_import.py b/python/lighthouse/ingress/torch/torch_import.py new file mode 100644 index 0000000..dfaccd0 --- /dev/null +++ b/python/lighthouse/ingress/torch/torch_import.py @@ -0,0 +1,226 @@ +import importlib +import importlib.util +from pathlib import Path +from typing import Iterable, Mapping + +from lighthouse.ingress.torch.utils import load_and_run_callable, maybe_load_and_run_callable + +try: + import torch + import torch.nn as nn +except ImportError as e: + raise ImportError( + "PyTorch is required to use the torch import functionality. " + "Please run 'uv pip install .[torch-mlir]'" + ) from e + +try: + from torch_mlir import fx + from torch_mlir.fx import OutputType +except ImportError as e: + raise ImportError( + "torch-mlir is required to use the torch import functionality. " + "Please run 'uv pip install .[torch-mlir]'" + ) from e + +from mlir import ir + +def import_from_model( + model: nn.Module, + sample_args: Iterable, + sample_kwargs: Mapping = None, + dialect: OutputType | str = OutputType.LINALG_ON_TENSORS, + ir_context: ir.Context | None = None, + **kwargs, +) -> str | ir.Module: + """Import a PyTorch nn.Module into MLIR. + + The function uses torch-mlir's FX importer to convert the given PyTorch model + into an MLIR module in the specified dialect. The user has to provide sample + input arguments (e.g. a torch.Tensor with the correct shape). + + Args: + model (nn.Module): The PyTorch model to import. + sample_args (Iterable): Sample input arguments to the model. + sample_kwargs (Mapping, optional): Sample keyword arguments to the model. + dialect (torch_mlir.fx.OutputType | {"linalg-on-tensors", "torch", "tosa"}): + The target dialect for the imported MLIR module. Defaults to + ``OutputType.LINALG_ON_TENSORS``. + ir_context (ir.Context, optional): An optional MLIR context to use for parsing + the module. If not provided, the module is returned as a string. + **kwargs: Additional keyword arguments passed to the ``torch_mlir.fx.export_and_import`` function. + + Returns: + str | ir.Module: The imported MLIR module as a string or an ir.Module if `ir_context` is provided. + + Examples: + >>> import torch + >>> import torch.nn as nn + >>> from lighthouse.ingress.torch_import import import_from_model + >>> class SimpleModel(nn.Module): + ... def __init__(self): + ... super().__init__() + ... self.fc = nn.Linear(10, 5) + ... def forward(self, x): + ... return self.fc(x) + >>> model = SimpleModel() + >>> sample_input = (torch.randn(1, 10),) + >>> # + >>> # option 1: get MLIR module as a string + >>> mlir_module : str = import_from_model(model, sample_input, dialect="linalg-on-tensors") + >>> print(mlir_module) # prints the MLIR module in linalg-on-tensors dialect + >>> # option 2: get MLIR module as an ir.Module + >>> ir_context = ir.Context() + >>> mlir_module_ir : ir.Module = import_from_model(model, sample_input, dialect="tosa", ir_context=ir_context) + """ + if dialect == "linalg": + raise ValueError( + "Dialect 'linalg' is not supported. Did you mean 'linalg-on-tensors'?" + ) + + if sample_kwargs is None: + sample_kwargs = {} + + model.eval() + module = fx.export_and_import( + model, *sample_args, output_type=dialect, **sample_kwargs, **kwargs + ) + + text_module = str(module) + if ir_context is None: + return text_module + # Cross boundary from torch-mlir's mlir to environment's mlir + return ir.Module.parse(text_module, context=ir_context) + + +def import_from_file( + filepath: str | Path, + model_class_name: str = "Model", + init_args_fn_name: str | None = "get_init_inputs", + init_kwargs_fn_name: str | None = None, + sample_args_fn_name: str = "get_inputs", + sample_kwargs_fn_name: str | None = None, + state_path: str | Path | None = None, + dialect: OutputType | str = OutputType.LINALG_ON_TENSORS, + ir_context: ir.Context | None = None, + **kwargs, +) -> str | ir.Module: + """Load a PyTorch nn.Module from a file and import it into MLIR. + + The function takes a `filepath` to a Python file containing the model definition, + along with the names of functions to get model init arguments and sample inputs. + The function imports the model class on its own, instantiates it, and passes + it to ``torch_mlir`` to get a MLIR module in the specified `dialect`. + + Args: + filepath (str | Path): Path to the Python file containing the model definition. + model_class_name (str, optional): The name of the model class in the file. + Defaults to "Model". + init_args_fn_name (str | None, optional): The name of the function in the file + that returns the arguments for initializing the model. If None, the model + is initialized without arguments. Defaults to "get_init_inputs". + init_kwargs_fn_name (str | None, optional): The name of the function in the file + that returns the keyword arguments for initializing the model. If None, the model + is initialized without keyword arguments. + sample_args_fn_name (str, optional): The name of the function in the file that + returns the sample input arguments for the model. Defaults to "get_inputs". + sample_kwargs_fn_name (str, optional): The name of the function in the file that + returns the sample keyword input arguments for the model. Defaults to None. + state_path (str | Path | None, optional): Optional path to a file containing + the model's ``state_dict``. Defaults to None. + dialect (torch_mlir.fx.OutputType | {"linalg-on-tensors", "torch", "tosa"}, optional): + The target dialect for the imported MLIR module. Defaults to + ``OutputType.LINALG_ON_TENSORS``. + ir_context (ir.Context, optional): An optional MLIR context to use for parsing + the module. If not provided, the module is returned as a string. + **kwargs: Additional keyword arguments passed to the ``torch_mlir.fx.export_and_import`` function. + + Returns: + str | ir.Module: The imported MLIR module as a string or an ir.Module if `ir_context` is provided. + + Examples: + Given a file `path/to/model_file.py` with the following content: + ```python + import torch + import torch.nn as nn + + class MyModel(nn.Module): + def __init__(self): + super().__init__() + self.fc = nn.Linear(10, 5) + def forward(self, x): + return self.fc(x) + + def get_inputs(): + return (torch.randn(1, 10),) + ``` + + The import script would look like: + >>> from lighthouse.ingress.torch_import import import_from_file + >>> # option 1: get MLIR module as a string + >>> mlir_module : str = import_from_file( + ... "path/to/model_file.py", + ... model_class_name="MyModel", + ... init_args_fn_name=None, + ... dialect="linalg-on-tensors" + ... ) + >>> print(mlir_module) # prints the MLIR module in linalg-on-tensors dialect + >>> # option 2: get MLIR module as an ir.Module + >>> ir_context = ir.Context() + >>> mlir_module_ir : ir.Module = import_from_file( + ... "path/to/model_file.py", + ... model_class_name="MyModel", + ... init_args_fn_name=None, + ... dialect="linalg-on-tensors", + ... ir_context=ir_context + ... ) + """ + if isinstance(filepath, str): + filepath = Path(filepath) + module_name = filepath.stem + + spec = importlib.util.spec_from_file_location(module_name, filepath) + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + + model = getattr(module, model_class_name, None) + if model is None: + raise ValueError(f"Model class '{model_class_name}' not found in {filepath}") + + model_init_args = maybe_load_and_run_callable( + module, + init_args_fn_name, + default=tuple(), + error_msg=f"Init args function '{init_args_fn_name}' not found in {filepath}" + ) + model_init_kwargs = maybe_load_and_run_callable( + module, + init_kwargs_fn_name, + default={}, + error_msg=f"Init kwargs function '{init_kwargs_fn_name}' not found in {filepath}" + ) + sample_args = load_and_run_callable( + module, + sample_args_fn_name, + f"Sample args function '{sample_args_fn_name}' not found in {filepath}" + ) + sample_kwargs = maybe_load_and_run_callable( + module, + sample_kwargs_fn_name, + default={}, + error_msg=f"Sample kwargs function '{sample_kwargs_fn_name}' not found in {filepath}" + ) + + nn_model: nn.Module = model(*model_init_args, **model_init_kwargs) + if state_path is not None: + state_dict = torch.load(state_path) + nn_model.load_state_dict(state_dict) + + return import_from_model( + nn_model, + sample_args=sample_args, + sample_kwargs=sample_kwargs, + dialect=dialect, + ir_context=ir_context, + **kwargs, + ) diff --git a/python/lighthouse/ingress/torch/utils.py b/python/lighthouse/ingress/torch/utils.py new file mode 100644 index 0000000..9a464b3 --- /dev/null +++ b/python/lighthouse/ingress/torch/utils.py @@ -0,0 +1,50 @@ +from types import ModuleType +from typing import Any + + +def load_and_run_callable( + module: ModuleType, + symbol_name: str, + error_msg: str | None = None, +): + """Helper to load and run a callable from a module by its symbol name. + + Args: + module (ModuleType): The python module to load the callable from. + symbol_name (str): The name of the callable symbol to load. + error_msg (str | None): Custom error message to use when raising an error + for missing symbol. If not provided, a default message will be used. + + Returns: + Any: The result of calling the loaded callable. + """ + func = getattr(module, symbol_name, None) + if func is None: + if error_msg: + raise ValueError(error_msg) + raise ValueError( + f"Symbol '{symbol_name}' not found in module '{module.__name__}'" + ) + if not callable(func): + raise ValueError(f"Symbol '{symbol_name}' is not callable") + return func() + + +def maybe_load_and_run_callable( + module: ModuleType, + symbol_name: str | None, + default: Any, + error_msg: str | None = None, +): + """Helper to conditionally load and run a callable from a module by its symbol name. + + If `symbol_name` is None, the function returns the provided default value. Otherwise + it calls ``load_and_run_callable`` with the provided arguments. + """ + if symbol_name is None: + return default + return load_and_run_callable( + module, + symbol_name, + error_msg=error_msg + )