|
23 | 23 | import time |
24 | 24 |
|
25 | 25 | from abc import ABC |
| 26 | +from typing import Union, Optional |
26 | 27 |
|
27 | 28 | import attr |
28 | 29 |
|
29 | 30 | import smdebug_rulesconfig as rule_configs |
30 | 31 |
|
31 | 32 | from sagemaker import image_uris |
32 | 33 | from sagemaker.utils import build_dict |
| 34 | +from sagemaker.workflow.entities import PipelineVariable |
33 | 35 |
|
34 | 36 | framework_name = "debugger" |
35 | 37 | DEBUGGER_FLAG = "USE_SMDEBUG" |
@@ -311,10 +313,10 @@ def sagemaker( |
311 | 313 | @classmethod |
312 | 314 | def custom( |
313 | 315 | cls, |
314 | | - name, |
315 | | - image_uri, |
316 | | - instance_type, |
317 | | - volume_size_in_gb, |
| 316 | + name: str, |
| 317 | + image_uri: Union[str, PipelineVariable], |
| 318 | + instance_type: Union[str, PipelineVariable], |
| 319 | + volume_size_in_gb: Union[int, PipelineVariable], |
318 | 320 | source=None, |
319 | 321 | rule_to_invoke=None, |
320 | 322 | container_local_output_path=None, |
@@ -610,7 +612,7 @@ class DebuggerHookConfig(object): |
610 | 612 |
|
611 | 613 | def __init__( |
612 | 614 | self, |
613 | | - s3_output_path=None, |
| 615 | + s3_output_path: Optional[Union[str, PipelineVariable]] = None, |
614 | 616 | container_local_output_path=None, |
615 | 617 | hook_parameters=None, |
616 | 618 | collection_configs=None, |
@@ -679,7 +681,9 @@ def _to_request_dict(self): |
679 | 681 | class TensorBoardOutputConfig(object): |
680 | 682 | """Create a tensor ouput configuration object for debugging visualizations on TensorBoard.""" |
681 | 683 |
|
682 | | - def __init__(self, s3_output_path, container_local_output_path=None): |
| 684 | + def __init__( |
| 685 | + self, s3_output_path: Union[str, PipelineVariable], container_local_output_path=None |
| 686 | + ): |
683 | 687 | """Initialize the TensorBoardOutputConfig instance. |
684 | 688 |
|
685 | 689 | Args: |
|
0 commit comments