-
Notifications
You must be signed in to change notification settings - Fork 1.2k
Description
Opening this issue on behalf of a SageMaker customer. The customer has pre-compressed and uploaded their source_dir to S3, and wants to set requirements_file to a relative path contained in the source.
The generic Frameworks Estimator allows source_dir to be an S3 location. In this case it skips validation/upload.
Skipping source_dir validation:
sagemaker-python-sdk/src/sagemaker/estimator.py
Lines 830 to 833 in 8b33a30
| # validate source dir will raise a ValueError if there is something wrong with the | |
| # source directory. We are intentionally not handling it because this is a critical error. | |
| if self.source_dir and not self.source_dir.lower().startswith('s3://'): | |
| validate_source_dir(self.entry_point, self.source_dir) |
Skipping
source_dir upload:sagemaker-python-sdk/src/sagemaker/fw_utils.py
Lines 143 to 144 in 8b33a30
| If directory is an S3 URI, an UploadedCode object will be returned, but nothing will be | |
| uploaded to S3 (this allow reuse of code already in S3). |
However the Tensorflow Estimator runs a validation for requirements_file which fails if the location source_dir/requirements_file is not a valid path on the local os:
sagemaker-python-sdk/src/sagemaker/tensorflow/estimator.py
Lines 273 to 285 in 8b33a30
| def _validate_requirements_file(self, requirements_file): | |
| if not requirements_file: | |
| return | |
| if not self.source_dir: | |
| raise ValueError('Must specify source_dir along with a requirements file.') | |
| if os.path.isabs(requirements_file): | |
| raise ValueError('Requirements file {} is not a path relative to source_dir.'.format( | |
| requirements_file)) | |
| if not os.path.exists(os.path.join(self.source_dir, requirements_file)): | |
| raise ValueError('Requirements file {} does not exist.'.format(requirements_file)) |
Seems like it would be easy to skip the local path validation if source_dir is an S3 location. Something like this could be added to _validate_requirements_file:
if source_dir.lower().startswith('s3://'):
return
I know that support for requirements.txt files is limited between the "legacy" TensorFlow container and the newer "script mode" version. But this is a small bug in the SDK which could easily be fixed.