|
13 | 13 | """Contains the SageMaker Experiment class.""" |
14 | 14 | from __future__ import absolute_import |
15 | 15 |
|
| 16 | +import time |
| 17 | + |
16 | 18 | from sagemaker.apiutils import _base_types |
| 19 | +from sagemaker.experiments.trial import _Trial |
| 20 | +from sagemaker.experiments.trial_component import _TrialComponent |
17 | 21 |
|
18 | 22 |
|
19 | 23 | class _Experiment(_base_types.Record): |
@@ -44,6 +48,8 @@ class _Experiment(_base_types.Record): |
44 | 48 | _boto_update_members = ["experiment_name", "description", "display_name"] |
45 | 49 | _boto_delete_members = ["experiment_name"] |
46 | 50 |
|
| 51 | + _MAX_DELETE_ALL_ATTEMPTS = 3 |
| 52 | + |
47 | 53 | def save(self): |
48 | 54 | """Save the state of this Experiment to SageMaker. |
49 | 55 |
|
@@ -160,3 +166,72 @@ def _load_or_create( |
160 | 166 | sagemaker_session=sagemaker_session, |
161 | 167 | ) |
162 | 168 | return experiment |
| 169 | + |
| 170 | + def list_trials(self, created_before=None, created_after=None, sort_by=None, sort_order=None): |
| 171 | + """List trials in this experiment matching the specified criteria. |
| 172 | +
|
| 173 | + Args: |
| 174 | + created_before (datetime.datetime): Return trials created before this instant |
| 175 | + (default: None). |
| 176 | + created_after (datetime.datetime): Return trials created after this instant |
| 177 | + (default: None). |
| 178 | + sort_by (str): Which property to sort results by. One of 'Name', 'CreationTime' |
| 179 | + (default: None). |
| 180 | + sort_order (str): One of 'Ascending', or 'Descending' (default: None). |
| 181 | +
|
| 182 | + Returns: |
| 183 | + collections.Iterator[experiments._api_types.TrialSummary] : |
| 184 | + An iterator over trials matching the criteria. |
| 185 | + """ |
| 186 | + return _Trial.list( |
| 187 | + experiment_name=self.experiment_name, |
| 188 | + created_before=created_before, |
| 189 | + created_after=created_after, |
| 190 | + sort_by=sort_by, |
| 191 | + sort_order=sort_order, |
| 192 | + sagemaker_session=self.sagemaker_session, |
| 193 | + ) |
| 194 | + |
| 195 | + def delete_all(self, action): |
| 196 | + """Force to delete the experiment and associated trials, trial components. |
| 197 | +
|
| 198 | + Args: |
| 199 | + action (str): The string '--force' is required to pass in to confirm recursively |
| 200 | + delete the experiments, and all its trials and trial components. |
| 201 | + """ |
| 202 | + if action != "--force": |
| 203 | + raise ValueError( |
| 204 | + "Must confirm with string '--force' in order to delete the experiment and " |
| 205 | + "associated trials, trial components." |
| 206 | + ) |
| 207 | + |
| 208 | + delete_attempt_count = 0 |
| 209 | + last_exception = None |
| 210 | + while True: |
| 211 | + if delete_attempt_count == self._MAX_DELETE_ALL_ATTEMPTS: |
| 212 | + raise Exception("Failed to delete, please try again.") from last_exception |
| 213 | + try: |
| 214 | + for trial_summary in self.list_trials(): |
| 215 | + trial = _Trial.load( |
| 216 | + sagemaker_session=self.sagemaker_session, |
| 217 | + trial_name=trial_summary.trial_name, |
| 218 | + ) |
| 219 | + for ( |
| 220 | + trial_component_summary |
| 221 | + ) in trial.list_trial_components(): # pylint: disable=no-member |
| 222 | + tc = _TrialComponent.load( |
| 223 | + sagemaker_session=self.sagemaker_session, |
| 224 | + trial_component_name=trial_component_summary.trial_component_name, |
| 225 | + ) |
| 226 | + tc.delete(force_disassociate=True) |
| 227 | + # to prevent throttling |
| 228 | + time.sleep(1.2) |
| 229 | + trial.delete() # pylint: disable=no-member |
| 230 | + # to prevent throttling |
| 231 | + time.sleep(1.2) |
| 232 | + self.delete() |
| 233 | + break |
| 234 | + except Exception as ex: # pylint: disable=broad-except |
| 235 | + last_exception = ex |
| 236 | + finally: |
| 237 | + delete_attempt_count = delete_attempt_count + 1 |
0 commit comments