diff --git a/docs/index.md b/docs/index.md index 7a4acfa..f912567 100644 --- a/docs/index.md +++ b/docs/index.md @@ -18,4 +18,5 @@ api :hidden: examples/vectoradd_jax/README.md +examples/jax_fem/README.md ``` diff --git a/examples/jax_fem/README.md b/examples/jax_fem/README.md new file mode 100644 index 0000000..3b38c9a --- /dev/null +++ b/examples/jax_fem/README.md @@ -0,0 +1,93 @@ +# Using JAX with finite element methods (JAX-FEM) + +In this example, you'll generate a Streamlit app from a Tesseract which models a structure and computes its compliance with finite element methods. +This is based on the example for [shape optimisation in Tesseract-JAX](https://docs.pasteurlabs.ai/projects/tesseract-jax/latest/examples/fem-shapeopt/demo.html) using JAX-FEM. +This, of course, uses `tesseract-streamlit` to automatically generate an interactive web app, this time with an interactive PyVista plot of the structure! ⚡ + +--- + +## 📥 Step 1: Download the Example Code + +We've written a custom Tesseract for this example, mashing up the Design Tesseract and FEM Tesseract from the original Tesseract JAX tutorial, so clone `tesseract-streamlit` like so: + +```shell +git clone --depth 1 https://github.com/pasteurlabs/tesseract-streamlit.git ~/Downloads/tesseract-streamlit +``` + +--- + +## 📦 Step 2: Install Requirements + +Enter the example directory, and install the required packages: + +```bash +cd ~/Documents/tesseract-streamlit/examples/jax_fem +pip install -r requirements.txt +``` + +--- + +## 🛠️ Step 3: Build and Serve the Tesseract + +Use the Tesseract CLI to build and serve `jax_fem`: + +```bash +tesseract build ~/Documents/tesseract-streamlit/examples/jax_fem +tesseract serve jax_fem +``` + +> [!NOTE] +> Make note of the `PORT` and `PROJECT ID` printed to stdout — you'll need them shortly. + +--- + +## ⚡ Step 4: Generate the Streamlit App + +With `tesseract-streamlit` installed, generate a ready-to-run Streamlit app: + +```bash +tesseract-streamlit --user-code udf.py "http://localhost:" app.py +``` + +`udf.py` can be found in under `tesseract-streamlit/examples/jax_fem/`. +It contains a custom function that takes the Tesseract's inputs to render a PyVista plot of the design structure directly in the UI! ⚙️ +Check out the [source code to see how it works](https://github.com/pasteurlabs/tesseract-streamlit/examples/jax_fem/udf.py). + +--- + +## ▶️ Step 5: Launch the App + +Run your new app with: + +```bash +cd ~/Documents/tesseract-streamlit/examples/jax_fem +streamlit run app.py +``` + +This will launch a web interface for submitting inputs, running the Tesseract, and visualising the results. + +The form is populated from sensible defaults defined in `tesseract_api.py`. +To easily provide the input parameters for the structure itself, you can upload the `bar_params.json` file in the current directory. + +--- + +## 🖼️ Screenshots + + +| | | +| --- | --- | +| | | + +--- + +## 🧹 Step 6: Clean Up + +When you're done, you can stop the Tesseract server with: + +```bash +tesseract teardown +``` + +--- + +🎉 That’s it — you've transformed a running Tesseract into a beautiful Streamlit web app with interactive plots, with minimal effort from the command line! diff --git a/examples/jax_fem/bar_params.json b/examples/jax_fem/bar_params.json new file mode 100644 index 0000000..45462c6 --- /dev/null +++ b/examples/jax_fem/bar_params.json @@ -0,0 +1 @@ +[[[-30.0,-5.0,0.0],[-18.0,-5.0,0.0],[-6.0,-5.0,0.0],[6.000000953674316,-5.0,0.0],[18.0,-5.0,0.0],[30.0,-5.0,0.0]],[[-30.0,-2.5,0.0],[-18.0,-2.5,0.0],[-6.0,-2.5,0.0],[6.000000953674316,-2.5,0.0],[18.0,-2.5,0.0],[30.0,-2.5,0.0]],[[-30.0,0.0,0.0],[-18.0,0.0,0.0],[-6.0,0.0,0.0],[6.000000953674316,0.0,0.0],[18.0,0.0,0.0],[30.0,0.0,0.0]],[[-30.0,2.5,0.0],[-18.0,2.5,0.0],[-6.0,2.5,0.0],[6.000000953674316,2.5,0.0],[18.0,2.5,0.0],[30.0,2.5,0.0]]] diff --git a/examples/jax_fem/requirements.txt b/examples/jax_fem/requirements.txt new file mode 100644 index 0000000..6d221c4 --- /dev/null +++ b/examples/jax_fem/requirements.txt @@ -0,0 +1,2 @@ +pyvista==0.45.2 +numpy==2.2.5 diff --git a/examples/jax_fem/run.sh b/examples/jax_fem/run.sh new file mode 100644 index 0000000..6b77bad --- /dev/null +++ b/examples/jax_fem/run.sh @@ -0,0 +1,38 @@ +#!/usr/bin/env bash + +# the parent dir of this script: +scriptdir="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +workdir="$(cd "$(dirname "${BASH_SOURCE[0]}")/../.." && pwd)" + +# a temporary dir to store the downloads for the example: +tmpdir=$(mktemp -d) + +if [ "$(basename $workdir)" != "tesseract-streamlit" ]; then + echo "Path mismatch: please contact the developers." + echo $workdir + exit 1 +fi + +# install requirements for the udf.py module: +pip install -r "${scriptdir}/requirements.txt" + +# build and serve the vectoradd_jax example tesseract: +example=jax_fem +tesseract build "${workdir}/examples/${example}" +tessinfo=$(tesseract serve $example) +tessid=$(echo $tessinfo | jq -r '.project_id') +tessport=$(echo $tessinfo | jq -r '.containers[0].port') + +# automatically generate the Streamlit app from the served tesseract: +tesseract-streamlit --user-code "${scriptdir}/udf.py" "http://localhost:${tessport}" "${tmpdir}/app.py" + +# launch the web-app: +streamlit run "${tmpdir}/app.py" + +# stop serving the tesseract +tesseract teardown $tessid + +# clean up the temporary directory: +rm -rf $tmpdir + +exit 0 diff --git a/examples/jax_fem/tesseract_api.py b/examples/jax_fem/tesseract_api.py new file mode 100644 index 0000000..5233e8f --- /dev/null +++ b/examples/jax_fem/tesseract_api.py @@ -0,0 +1,390 @@ +# Copyright 2025 Pasteur Labs. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 + +# Tesseract API module for jax-fem +# Generated by tesseract 0.9.0 on 2025-07-03T12:17:32.779918 + + +from collections.abc import Callable +from functools import lru_cache + +import jax +import jax.numpy as jnp +import numpy as np +import pyvista as pv +from jax_fem.generate_mesh import Mesh, get_meshio_cell_type, rectangle_mesh +from jax_fem.problem import Problem +from jax_fem.solver import ad_wrapper +from pydantic import BaseModel, Field +from tesseract_core.runtime import Array, Differentiable, Float32 + +# +# Schemas +# + + +class InputSchema(BaseModel): + bar_params: Differentiable[ + Array[ + (None, None, 3), + Float32, + ] + ] = Field( + description=( + "Vertex positions of the bar geometry. " + "The shape is (num_bars, num_vertices, 3), where num_bars is the number of bars " + "and num_vertices is the number of vertices per bar. The last dimension represents " + "the x, y, z coordinates of each vertex." + ) + ) + + bar_radius: float = Field( + default=1.5, + description=( + "Radius of the bars in the geometry. " + "This is a scalar value that defines the thickness of the bars." + ), + ) + + Lx: float = Field( + default=60.0, + description=( + "Length of the plane in the x direction. " + "This is a scalar value that defines the size of the plane along the x-axis." + ), + ) + Ly: float = Field( + default=30.0, + description=( + "Length of the plane in the y direction. " + "This is a scalar value that defines the size of the plane along the y-axis." + ), + ) + Nx: int = Field( + default=60, + description=( + "Number of points in the x direction. " + "This is an integer value that defines the resolution of the plane along the x-axis." + ), + ) + Ny: int = Field( + default=30, + description=( + "Number of points in the y direction. " + "This is an integer value that defines the resolution of the plane along the y-axis." + ), + ) + epsilon: float = Field( + default=1e-5, + description=( + "Epsilon value for finite difference approximation of the Jacobian. " + "This is a small scalar value used to compute the numerical gradient." + ), + ) + + +class OutputSchema(BaseModel): + compliance: Differentiable[ + Array[ + (), + Float32, + ] + ] = Field(description="Compliance of the structure, a measure of stiffness") + + von_mises_stress: Differentiable[ + Array[ + (None,), + Float32, + ] + ] = Field(description="The average von Mises stress in each element") + + +# +# Helper functions +# + + +def build_geometry( + params: np.ndarray, + radius: float, +) -> list[pv.PolyData]: + """Build a pyvista geometry from the parameters. + + The parameters are expected to be of shape (n_chains, n_edges_per_chain + 1, 3), + """ + n_chains = params.shape[0] + geometry = [] + + for chain in range(n_chains): + tube = pv.Spline(points=params[chain]).tube(radius=radius, capping=False) + geometry.append(tube) + + return geometry + + +def compute_sdf( + params: np.ndarray, + radius: float, + Lx: float, + Ly: float, + Nx: int, + Ny: int, +) -> pv.PolyData: + """Create a pyvista plane that has the SDF values stored as a vertex attribute. + + The SDF field is computed based on the geometry defined by the parameters. + """ + grid_coords = pv.Plane( + center=(0, 0, 0), + direction=(0, 0, 1), + i_size=Lx, + j_size=Ly, + i_resolution=Nx - 1, + j_resolution=Ny - 1, + ) + grid_coords = grid_coords.triangulate() + + geometries = build_geometry( + params, + radius=radius, + ) + + sdf_field = None + + for geometry in geometries: + # Compute the implicit distance from the geometry to the grid coordinates. + # The implicit distance is a signed distance field, where positive values + # are outside the geometry and negative values are inside. + this_sdf = grid_coords.compute_implicit_distance(geometry.triangulate()) + if sdf_field is None: + sdf_field = this_sdf + else: + sdf_field["implicit_distance"] = np.minimum( + sdf_field["implicit_distance"], this_sdf["implicit_distance"] + ) + + return sdf_field + + +def sdf_to_rho( + sdf: jnp.ndarray, scale: float = 4.0, offset: float = 1.0 +) -> jnp.ndarray: + """Convert signed distance function to material density using sigmoid. + + Args: + sdf: Signed distance function values. + scale: Sigmoid steepness (higher = sharper transition). + offset: SDF value where density = 0.5. + + Returns: + Material density field in [0,1]. + """ + return 1 / (1 + jnp.exp(scale * sdf - offset)) + + +# Define constitutive relationship +# Adapted from JAX-FEM +# https://github.com/deepmodeling/jax-fem/blob/1bdbf060bb32951d04ed9848c238c9a470fee1b4/demos/topology_optimization/example.py +class Elasticity(Problem): + def custom_init(self): + self.fe = self.fes[0] + self.fe.flex_inds = jnp.arange(len(self.fe.cells)) + + def get_tensor_map(self): + def stress(u_grad, theta): + Emax = 70.0e3 + Emin = 1e-3 * Emax + nu = 0.3 + penal = 3.0 + E = Emin + (Emax - Emin) * theta[0] ** penal + epsilon = 0.5 * (u_grad + u_grad.T) + eps11 = epsilon[0, 0] + eps22 = epsilon[1, 1] + eps12 = epsilon[0, 1] + sig11 = E / (1 + nu) / (1 - nu) * (eps11 + nu * eps22) + sig22 = E / (1 + nu) / (1 - nu) * (nu * eps11 + eps22) + sig12 = E / (1 + nu) * eps12 + sigma = jnp.array([[sig11, sig12], [sig12, sig22]]) + return sigma + + return stress + + def get_surface_maps(self): + def surface_map(u, x): + return jnp.array([0.0, 100.0]) + + return [surface_map] + + def set_params(self, params): + # Override base class method. + full_params = jnp.ones((self.fe.num_cells, params.shape[1])) + full_params = full_params.at[self.fe.flex_inds].set(params) + thetas = jnp.repeat(full_params[:, None, :], self.fe.num_quads, axis=1) + self.full_params = full_params + self.internal_vars = [thetas] + + def compute_compliance(self, sol): + # Surface integral + boundary_inds = self.boundary_inds_list[0] + _, nanson_scale = self.fe.get_face_shape_grads(boundary_inds) + u_face = ( + sol[self.fe.cells][boundary_inds[:, 0]][:, None, :, :] + * self.fe.face_shape_vals[boundary_inds[:, 1]][:, :, :, None] + ) + + u_face = jnp.sum(u_face, axis=2) + subset_quad_points = self.physical_surface_quad_points[0] + neumann_fn = self.get_surface_maps()[0] + traction = -jax.vmap(jax.vmap(neumann_fn))(u_face, subset_quad_points) + val = jnp.sum(traction * u_face * nanson_scale[:, :, None]) + return val + + def get_von_mises_stress_fn(self): + def vm_stress_fn_helper(sigma): + dim = 2 + s_dev = sigma - 1.0 / dim * jnp.trace(sigma) * np.eye(dim) + vm_s = jnp.sqrt(3.0 / 2.0 * np.sum(s_dev * s_dev)) + return vm_s + + def vm_stress_fn(u_grad, theta): + sigma = self.get_tensor_map()(u_grad, theta) + return vm_stress_fn_helper(sigma) + + return vm_stress_fn + + # reference: + # https://github.com/deepmodeling/jax-fem/blob/ac0aace6537cfd3f44183d760fdfa201cee8ab46/docs/source/learn/linear_elasticity.ipynb#L300 + # https://github.com/deepmodeling/jax-fem/blob/ac0aace6537cfd3f44183d760fdfa201cee8ab46/applications/outdated/top_opt/fem_model.py#L121 + def compute_von_mises_stress(self, sol, density): + """Compute element-average von Mises stress.""" + # (num_cells, num_quads, num_nodes, vec, dim) + u_grads = self.fe.sol_to_grad(sol) + + vm_stress_fn = self.get_von_mises_stress_fn() + vm_stress = jax.vmap(jax.vmap(vm_stress_fn))( + u_grads, *self.internal_vars + ) # (num_cells, num_quads) + + cells_JxW = self.JxW[:, 0, :] # (num_cells, num_quads) + volume_avg_vm_stress = np.sum(vm_stress * cells_JxW, axis=1) / np.sum( + cells_JxW, axis=1 + ) # (num_cells,) + return volume_avg_vm_stress + + +# Memoize the setup function to avoid expensive recomputation +@lru_cache(maxsize=1) +def setup( + Nx: int = 60, + Ny: int = 30, + Lx: float = 60.0, + Ly: float = 30.0, +) -> tuple[Elasticity, Callable]: + # Specify mesh-related information. We use first-order quadrilateral element. + ele_type = "QUAD4" + cell_type = get_meshio_cell_type(ele_type) + meshio_mesh = rectangle_mesh(Nx=Nx, Ny=Ny, domain_x=Lx, domain_y=Ly) + mesh = Mesh(meshio_mesh.points, meshio_mesh.cells_dict[cell_type]) + + # Define boundary conditions and values. + def fixed_location(point): + return jnp.isclose(point[0], 0.0, atol=1e-5) + + def load_location(point): + return jnp.logical_and( + jnp.isclose(point[0], Lx, atol=1e-5), + jnp.isclose(point[1], 0.0, atol=0.1 * Ly + 1e-5), + ) + + def dirichlet_val(point): + return 0.0 + + dirichlet_bc_info = [[fixed_location] * 2, [0, 1], [dirichlet_val] * 2] + + location_fns = [load_location] + + # Define forward problem + problem = Elasticity( + mesh, + vec=2, + dim=2, + ele_type=ele_type, + dirichlet_bc_info=dirichlet_bc_info, + location_fns=location_fns, + ) + + # Apply the automatic differentiation wrapper + # This is a critical step that makes the problem solver differentiable + fwd_pred = ad_wrapper( + problem, + solver_options={"umfpack_solver": {}}, + adjoint_solver_options={"umfpack_solver": {}}, + ) + return problem, fwd_pred + + +# +# Required endpoints +# + + +def apply(inputs: InputSchema) -> OutputSchema: + """Computes the compliance of a structure. + + Gridded signed distance function (SDF) is computed from a set of + shape parameters, which are taken as inputs. + Parameters define control points and radii of piecewise 3D tubes. + + JAX-FEM evaluates compliance of the structure from density as a + function of the SDF field. + """ + Nx, Ny, Lx, Ly = inputs.Nx, inputs.Ny, inputs.Lx, inputs.Ly + sdf_geom = compute_sdf( + params=inputs.bar_params, + radius=inputs.bar_radius, + Lx=Lx, + Ly=Ly, + Nx=Nx, + Ny=Ny, + )["implicit_distance"] + sdf = sdf_geom.reshape((Ny, Nx)).T + density = jnp.reshape(sdf_to_rho(sdf), (Nx * Ny, 1)) + problem, fwd_pred = setup( + Nx=Nx, + Ny=Ny, + Lx=Lx, + Ly=Ly, + ) + sol_list = fwd_pred(density) + compliance = problem.compute_compliance(sol_list[0]) + vm = problem.compute_von_mises_stress(sol_list[0], density) + return OutputSchema(compliance=compliance, von_mises_stress=vm) + + +# +# Optional endpoints +# + +# import numpy as np + +# def jacobian(inputs: InputSchema, jac_inputs: set[str], jac_outputs: set[str]): +# return {} + +# def jacobian_vector_product( +# inputs: InputSchema, +# jvp_inputs: set[str], +# jvp_outputs: set[str], +# tangent_vector: dict[str, np.typing.ArrayLike] +# ) -> dict[str, np.typing.ArrayLike]: +# return {} + +# def vector_jacobian_product( +# inputs: InputSchema, +# vjp_inputs: set[str], +# vjp_outputs: set[str], +# cotangent_vector: dict[str, np.typing.ArrayLike] +# ) -> dict[str, np.typing.ArrayLike]: +# return {} + +# def abstract_eval(abstract_inputs): +# return {} diff --git a/examples/jax_fem/tesseract_config.yaml b/examples/jax_fem/tesseract_config.yaml new file mode 100644 index 0000000..75248c7 --- /dev/null +++ b/examples/jax_fem/tesseract_config.yaml @@ -0,0 +1,38 @@ +# Tesseract configuration file +# Generated by tesseract 0.9.0 on 2025-07-03T12:17:32.779918 + +name: "jax_fem" +version: "0.1.0" +description: | + Tesseract that computes the compliance of a structure from its gridded signed + distance function (SDF), computed from a set of shape parameters. + + Parameters are expected to define the control points and radii of piecewise + linear tubes in 3D space. + + +build_config: + # Base image to use for the container, must be Ubuntu or Debian-based + # base_image: "debian:bookworm-slim" + + # Platform to build the container for. In general, images can only be executed + # on the platform they were built for. + target_platform: "native" + base_image: "condaforge/miniforge3:latest" + requirements: + provider: conda + extra_packages: + - libgl1 + + # Additional packages to install in the container (via apt-get) + # extra_packages: + # - package_name + + # Data to copy into the container, relative to the project root + # package_data: + # - [path/to/source, path/to/destination] + + # Additional Dockerfile commands to run during the build process + # custom_build_steps: + # - | + # RUN echo "Hello, World!" diff --git a/examples/jax_fem/tesseract_environment.yaml b/examples/jax_fem/tesseract_environment.yaml new file mode 100644 index 0000000..9b5db26 --- /dev/null +++ b/examples/jax_fem/tesseract_environment.yaml @@ -0,0 +1,22 @@ +name: jax-fem-env +channels: + - conda-forge +dependencies: + - python==3.12 + - numpy==1.26.4 + - scipy==1.15.2 + - matplotlib==3.10.3 + - meshio==5.3.5 + - petsc4py==3.23.3 + - fenics==2019.1.0 + - gmsh==4.13.1 + - python-gmsh==4.13.1 + - pip + - pyvista=0.45.2 + - pip: + - setuptools + - wheel + - fenics-basix==0.9.0 + - pyfiglet==1.0.3 + - jax[cpu]==0.5.3 + - jax-fem==0.0.9 diff --git a/examples/jax_fem/udf.py b/examples/jax_fem/udf.py new file mode 100644 index 0000000..42c4193 --- /dev/null +++ b/examples/jax_fem/udf.py @@ -0,0 +1,89 @@ +import typing + +import numpy as np +import pyvista as pv + + +def _build_geometry( + params: np.ndarray, + radius: float, +) -> list[pv.PolyData]: + """Build a pyvista geometry from the parameters. + + The parameters are expected to be of shape (n_chains, n_edges_per_chain + 1, 3), + """ + n_chains = params.shape[0] + geometry = [] + + for chain in range(n_chains): + tube = pv.Spline(points=params[chain]).tube(radius=radius, capping=False) + geometry.append(tube) + + return geometry + + +def _compute_sdf( + params: np.ndarray, + radius: float, + Lx: float, + Ly: float, + Nx: int, + Ny: int, +) -> pv.PolyData: + """Create a pyvista plane that has the SDF values stored as a vertex attribute. + + The SDF field is computed based on the geometry defined by the parameters. + """ + grid_coords = pv.Plane( + center=(0, 0, 0), + direction=(0, 0, 1), + i_size=Lx, + j_size=Ly, + i_resolution=Nx - 1, + j_resolution=Ny - 1, + ) + grid_coords = grid_coords.triangulate() + + geometries = _build_geometry( + params, + radius=radius, + ) + + sdf_field = None + + for geometry in geometries: + # Compute the implicit distance from the geometry to the grid coordinates. + # The implicit distance is a signed distance field, where positive values + # are outside the geometry and negative values are inside. + this_sdf = grid_coords.compute_implicit_distance(geometry.triangulate()) + if sdf_field is None: + sdf_field = this_sdf + else: + sdf_field["implicit_distance"] = np.minimum( + sdf_field["implicit_distance"], this_sdf["implicit_distance"] + ) + + return sdf_field + + +def input_geometry(inputs: dict[str, typing.Any]) -> pv.Plotter: + """Display the geometry defined by the parameters. + + Shows the chains formed of bars, and the signed distance field + around them. + """ + bar_params, bar_radius = np.array(inputs["bar_params"]), inputs["bar_radius"] + Lx, Ly, Nx, Ny = inputs["Lx"], inputs["Ly"], inputs["Nx"], inputs["Ny"] + geometries = _build_geometry(bar_params, bar_radius) + # Concatenate all pipe geometries into a single PolyData object + geometry = sum(geometries, start=pv.PolyData()) + sdf = _compute_sdf(bar_params, radius=bar_radius, Lx=Lx, Ly=Ly, Nx=Nx, Ny=Ny) + isoval = sdf.contour(isosurfaces=[0.0], scalars="implicit_distance") + + plotter = pv.Plotter() + plotter.add_mesh(geometry, color="lightblue", show_edges=True, edge_color="black") + plotter.add_mesh( + sdf, scalars="implicit_distance", cmap="coolwarm", show_edges=False + ) + plotter.add_mesh(isoval, color="red", show_edges=True, line_width=2) + return plotter