1717from collections import deque
1818import time
1919import threading
20- from typing import Callable , Dict , List , Optional , Tuple , Any
20+ from typing import Callable , Dict , List , Optional , Tuple , Any , Union
2121import functools
2222import itertools
2323import inspect
3939from sagemaker .remote_function import logging_config
4040from sagemaker .utils import name_from_base , base_from_name
4141from sagemaker .remote_function .spark_config import SparkConfig
42- from sagemaker .remote_function .workdir_config import WorkdirConfig
42+ from sagemaker .remote_function .custom_file_filter import CustomFileFilter
4343
4444_API_CALL_LIMIT = {
4545 "SubmittingIntervalInSecs" : 1 ,
@@ -66,7 +66,7 @@ def remote(
6666 environment_variables : Dict [str , str ] = None ,
6767 image_uri : str = None ,
6868 include_local_workdir : bool = False ,
69- workdir_config : WorkdirConfig = None ,
69+ custom_file_filter : Optional [ Union [ Callable [[ str , List ], List ], CustomFileFilter ]] = None ,
7070 instance_count : int = 1 ,
7171 instance_type : str = None ,
7272 job_conda_env : str = None ,
@@ -87,7 +87,6 @@ def remote(
8787 spark_config : SparkConfig = None ,
8888 use_spot_instances = False ,
8989 max_wait_time_in_seconds = None ,
90- custom_file_filter : Optional [Callable [[str , List ], List ]] = None ,
9190):
9291 """Decorator for running the annotated function as a SageMaker training job.
9392
@@ -195,10 +194,12 @@ def remote(
195194 methods that are not available via PyPI or conda. Only python files are included.
196195 Default value is ``False``.
197196
198- workdir_config (WorkdirConfig): A ``WorkdirConfig`` object that specifies the
199- local directories and files to be included in the remote function.
200- workdir_config takes precedence over include_local_workdir.
201- Default value is ``None``.
197+ custom_file_filter (Callable[[str, List], List], CustomFileFilter): Either a function
198+ that filters job dependencies to be uploaded to S3 or a ``CustomFileFilter`` object
199+ that specifies the local directories and files to be included in the remote function.
200+ If a callable is passed in, that function is passed to the ``ignore`` argument of
201+ ``shutil.copytree``. Defaults to ``None``, which means only python
202+ files are accepted and uploaded to S3.
202203
203204 instance_count (int): The number of instances to use. Defaults to 1.
204205 NOTE: Remote function does not support instance_count > 1 for non Spark jobs.
@@ -274,11 +275,6 @@ def remote(
274275 max_wait_time_in_seconds (int): Timeout in seconds waiting for spot training job.
275276 After this amount of time Amazon SageMaker will stop waiting for managed spot training
276277 job to complete. Defaults to ``None``.
277-
278- custom_file_filter (Callable[[str, List], List]): A function that filters job
279- dependencies to be uploaded to S3. This function is passed to the ``ignore``
280- argument of ``shutil.copytree``. Defaults to ``None``, which means only python
281- files are accepted.
282278 """
283279
284280 def _remote (func ):
@@ -290,7 +286,7 @@ def _remote(func):
290286 environment_variables = environment_variables ,
291287 image_uri = image_uri ,
292288 include_local_workdir = include_local_workdir ,
293- workdir_config = workdir_config ,
289+ custom_file_filter = custom_file_filter ,
294290 instance_count = instance_count ,
295291 instance_type = instance_type ,
296292 job_conda_env = job_conda_env ,
@@ -311,7 +307,6 @@ def _remote(func):
311307 spark_config = spark_config ,
312308 use_spot_instances = use_spot_instances ,
313309 max_wait_time_in_seconds = max_wait_time_in_seconds ,
314- custom_file_filter = custom_file_filter ,
315310 )
316311
317312 @functools .wraps (func )
@@ -501,7 +496,7 @@ def __init__(
501496 environment_variables : Dict [str , str ] = None ,
502497 image_uri : str = None ,
503498 include_local_workdir : bool = False ,
504- workdir_config : WorkdirConfig = None ,
499+ custom_file_filter : Optional [ Union [ Callable [[ str , List ], List ], CustomFileFilter ]] = None ,
505500 instance_count : int = 1 ,
506501 instance_type : str = None ,
507502 job_conda_env : str = None ,
@@ -523,7 +518,6 @@ def __init__(
523518 spark_config : SparkConfig = None ,
524519 use_spot_instances = False ,
525520 max_wait_time_in_seconds = None ,
526- custom_file_filter : Optional [Callable [[str , List ], List ]] = None ,
527521 ):
528522 """Constructor for RemoteExecutor
529523
@@ -628,10 +622,12 @@ def __init__(
628622 local directories. Set to ``True`` if the remote function code imports local modules
629623 and methods that are not available via PyPI or conda. Default value is ``False``.
630624
631- workdir_config (WorkdirConfig): A ``WorkdirConfig`` object that specifies the
632- local directories and files to be included in the remote function.
633- workdir_config takes precedence over include_local_workdir.
634- Default value is ``None``.
625+ custom_file_filter (Callable[[str, List], List], CustomFileFilter): Either a function
626+ that filters job dependencies to be uploaded to S3 or a ``CustomFileFilter`` object
627+ that specifies the local directories and files to be included in the remote function.
628+ If a callable is passed in, that function is passed to the ``ignore`` argument of
629+ ``shutil.copytree``. Defaults to ``None``, which means only python
630+ files are accepted and uploaded to S3.
635631
636632 instance_count (int): The number of instances to use. Defaults to 1.
637633 NOTE: Remote function does not support instance_count > 1 for non Spark jobs.
@@ -715,11 +711,6 @@ def __init__(
715711 max_wait_time_in_seconds (int): Timeout in seconds waiting for spot training job.
716712 After this amount of time Amazon SageMaker will stop waiting for managed spot training
717713 job to complete. Defaults to ``None``.
718-
719- custom_file_filter (Callable[[str, List], List]): A function that filters job
720- dependencies to be uploaded to S3. This function is passed to the ``ignore``
721- argument of ``shutil.copytree``. Defaults to ``None``, which means only python
722- files are accepted.
723714 """
724715 self .max_parallel_jobs = max_parallel_jobs
725716
@@ -739,7 +730,7 @@ def __init__(
739730 environment_variables = environment_variables ,
740731 image_uri = image_uri ,
741732 include_local_workdir = include_local_workdir ,
742- workdir_config = workdir_config ,
733+ custom_file_filter = custom_file_filter ,
743734 instance_count = instance_count ,
744735 instance_type = instance_type ,
745736 job_conda_env = job_conda_env ,
@@ -760,7 +751,6 @@ def __init__(
760751 spark_config = spark_config ,
761752 use_spot_instances = use_spot_instances ,
762753 max_wait_time_in_seconds = max_wait_time_in_seconds ,
763- custom_file_filter = custom_file_filter ,
764754 )
765755
766756 self ._state_condition = threading .Condition ()
0 commit comments