|
24 | 24 |
|
25 | 25 | from abc import ABC |
26 | 26 |
|
| 27 | +from typing import Union, Optional, List, Dict |
| 28 | + |
27 | 29 | import attr |
28 | 30 |
|
29 | 31 | import smdebug_rulesconfig as rule_configs |
30 | 32 |
|
31 | 33 | from sagemaker import image_uris |
32 | 34 | from sagemaker.utils import build_dict |
| 35 | +from sagemaker.workflow.entities import PipelineVariable |
33 | 36 |
|
34 | 37 | framework_name = "debugger" |
35 | 38 | DEBUGGER_FLAG = "USE_SMDEBUG" |
@@ -311,17 +314,17 @@ def sagemaker( |
311 | 314 | @classmethod |
312 | 315 | def custom( |
313 | 316 | cls, |
314 | | - name, |
315 | | - image_uri, |
316 | | - instance_type, |
317 | | - volume_size_in_gb, |
318 | | - source=None, |
319 | | - rule_to_invoke=None, |
320 | | - container_local_output_path=None, |
321 | | - s3_output_path=None, |
322 | | - other_trials_s3_input_paths=None, |
323 | | - rule_parameters=None, |
324 | | - collections_to_save=None, |
| 317 | + name: str, |
| 318 | + image_uri: Union[str, PipelineVariable], |
| 319 | + instance_type: Union[str, PipelineVariable], |
| 320 | + volume_size_in_gb: Union[int, PipelineVariable], |
| 321 | + source: Optional[str] = None, |
| 322 | + rule_to_invoke: Optional[Union[str, PipelineVariable]] = None, |
| 323 | + container_local_output_path: Optional[Union[str, PipelineVariable]] = None, |
| 324 | + s3_output_path: Optional[Union[str, PipelineVariable]] = None, |
| 325 | + other_trials_s3_input_paths: Optional[List[Union[str, PipelineVariable]]] = None, |
| 326 | + rule_parameters: Optional[Dict[str, Union[str, PipelineVariable]]] = None, |
| 327 | + collections_to_save: Optional[List["CollectionConfig"]] = None, |
325 | 328 | actions=None, |
326 | 329 | ): |
327 | 330 | """Initialize a ``Rule`` object for a *custom* debugging rule. |
@@ -610,10 +613,10 @@ class DebuggerHookConfig(object): |
610 | 613 |
|
611 | 614 | def __init__( |
612 | 615 | self, |
613 | | - s3_output_path=None, |
614 | | - container_local_output_path=None, |
615 | | - hook_parameters=None, |
616 | | - collection_configs=None, |
| 616 | + s3_output_path: Optional[Union[str, PipelineVariable]] = None, |
| 617 | + container_local_output_path: Optional[Union[str, PipelineVariable]] = None, |
| 618 | + hook_parameters: Optional[Dict[str, Union[str, PipelineVariable]]] = None, |
| 619 | + collection_configs: Optional[List["CollectionConfig"]] = None, |
617 | 620 | ): |
618 | 621 | """Initialize the DebuggerHookConfig instance. |
619 | 622 |
|
@@ -679,7 +682,11 @@ def _to_request_dict(self): |
679 | 682 | class TensorBoardOutputConfig(object): |
680 | 683 | """Create a tensor ouput configuration object for debugging visualizations on TensorBoard.""" |
681 | 684 |
|
682 | | - def __init__(self, s3_output_path, container_local_output_path=None): |
| 685 | + def __init__( |
| 686 | + self, |
| 687 | + s3_output_path: Union[str, PipelineVariable], |
| 688 | + container_local_output_path: Optional[Union[str, PipelineVariable]] = None, |
| 689 | + ): |
683 | 690 | """Initialize the TensorBoardOutputConfig instance. |
684 | 691 |
|
685 | 692 | Args: |
@@ -708,7 +715,11 @@ def _to_request_dict(self): |
708 | 715 | class CollectionConfig(object): |
709 | 716 | """Creates tensor collections for SageMaker Debugger.""" |
710 | 717 |
|
711 | | - def __init__(self, name, parameters=None): |
| 718 | + def __init__( |
| 719 | + self, |
| 720 | + name: Union[str, PipelineVariable], |
| 721 | + parameters: Optional[Dict[str, Union[str, PipelineVariable]]] = None, |
| 722 | + ): |
712 | 723 | """Constructor for collection configuration. |
713 | 724 |
|
714 | 725 | Args: |
|
0 commit comments