Skip to content

Commit 145dd95

Browse files
QingengWeiyaugenst-flex
authored andcommitted
feat(webapi): add priority to Bach, Job, and run_async
1 parent b2b9bef commit 145dd95

File tree

6 files changed

+80
-24
lines changed

6 files changed

+80
-24
lines changed

tests/test_web/test_webapi.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -673,7 +673,7 @@ def track_to_file(self, fname):
673673
return original_to_file(self, fname)
674674

675675
# mock start to interrupt run() after upload and to_file
676-
def mock_start_interrupt(self):
676+
def mock_start_interrupt(self, *args, **kwargs):
677677
# at this point, upload() and to_file() should have been called
678678
assert batch_file_saved["saved"], "Batch file should be saved before start()"
679679
assert batch_file_saved["has_task_ids"], "Batch file should have task_ids"

tidy3d/web/api/asynchronous.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ def run_async(
2323
parent_tasks: Optional[dict[str, list[str]]] = None,
2424
reduce_simulation: Literal["auto", True, False] = "auto",
2525
pay_type: Union[PayType, str] = PayType.AUTO,
26+
priority: Optional[int] = None,
2627
) -> BatchData:
2728
"""Submits a set of Union[:class:`.Simulation`, :class:`.HeatSimulation`, :class:`.EMESimulation`] objects to server,
2829
starts running, monitors progress, downloads, and loads results as a :class:`.BatchData` object.
@@ -52,7 +53,9 @@ def run_async(
5253
Whether to reduce structures in the simulation to the simulation domain only. Note: currently only implemented for the mode solver.
5354
pay_type: Union[PayType, str] = PayType.AUTO
5455
Specify the payment method.
55-
56+
priority: int = None
57+
Priority of the simulation in the Virtual GPU (vGPU) queue (1 = lowest, 10 = highest).
58+
It affects only simulations from vGPU licenses and does not impact simulations using FlexCredits.
5659
Returns
5760
------
5861
:class:`BatchData`
@@ -90,5 +93,5 @@ def run_async(
9093
pay_type=pay_type,
9194
)
9295

93-
batch_data = batch.run(path_dir=path_dir)
96+
batch_data = batch.run(path_dir=path_dir, priority=priority)
9497
return batch_data

tidy3d/web/api/autograd/autograd.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -218,6 +218,7 @@ def run(
218218
local_gradient=local_gradient,
219219
max_num_adjoint_per_fwd=max_num_adjoint_per_fwd,
220220
pay_type=pay_type,
221+
priority=priority,
221222
)
222223

223224
return run_webapi(
@@ -253,6 +254,7 @@ def run_async(
253254
max_num_adjoint_per_fwd: int = MAX_NUM_ADJOINT_PER_FWD,
254255
reduce_simulation: typing.Literal["auto", True, False] = "auto",
255256
pay_type: typing.Union[PayType, str] = PayType.AUTO,
257+
priority: typing.Optional[int] = None,
256258
) -> BatchData:
257259
"""Submits a set of Union[:class:`.Simulation`, :class:`.HeatSimulation`, :class:`.EMESimulation`] objects to server,
258260
starts running, monitors progress, downloads, and loads results as a :class:`.BatchData` object.
@@ -303,6 +305,10 @@ def run_async(
303305
:class:`Batch`
304306
Interface for submitting several :class:`Simulation` objects to sever.
305307
"""
308+
# validate priority if specified
309+
if priority is not None and (priority < 1 or priority > 10):
310+
raise ValueError("Priority must be between '1' and '10' if specified.")
311+
306312
if is_valid_for_autograd_async(simulations):
307313
return _run_async(
308314
simulations=simulations,
@@ -317,6 +323,7 @@ def run_async(
317323
local_gradient=local_gradient,
318324
max_num_adjoint_per_fwd=max_num_adjoint_per_fwd,
319325
pay_type=pay_type,
326+
priority=priority,
320327
)
321328

322329
return run_async_webapi(
@@ -331,6 +338,7 @@ def run_async(
331338
parent_tasks=parent_tasks,
332339
reduce_simulation=reduce_simulation,
333340
pay_type=pay_type,
341+
priority=priority,
334342
)
335343

336344

@@ -1272,10 +1280,11 @@ def _run_tidy3d(
12721280
verbose = run_kwargs.get("verbose", False)
12731281
upload_sim_fields_keys(run_kwargs["sim_fields_keys"], task_id=job.task_id, verbose=verbose)
12741282
path = run_kwargs.get("path", DEFAULT_DATA_PATH)
1283+
priority = run_kwargs.get("priority")
12751284
if task_name.endswith("_adjoint"):
12761285
path_parts = basename(path).split(".")
12771286
path = join(dirname(path), path_parts[0] + "_adjoint." + ".".join(path_parts[1:]))
1278-
data = job.run(path)
1287+
data = job.run(path, priority=priority)
12791288
return data, job.task_id
12801289

12811290

@@ -1286,6 +1295,7 @@ def _run_async_tidy3d(
12861295

12871296
batch_init_kwargs = parse_run_kwargs(**run_kwargs)
12881297
path_dir = run_kwargs.pop("path_dir", None)
1298+
priority = run_kwargs.get("priority")
12891299
batch = Batch(simulations=simulations, **batch_init_kwargs)
12901300
td.log.info(f"running {batch.simulation_type} batch with '_run_async_tidy3d()'")
12911301

@@ -1305,9 +1315,9 @@ def _run_async_tidy3d(
13051315
upload_sim_fields_keys(sim_fields_keys, task_id=task_id, verbose=verbose)
13061316

13071317
if path_dir:
1308-
batch_data = batch.run(path_dir)
1318+
batch_data = batch.run(path_dir, priority=priority)
13091319
else:
1310-
batch_data = batch.run()
1320+
batch_data = batch.run(priority=priority)
13111321

13121322
task_ids = {key: job.task_id for key, job in batch.jobs.items()}
13131323
return batch_data, task_ids
@@ -1324,7 +1334,8 @@ def _run_async_tidy3d_bwd(
13241334
batch = Batch(simulations=simulations, **batch_init_kwargs)
13251335
td.log.info(f"running {batch.simulation_type} batch with '_run_async_tidy3d_bwd()'")
13261336

1327-
batch.start()
1337+
priority = run_kwargs.get("priority")
1338+
batch.start(priority=priority)
13281339
batch.monitor()
13291340

13301341
vjp_traced_fields_dict = {}

tidy3d/web/api/container.py

Lines changed: 52 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -227,21 +227,28 @@ def to_file(self, fname: str) -> None:
227227
self = self.updated_copy(task_id_cached=task_id_cached)
228228
super(Job, self).to_file(fname=fname) # noqa: UP008
229229

230-
def run(self, path: str = DEFAULT_DATA_PATH) -> WorkflowDataType:
230+
def run(
231+
self, path: str = DEFAULT_DATA_PATH, priority: Optional[int] = None
232+
) -> WorkflowDataType:
231233
"""Run :class:`Job` all the way through and return data.
232234
233235
Parameters
234236
----------
235-
path_dir : str = "./simulation_data.hdf5"
236-
Base directory where data will be downloaded, by default current working directory.
237-
237+
path : str = "./simulation_data.hdf5"
238+
Path to download results file (.hdf5), including filename.
239+
priority: int = None
240+
Priority of the simulation in the Virtual GPU (vGPU) queue (1 = lowest, 10 = highest).
241+
It affects only simulations from vGPU licenses and does not impact simulations using FlexCredits.
238242
Returns
239243
-------
240-
Union[:class:`.SimulationData`, :class:`.HeatSimulationData`, :class:`.EMESimulationData`]
244+
:class:`WorkflowDataType`
241245
Object containing simulation results.
242246
"""
243247
self.upload()
244-
self.start()
248+
if priority is None:
249+
self.start()
250+
else:
251+
self.start(priority=priority)
245252
self.monitor()
246253
return self.load(path=path)
247254

@@ -280,14 +287,25 @@ def status(self):
280287
"""Return current status of :class:`Job`."""
281288
return self.get_info().status
282289

283-
def start(self) -> None:
290+
def start(self, priority: Optional[int] = None) -> None:
284291
"""Start running a :class:`Job`.
285292
293+
Parameters
294+
----------
295+
296+
priority: int = None
297+
Priority of the simulation in the Virtual GPU (vGPU) queue (1 = lowest, 10 = highest).
298+
It affects only simulations from vGPU licenses and does not impact simulations using FlexCredits.
286299
Note
287300
----
288301
To monitor progress of the :class:`Job`, call :meth:`Job.monitor` after started.
289302
"""
290-
web.start(self.task_id, solver_version=self.solver_version, pay_type=self.pay_type)
303+
web.start(
304+
self.task_id,
305+
solver_version=self.solver_version,
306+
pay_type=self.pay_type,
307+
priority=priority,
308+
)
291309

292310
def get_run_info(self) -> RunInfo:
293311
"""Return information about the running :class:`Job`.
@@ -581,14 +599,20 @@ class Batch(WebContainer):
581599

582600
_job_type = Job
583601

584-
def run(self, path_dir: str = DEFAULT_DATA_DIR) -> BatchData:
602+
def run(
603+
self,
604+
path_dir: str = DEFAULT_DATA_DIR,
605+
priority: Optional[int] = None,
606+
) -> BatchData:
585607
"""Upload and run each simulation in :class:`Batch`.
586608
587609
Parameters
588610
----------
589611
path_dir : str
590612
Base directory where data will be downloaded, by default current working directory.
591-
613+
priority: int = None
614+
Priority of the simulation in the Virtual GPU (vGPU) queue (1 = lowest, 10 = highest).
615+
It affects only simulations from vGPU licenses and does not impact simulations using FlexCredits.
592616
Returns
593617
------
594618
:class:`BatchData`
@@ -612,7 +636,10 @@ def run(self, path_dir: str = DEFAULT_DATA_DIR) -> BatchData:
612636
self._check_path_dir(path_dir)
613637
self.upload()
614638
self.to_file(self._batch_path(path_dir=path_dir))
615-
self.start()
639+
if priority is None:
640+
self.start()
641+
else:
642+
self.start(priority=priority)
616643
self.monitor()
617644
return self.load(path_dir=path_dir)
618645

@@ -715,9 +742,18 @@ def get_info(self) -> dict[TaskName, TaskInfo]:
715742
info_dict[task_name] = task_info
716743
return info_dict
717744

718-
def start(self) -> None:
745+
def start(
746+
self,
747+
priority: Optional[int] = None,
748+
) -> None:
719749
"""Start running all tasks in the :class:`Batch`.
720750
751+
Parameters
752+
----------
753+
754+
priority: int = None
755+
Priority of the simulation in the Virtual GPU (vGPU) queue (1 = lowest, 10 = highest).
756+
It affects only simulations from vGPU licenses and does not impact simulations using FlexCredits.
721757
Note
722758
----
723759
To monitor the running simulations, can call :meth:`Batch.monitor`.
@@ -728,7 +764,10 @@ def start(self) -> None:
728764

729765
with ThreadPoolExecutor(max_workers=self.num_workers) as executor:
730766
for _, job in self.jobs.items():
731-
executor.submit(job.start)
767+
if priority is None:
768+
executor.submit(job.start)
769+
else:
770+
executor.submit(job.start, priority=priority)
732771

733772
def get_run_info(self) -> dict[TaskName, RunInfo]:
734773
"""get information about a each of the tasks in the :class:`Batch`.

tidy3d/web/api/webapi.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -137,9 +137,10 @@ def run(
137137
reduce_simulation : Literal["auto", True, False] = "auto"
138138
Whether to reduce structures in the simulation to the simulation domain only. Note: currently only implemented for the mode solver.
139139
pay_type: Union[PayType, str] = PayType.AUTO
140-
Which method to pay the simulation.
140+
Which method to pay the simulation.
141141
priority: int = None
142-
Task priority for vGPU queue (1=lowest, 10=highest).
142+
Priority of the simulation in the Virtual GPU (vGPU) queue (1 = lowest, 10 = highest).
143+
It affects only simulations from vGPU licenses and does not impact simulations using FlexCredits.
143144
Returns
144145
-------
145146
Union[:class:`.SimulationData`, :class:`.HeatSimulationData`, :class:`.EMESimulationData`]
@@ -450,7 +451,8 @@ def start(
450451
pay_type: Union[PayType, str] = PayType.AUTO
451452
Which method to pay the simulation
452453
priority: int = None
453-
Task priority for vGPU queue (1=lowest, 10=highest).
454+
Priority of the simulation in the Virtual GPU (vGPU) queue (1 = lowest, 10 = highest).
455+
It affects only simulations from vGPU licenses and does not impact simulations using FlexCredits.
454456
Note
455457
----
456458
To monitor progress, can call :meth:`monitor` after starting simulation.

tidy3d/web/core/task_core.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -477,7 +477,8 @@ def submit(
477477
pay_type: Union[PayType, str] = PayType.AUTO
478478
Which method to pay the simulation.
479479
priority: int = None
480-
Task priority for vGPU queue (1=lowest, 10=highest).
480+
Priority of the simulation in the Virtual GPU (vGPU) queue (1 = lowest, 10 = highest).
481+
It affects only simulations from vGPU licenses and does not impact simulations using FlexCredits.
481482
"""
482483
pay_type = PayType(pay_type) if not isinstance(pay_type, PayType) else pay_type
483484

0 commit comments

Comments
 (0)