Skip to content

Commit 3ef5385

Browse files
Gregory Robertsyaugenst-flex
authored andcommitted
fix(adjoint): disallow autograd with element_mappings for a ComponentModeler
1 parent 7b25de0 commit 3ef5385

File tree

4 files changed

+50
-12
lines changed

4 files changed

+50
-12
lines changed

tests/test_plugins/smatrix/test_component_modeler_autograd.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -232,6 +232,35 @@ def objective(scale: float) -> float:
232232
assert not np.isclose(g, 0.0)
233233

234234

235+
def test_component_modeler_autograd_error_with_element_mappings(
236+
patch_web_autograd_emulator, tmp_path
237+
):
238+
"""Verify that we get the expected error when running a component modeler with `element_mappings` in autograd."""
239+
td.config.logging_level = "ERROR"
240+
td.config.log_suppression = True
241+
242+
def objective(scale: float) -> float:
243+
modeler = build_modal_modeler(scale)
244+
# set up symmetry (port names only; directions are handled internally)
245+
left = ("p1", 0)
246+
right = ("p2", 0)
247+
248+
element_mappings = [
249+
((left, right), (right, left), 1.0),
250+
]
251+
252+
modeler_with_mappings = modeler.updated_copy(element_mappings=element_mappings)
253+
s = modeler_with_mappings.run(
254+
path_dir=str(tmp_path),
255+
verbose=False,
256+
local_gradient=True,
257+
)
258+
return anp.real(anp.sum(s.data))
259+
260+
g = ag.grad(objective)(1.0)
261+
assert np.isfinite(g)
262+
263+
235264
def test_component_modeler_autograd_tracing_modeler_run(patch_web_autograd_emulator, tmp_path):
236265
td.config.logging_level = "ERROR"
237266
td.config.log_suppression = True

tidy3d/plugins/smatrix/component_modelers/base.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -227,7 +227,8 @@ def run(
227227
):
228228
log.warning(
229229
"'ComponentModeler.run()' is deprecated and will be removed in a future release. "
230-
"Use 'web.run(modeler)' instead.",
230+
"Use web.run(modeler) instead. 'web.run' returns a 'ComponentModelerData' object; "
231+
"get the scattering matrix via 'data.smatrix()'.",
231232
log_once=True,
232233
)
233234
from tidy3d.plugins.smatrix.run import _run_local

tidy3d/plugins/smatrix/run.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
from tidy3d.components.base import Tidy3dBaseModel
66
from tidy3d.components.data.index import SimulationDataMap
7+
from tidy3d.log import log
78
from tidy3d.plugins.smatrix.component_modelers.modal import ModalComponentModeler
89
from tidy3d.plugins.smatrix.component_modelers.terminal import TerminalComponentModeler
910
from tidy3d.plugins.smatrix.component_modelers.types import ComponentModelerType
@@ -179,15 +180,24 @@ def _run_local(
179180
# autograd path if any sim is valid for autograd
180181
from tidy3d.web.api.autograd import autograd as web_ag
181182

182-
sims = getattr(modeler, "sim_dict", None) or {}
183+
sims = modeler.sim_dict
183184
if any(web_ag.is_valid_for_autograd(sim) for sim in sims.values()):
185+
if len(modeler.element_mappings) > 0:
186+
log.warning(
187+
"Element mappings are used to populate S-matrix values, but autograd gradients "
188+
"are computed only for simulated elements. Gradients for mapped elements are not "
189+
"included. For optimization with autograd, prefer enforcing symmetry in geometry/"
190+
"objective functions and use 'run_only' to select unique sources.",
191+
log_once=True,
192+
)
193+
184194
from tidy3d.web.api.autograd.autograd import _run_async
185195

186196
kwargs.setdefault("folder_name", "default")
187197
kwargs.setdefault("simulation_type", "tidy3d_autograd_async")
188198
kwargs.setdefault("path_dir", path_dir)
189199

190-
sim_data_map = _run_async(simulations=modeler.sim_dict, **kwargs)
200+
sim_data_map = _run_async(simulations=sims, **kwargs)
191201

192202
return compose_modeler_data_from_batch_data(modeler=modeler, batch_data=sim_data_map)
193203

tidy3d/web/api/autograd/autograd.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from tidy3d.components.autograd.derivative_utils import DerivativeInfo
2424
from tidy3d.components.data.data_array import DataArray
2525
from tidy3d.components.grid.grid_spec import GridSpec
26+
from tidy3d.components.types.workflow import WorkflowDataType, WorkflowType
2627
from tidy3d.exceptions import AdjointError
2728
from tidy3d.web.api.asynchronous import DEFAULT_DATA_DIR
2829
from tidy3d.web.api.asynchronous import run_async as run_async_webapi
@@ -94,7 +95,7 @@ def is_valid_for_autograd_async(simulations: dict[str, td.Simulation]) -> bool:
9495

9596

9697
def run(
97-
simulation: typing.Any,
98+
simulation: WorkflowType,
9899
task_name: str,
99100
folder_name: str = "default",
100101
path: str = "simulation_data.hdf5",
@@ -111,14 +112,14 @@ def run(
111112
reduce_simulation: typing.Literal["auto", True, False] = "auto",
112113
pay_type: typing.Union[PayType, str] = PayType.AUTO,
113114
priority: typing.Optional[int] = None,
114-
) -> td.SimulationData:
115+
) -> WorkflowDataType:
115116
"""
116117
Submits a :class:`.Simulation` to server, starts running, monitors progress, downloads,
117118
and loads results as a :class:`.WorkflowDataType` object.
118119
119120
Parameters
120121
----------
121-
simulation : Union[:class:`.Simulation`, :class:`.HeatSimulation`, :class:`.EMESimulation`]
122+
simulation : Union[:class:`.Simulation`, :class:`.HeatSimulation`, :class:`.EMESimulation`, :class:`.ModalComponentModeler`, :class:`.TerminalComponentModeler`]
122123
Simulation to upload to server.
123124
task_name : str
124125
Name of task.
@@ -154,8 +155,8 @@ def run(
154155
Task priority for vGPU queue (1=lowest, 10=highest).
155156
Returns
156157
-------
157-
Union[:class:`.SimulationData`, :class:`.HeatSimulationData`, :class:`.EMESimulationData`]
158-
Object containing solver results for the supplied simulation.
158+
Union[:class:`.SimulationData`, :class:`.HeatSimulationData`, :class:`.EMESimulationData`, :class:`.ModalComponentModelerData`, :class:`.TerminalComponentModelerData`]
159+
Object containing solver results for the supplied input.
159160
160161
Notes
161162
-----
@@ -203,10 +204,7 @@ def run(
203204
from tidy3d.plugins.smatrix.component_modelers.types import ComponentModelerType
204205

205206
if isinstance(simulation, typing.get_args(ComponentModelerType)):
206-
sims = getattr(simulation, "sim_dict", None)
207-
if local_gradient or (
208-
isinstance(sims, dict) and any(is_valid_for_autograd(s) for s in sims.values())
209-
):
207+
if any(is_valid_for_autograd(s) for s in simulation.sim_dict.values()):
210208
from tidy3d.plugins.smatrix import run as smatrix_run
211209

212210
path_dir = dirname(path) or "."

0 commit comments

Comments
 (0)