Skip to content

Commit ff82241

Browse files
feat(adjoint): Add autograd support for component modelers through web.run
Co-authored-by: groberts-flex <[email protected]>
1 parent c643e74 commit ff82241

File tree

7 files changed

+475
-17
lines changed

7 files changed

+475
-17
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
2828
- Added `TerminalComponentModelerData`, `ComponentModelerData`, `MicrowaveSMatrixData`, and introduced multiple DataArrays for modeler workflow data structures.
2929
- Added autograd support for dispersive material models: `Sellmeier`, `Drude`, `Lorentz`, `Debye` and their custom medium variants.
3030
- Added check and exception for NaN data in the adjoint pipeline to raise issue to user before adjoint source creation failure.
31+
- Added autograd support for `TerminalComponentModeler` and `ModalComponentModeler`.
3132

3233
### Changed
3334
- Validate mode solver object for large number of grid points on the modal plane.

tests/test_components/autograd/test_autograd.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,9 @@
6868
# whether to include a call to `objective(params)` in addition to gradient
6969
CALL_OBJECTIVE = False
7070

71+
# import sys
72+
# sys.stdout = sys.stderr
73+
7174

7275
# --- helpers for custom dispersive tests ---
7376
def _patch_cmp_to_const(monkeypatch, cls, dJ_const):
Lines changed: 381 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,381 @@
1+
from __future__ import annotations
2+
3+
import autograd as ag
4+
import autograd.numpy as anp
5+
import numpy as np
6+
import pytest
7+
8+
import tidy3d as td
9+
from tidy3d.plugins.smatrix.analysis import terminal as terminal_analysis
10+
from tidy3d.plugins.smatrix.component_modelers.modal import ModalComponentModeler
11+
from tidy3d.plugins.smatrix.component_modelers.terminal import TerminalComponentModeler
12+
from tidy3d.plugins.smatrix.data.data_array import TerminalPortDataArray
13+
from tidy3d.plugins.smatrix.ports.modal import Port as ModalPort
14+
from tidy3d.plugins.smatrix.ports.rectangular_lumped import LumpedPort as RectLumpedPort
15+
from tidy3d.plugins.smatrix.run import run as smatrix_run
16+
from tidy3d.web.api.autograd import autograd as web_ag
17+
18+
19+
def _run_emulated_minimal(simulation: td.Simulation, path=None, **kwargs) -> td.SimulationData:
20+
"""Very small offline emulator used by autograd tests.
21+
22+
- Supports ModeMonitor (amps + n_complex)
23+
- Supports FieldMonitor (Ex, Ey, Ez, Hx, Hy, Hz)
24+
- Supports PermittivityMonitor (eps_xx, eps_yy, eps_zz)
25+
"""
26+
27+
rng = np.random.default_rng(42)
28+
29+
def _coords_for_monitor(sim: td.Simulation, mnt: td.Monitor):
30+
grid = sim.discretize_monitor(mnt)
31+
bounds = grid.boundaries.dict()
32+
33+
def centers(arr):
34+
arr = np.asarray(arr)
35+
if arr.size < 2:
36+
return arr
37+
return 0.5 * (arr[:-1] + arr[1:])
38+
39+
xyz = {}
40+
for ax, dim in enumerate("xyz"):
41+
if mnt.size[ax] == 0:
42+
xyz[dim] = [mnt.center[ax]]
43+
else:
44+
arr = np.asarray(bounds[dim])
45+
if arr.size < 2:
46+
xyz[dim] = [mnt.center[ax]]
47+
else:
48+
xyz[dim] = centers(arr)
49+
50+
# ensure at least two points along any nonzero-size axis to avoid empty-grid interpolation
51+
for ax, dim in enumerate("xyz"):
52+
if mnt.size[ax] != 0 and len(xyz[dim]) < 2:
53+
c = float(mnt.center[ax])
54+
half = float(mnt.size[ax]) / 2.0
55+
if half == 0:
56+
half = 1e-6
57+
eps = max(half * 1e-3, 1e-6)
58+
xyz[dim] = [c - eps, c + eps]
59+
return xyz, grid
60+
61+
data_items = []
62+
63+
for mnt in simulation.monitors:
64+
if isinstance(mnt, td.ModeMonitor):
65+
f = list(mnt.freqs)
66+
mode_index = np.arange(mnt.mode_spec.num_modes)
67+
directions = np.array(["+", "-"])
68+
69+
amps_vals = (1 + 0.1j) * rng.random((len(directions), len(f), len(mode_index)))
70+
n_vals = (1 + 0.05j) * rng.random((len(f), len(mode_index)))
71+
72+
amps = td.ModeAmpsDataArray(
73+
amps_vals,
74+
coords={"direction": directions, "f": f, "mode_index": mode_index},
75+
)
76+
n_complex = td.ModeIndexDataArray(n_vals, coords={"f": f, "mode_index": mode_index})
77+
78+
data_items.append(td.ModeData(monitor=mnt, amps=amps, n_complex=n_complex))
79+
80+
elif isinstance(mnt, td.FieldMonitor):
81+
xyz, grid = _coords_for_monitor(simulation, mnt)
82+
f = list(mnt.freqs)
83+
shape = (len(xyz["x"]), len(xyz["y"]), len(xyz["z"]), len(f))
84+
85+
def cfield(shape=shape, xyz=xyz, f=f, rng=rng):
86+
vals = (1 + 0.2j) * rng.random(shape)
87+
return td.ScalarFieldDataArray(vals, coords={**xyz, "f": f})
88+
89+
data_items.append(
90+
td.FieldData(
91+
monitor=mnt,
92+
grid_expanded=grid,
93+
Ex=cfield(),
94+
Ey=cfield(),
95+
Ez=cfield(),
96+
Hx=cfield(),
97+
Hy=cfield(),
98+
Hz=cfield(),
99+
symmetry=(0, 0, 0),
100+
symmetry_center=simulation.center,
101+
)
102+
)
103+
104+
elif isinstance(mnt, td.PermittivityMonitor):
105+
xyz, grid = _coords_for_monitor(simulation, mnt)
106+
f = list(mnt.freqs)
107+
shape = (len(xyz["x"]), len(xyz["y"]), len(xyz["z"]), len(f))
108+
109+
def rfield(shape=shape, xyz=xyz, f=f, rng=rng):
110+
vals = rng.random(shape)
111+
return td.ScalarFieldDataArray(vals, coords={**xyz, "f": f})
112+
113+
data_items.append(
114+
td.PermittivityData(
115+
monitor=mnt,
116+
grid_expanded=grid,
117+
eps_xx=rfield(),
118+
eps_yy=rfield(),
119+
eps_zz=rfield(),
120+
)
121+
)
122+
123+
return td.SimulationData(simulation=simulation, data=tuple(data_items))
124+
125+
126+
def _emulated_run_async_tidy3d(simulations, **kwargs):
127+
"""Batch wrapper around the minimal emulator."""
128+
sim_data_map = {}
129+
for task_name, sim in simulations.items():
130+
sim_data_map[task_name] = _run_emulated_minimal(sim)
131+
132+
class _BatchLike(dict):
133+
def __getitem__(self, key):
134+
return sim_data_map[key]
135+
136+
return _BatchLike(sim_data_map), {}
137+
138+
139+
@pytest.fixture
140+
def patch_web_autograd_emulator(monkeypatch):
141+
"""Patch web autograd internals to use the local minimal emulator."""
142+
143+
monkeypatch.setattr(web_ag, "_run_async_tidy3d", _emulated_run_async_tidy3d)
144+
yield
145+
146+
147+
def _build_base_sim(scale: float) -> td.Simulation:
148+
"""Shared base Simulation for modal and terminal modelers."""
149+
return td.Simulation(
150+
size=(4.0, 4.0, 4.0),
151+
run_time=1e-13,
152+
grid_spec=td.GridSpec.uniform(dl=0.2),
153+
boundary_spec=td.BoundarySpec.pml(x=True, y=True, z=True),
154+
structures=[
155+
td.Structure(
156+
geometry=td.Box(size=(1.0 + scale, 1.0, 1.0), center=(0.0, 0.0, 0.0)),
157+
medium=td.Medium(permittivity=2.0 + 0.1 * scale),
158+
)
159+
],
160+
sources=[],
161+
monitors=[],
162+
)
163+
164+
165+
def build_modal_modeler(scale: float) -> ModalComponentModeler:
166+
sim = _build_base_sim(scale)
167+
168+
# two modal ports on +/- z sides
169+
port_size = (2.0, 2.0, 0.0)
170+
p1 = ModalPort(
171+
center=(0.0, 0.0, -1.5),
172+
size=port_size,
173+
direction="+",
174+
mode_spec=td.ModeSpec(num_modes=1),
175+
name="p1",
176+
)
177+
p2 = ModalPort(
178+
center=(0.0, 0.0, 1.5),
179+
size=port_size,
180+
direction="-",
181+
mode_spec=td.ModeSpec(num_modes=1),
182+
name="p2",
183+
)
184+
185+
freqs = [2.0e14]
186+
return ModalComponentModeler(simulation=sim, ports=(p1, p2), freqs=freqs)
187+
188+
189+
def build_terminal_modeler(scale: float) -> TerminalComponentModeler:
190+
sim = _build_base_sim(scale)
191+
192+
# two lumped ports on +/- z sides; injection axis is z
193+
port_size = (1.0, 1.0, 0.0)
194+
p1 = RectLumpedPort(
195+
center=(0.0, 0.0, -1.5),
196+
size=port_size,
197+
voltage_axis=1,
198+
name="lp1",
199+
impedance=50.0,
200+
)
201+
p2 = RectLumpedPort(
202+
center=(0.0, 0.0, 1.5),
203+
size=port_size,
204+
voltage_axis=1,
205+
name="lp2",
206+
impedance=50.0,
207+
)
208+
209+
freqs = [2.0e14]
210+
return TerminalComponentModeler(simulation=sim, ports=(p1, p2), freqs=freqs)
211+
212+
213+
def test_component_modeler_autograd_tracing(patch_web_autograd_emulator, tmp_path):
214+
td.config.logging_level = "ERROR"
215+
td.config.log_suppression = True
216+
217+
def objective(scale: float) -> float:
218+
modeler = build_modal_modeler(scale)
219+
modeler_data = smatrix_run(
220+
modeler,
221+
path_dir=str(tmp_path),
222+
verbose=False,
223+
local_gradient=True,
224+
)
225+
s = modeler_data.smatrix() # ModalPortDataArray
226+
return anp.real(anp.sum(s.data))
227+
228+
# verify that gradients propagate without error
229+
g = ag.grad(objective)(1.0)
230+
assert np.isfinite(g)
231+
assert not np.isclose(g, 0.0)
232+
233+
234+
def test_component_modeler_autograd_tracing_modeler_run(patch_web_autograd_emulator, tmp_path):
235+
td.config.logging_level = "ERROR"
236+
td.config.log_suppression = True
237+
238+
def objective(scale: float) -> float:
239+
modeler = build_modal_modeler(scale)
240+
modeler_data = modeler.run(
241+
path_dir=str(tmp_path),
242+
verbose=False,
243+
local_gradient=True,
244+
)
245+
s = modeler_data.smatrix()
246+
return anp.real(anp.sum(s.data))
247+
248+
g = ag.grad(objective)(1.0)
249+
assert np.isfinite(g)
250+
assert not np.isclose(g, 0.0)
251+
252+
253+
def test_terminal_component_modeler_autograd_tracing_stubbed(
254+
patch_web_autograd_emulator, monkeypatch, tmp_path
255+
):
256+
"""Autograd plumbing test for TerminalComponentModeler using a minimal S-matrix stub.
257+
258+
This test verifies that web.run autograd integration (forward/adjoint batching and result
259+
composition) works for terminal modelers when terminal_construct_smatrix is replaced with a
260+
simple function of FieldData. The full terminal analysis relies on voltage/current integrals
261+
with interpolation and Yee-grid snapping, which are not autograd-compatible; hence the use of
262+
a stub to keep the test fast and robust.
263+
"""
264+
265+
td.config.logging_level = "ERROR"
266+
td.config.log_suppression = True
267+
268+
# Minimal stub: reduce first available FieldData per port over space (keep f), place on diagonal
269+
def _fake_terminal_construct_smatrix(
270+
modeler_data, assume_ideal_excitation=False, s_param_def="pseudo"
271+
):
272+
ports = list(modeler_data.modeler.network_dict.keys())
273+
freqs = list(modeler_data.modeler.freqs)
274+
f_len = len(freqs)
275+
n = len(ports)
276+
277+
diag_vals = []
278+
for p in ports:
279+
sim_data = modeler_data.data[p]
280+
vals_f = None
281+
for d in sim_data.data:
282+
if isinstance(d, td.FieldData):
283+
arr = next(iter(d.field_components.values()))
284+
val_comp = arr.sum(dim=[c for c in arr.dims if c != "f"]).astype(complex)
285+
if "f" not in val_comp.dims:
286+
val_comp = td.FreqDataArray(
287+
np.ones((f_len,), dtype=complex), coords={"f": freqs}
288+
)
289+
vals_f = val_comp
290+
break
291+
if vals_f is None:
292+
vals_f = td.FreqDataArray(np.ones((f_len,), dtype=complex), coords={"f": freqs})
293+
diag_vals.append(vals_f)
294+
295+
data = anp.zeros((f_len, n, n), dtype=complex)
296+
for i, s in enumerate(diag_vals):
297+
Ei = np.zeros((n, n))
298+
Ei[i, i] = 1.0
299+
data = data + anp.einsum("f,ij->fij", s.values, Ei)
300+
301+
return TerminalPortDataArray(data, coords={"f": freqs, "port_out": ports, "port_in": ports})
302+
303+
monkeypatch.setattr(
304+
terminal_analysis, "terminal_construct_smatrix", _fake_terminal_construct_smatrix
305+
)
306+
307+
def objective(scale: float) -> float:
308+
modeler = build_terminal_modeler(scale)
309+
modeler_data = smatrix_run(
310+
modeler,
311+
path_dir=str(tmp_path),
312+
verbose=False,
313+
local_gradient=True,
314+
)
315+
s_vals = modeler_data.smatrix().data.data
316+
return anp.real(anp.sum(s_vals))
317+
318+
g = ag.grad(objective)(1.0)
319+
assert np.isfinite(g)
320+
assert not np.isclose(g, 0.0)
321+
322+
323+
def test_terminal_component_modeler_autograd_tracing_modeler_run_stubbed(
324+
patch_web_autograd_emulator, monkeypatch, tmp_path
325+
):
326+
"""Same as terminal autograd test, but running via modeler.run()."""
327+
328+
td.config.logging_level = "ERROR"
329+
td.config.log_suppression = True
330+
331+
def _fake_terminal_construct_smatrix(
332+
modeler_data, assume_ideal_excitation=False, s_param_def="pseudo"
333+
):
334+
ports = list(modeler_data.modeler.network_dict.keys())
335+
freqs = list(modeler_data.modeler.freqs)
336+
f_len = len(freqs)
337+
n = len(ports)
338+
339+
diag_vals = []
340+
for p in ports:
341+
sim_data = modeler_data.data[p]
342+
vals_f = None
343+
for d in sim_data.data:
344+
if isinstance(d, td.FieldData):
345+
arr = next(iter(d.field_components.values()))
346+
val_comp = arr.sum(dim=[c for c in arr.dims if c != "f"]).astype(complex)
347+
if "f" not in val_comp.dims:
348+
val_comp = td.FreqDataArray(
349+
np.ones((f_len,), dtype=complex), coords={"f": freqs}
350+
)
351+
vals_f = val_comp
352+
break
353+
if vals_f is None:
354+
vals_f = td.FreqDataArray(np.ones((f_len,), dtype=complex), coords={"f": freqs})
355+
diag_vals.append(vals_f)
356+
357+
data = anp.zeros((f_len, n, n), dtype=complex)
358+
for i, s in enumerate(diag_vals):
359+
Ei = np.zeros((n, n))
360+
Ei[i, i] = 1.0
361+
data = data + anp.einsum("f,ij->fij", s.values, Ei)
362+
363+
return TerminalPortDataArray(data, coords={"f": freqs, "port_out": ports, "port_in": ports})
364+
365+
monkeypatch.setattr(
366+
terminal_analysis, "terminal_construct_smatrix", _fake_terminal_construct_smatrix
367+
)
368+
369+
def objective(scale: float) -> float:
370+
modeler = build_terminal_modeler(scale)
371+
modeler_data = modeler.run(
372+
path_dir=str(tmp_path),
373+
verbose=False,
374+
local_gradient=True,
375+
)
376+
s_vals = modeler_data.smatrix().data.data
377+
return anp.real(anp.sum(s_vals))
378+
379+
g = ag.grad(objective)(1.0)
380+
assert np.isfinite(g)
381+
assert not np.isclose(g, 0.0)

0 commit comments

Comments
 (0)