diff --git a/.pylintrc b/.pylintrc index 442b7307bd..0a980d43b3 100644 --- a/.pylintrc +++ b/.pylintrc @@ -88,10 +88,8 @@ disable= protected-access, # TODO: Fix access abstract-method, # TODO: Fix abstract methods wrong-import-order, # TODO: Fix import order - no-else-return, # TODO: Remove unnecessary elses useless-object-inheritance, # TODO: Remove unnecessary imports cyclic-import, # TODO: Resolve cyclic imports - no-else-raise, # TODO: Remove unnecessary elses no-self-use, # TODO: Convert methods to functions where appropriate inconsistent-return-statements, # TODO: Make returns consistent consider-merging-isinstance, # TODO: Merge isinstance where appropriate diff --git a/src/sagemaker/amazon/common.py b/src/sagemaker/amazon/common.py index 5a402a9c25..6bd5047a98 100644 --- a/src/sagemaker/amazon/common.py +++ b/src/sagemaker/amazon/common.py @@ -204,8 +204,8 @@ def read_recordio(f): def _resolve_type(dtype): if dtype == np.dtype(int): return "Int32" - elif dtype == np.dtype(float): + if dtype == np.dtype(float): return "Float64" - elif dtype == np.dtype("float32"): + if dtype == np.dtype("float32"): return "Float32" raise ValueError("Unsupported dtype {} on array".format(dtype)) diff --git a/src/sagemaker/estimator.py b/src/sagemaker/estimator.py index e558822dd7..287d203912 100644 --- a/src/sagemaker/estimator.py +++ b/src/sagemaker/estimator.py @@ -643,8 +643,7 @@ def get_vpc_config(self, vpc_config_override=vpc_utils.VPC_CONFIG_DEFAULT): """ if vpc_config_override is vpc_utils.VPC_CONFIG_DEFAULT: return vpc_utils.to_dict(self.subnets, self.security_group_ids) - else: - return vpc_utils.sanitize(vpc_config_override) + return vpc_utils.sanitize(vpc_config_override) def _ensure_latest_training_job( self, error_message="Estimator is not associated with a training job" @@ -1235,14 +1234,13 @@ def train_image(self): """ if self.image_name: return self.image_name - else: - return create_image_uri( - self.sagemaker_session.boto_region_name, - self.__framework_name__, - self.train_instance_type, - self.framework_version, # pylint: disable=no-member - py_version=self.py_version, # pylint: disable=no-member - ) + return create_image_uri( + self.sagemaker_session.boto_region_name, + self.__framework_name__, + self.train_instance_type, + self.framework_version, # pylint: disable=no-member + py_version=self.py_version, # pylint: disable=no-member + ) @classmethod def attach(cls, training_job_name, sagemaker_session=None, model_channel_name="model"): @@ -1404,13 +1402,10 @@ def _s3_uri_without_prefix_from_input(input_data): for channel_name, channel_s3_uri in input_data.items(): response.update(_s3_uri_prefix(channel_name, channel_s3_uri)) return response - elif isinstance(input_data, str): + if isinstance(input_data, str): return _s3_uri_prefix("training", input_data) - elif isinstance(input_data, s3_input): + if isinstance(input_data, s3_input): return _s3_uri_prefix("training", input_data) - else: - raise ValueError( - "Unrecognized type for S3 input data config - not str or s3_input: {}".format( - input_data - ) - ) + raise ValueError( + "Unrecognized type for S3 input data config - not str or s3_input: {}".format(input_data) + ) diff --git a/src/sagemaker/fw_utils.py b/src/sagemaker/fw_utils.py index 9834d4eb65..3b90f9c006 100644 --- a/src/sagemaker/fw_utils.py +++ b/src/sagemaker/fw_utils.py @@ -87,8 +87,7 @@ def _is_merged_versions(framework, framework_version): lowest_version_list = MERGED_FRAMEWORKS_LOWEST_VERSIONS.get(framework) if lowest_version_list: return is_version_equal_or_higher(lowest_version_list, framework_version) - else: - return False + return False def _using_merged_images(region, framework, py_version, accelerator_type, framework_version): @@ -101,8 +100,7 @@ def _using_merged_images(region, framework, py_version, accelerator_type, framew def _registry_id(region, framework, py_version, account, accelerator_type, framework_version): if _using_merged_images(region, framework, py_version, accelerator_type, framework_version): return "763104351884" - else: - return VALID_ACCOUNTS_BY_REGION.get(region, account) + return VALID_ACCOUNTS_BY_REGION.get(region, account) def create_image_uri( @@ -182,10 +180,7 @@ def create_image_uri( return "{}/{}:{}".format( get_ecr_image_uri_prefix(account, region), MERGED_FRAMEWORKS_REPO_MAP[framework], tag ) - else: - return "{}/sagemaker-{}:{}".format( - get_ecr_image_uri_prefix(account, region), framework, tag - ) + return "{}/sagemaker-{}:{}".format(get_ecr_image_uri_prefix(account, region), framework, tag) def _accelerator_type_valid_for_framework( @@ -324,30 +319,28 @@ def framework_name_from_image(image_name): sagemaker_match = sagemaker_pattern.match(image_name) if sagemaker_match is None: return None, None, None, None - else: - # extract framework, python version and image tag - # We must support both the legacy and current image name format. - name_pattern = re.compile( - r"^(?:sagemaker(?:-rl)?-)?(tensorflow|mxnet|chainer|pytorch|scikit-learn)(?:-)?(scriptmode|training)?:(.*)-(.*?)-(py2|py3)$" # noqa: E501 + # extract framework, python version and image tag + # We must support both the legacy and current image name format. + name_pattern = re.compile( + r"^(?:sagemaker(?:-rl)?-)?(tensorflow|mxnet|chainer|pytorch|scikit-learn)(?:-)?(scriptmode|training)?:(.*)-(.*?)-(py2|py3)$" # noqa: E501 + ) + legacy_name_pattern = re.compile(r"^sagemaker-(tensorflow|mxnet)-(py2|py3)-(cpu|gpu):(.*)$") + + name_match = name_pattern.match(sagemaker_match.group(9)) + legacy_match = legacy_name_pattern.match(sagemaker_match.group(9)) + + if name_match is not None: + fw, scriptmode, ver, device, py = ( + name_match.group(1), + name_match.group(2), + name_match.group(3), + name_match.group(4), + name_match.group(5), ) - legacy_name_pattern = re.compile(r"^sagemaker-(tensorflow|mxnet)-(py2|py3)-(cpu|gpu):(.*)$") - - name_match = name_pattern.match(sagemaker_match.group(9)) - legacy_match = legacy_name_pattern.match(sagemaker_match.group(9)) - - if name_match is not None: - fw, scriptmode, ver, device, py = ( - name_match.group(1), - name_match.group(2), - name_match.group(3), - name_match.group(4), - name_match.group(5), - ) - return fw, py, "{}-{}-{}".format(ver, device, py), scriptmode - elif legacy_match is not None: - return (legacy_match.group(1), legacy_match.group(2), legacy_match.group(4), None) - else: - return None, None, None, None + return fw, py, "{}-{}-{}".format(ver, device, py), scriptmode + if legacy_match is not None: + return (legacy_match.group(1), legacy_match.group(2), legacy_match.group(4), None) + return None, None, None, None def framework_version_from_tag(image_tag): diff --git a/src/sagemaker/job.py b/src/sagemaker/job.py index ee2658fe02..ffe4ba31b0 100644 --- a/src/sagemaker/job.py +++ b/src/sagemaker/job.py @@ -140,25 +140,24 @@ def _convert_input_to_channel(channel_name, channel_s3_input): def _format_string_uri_input(uri_input, validate_uri=True, content_type=None, input_mode=None): if isinstance(uri_input, str) and validate_uri and uri_input.startswith("s3://"): return s3_input(uri_input, content_type=content_type, input_mode=input_mode) - elif isinstance(uri_input, str) and validate_uri and uri_input.startswith("file://"): + if isinstance(uri_input, str) and validate_uri and uri_input.startswith("file://"): return file_input(uri_input) - elif isinstance(uri_input, str) and validate_uri: + if isinstance(uri_input, str) and validate_uri: raise ValueError( 'URI input {} must be a valid S3 or FILE URI: must start with "s3://" or ' '"file://"'.format(uri_input) ) - elif isinstance(uri_input, str): + if isinstance(uri_input, str): return s3_input(uri_input, content_type=content_type, input_mode=input_mode) - elif isinstance(uri_input, s3_input): + if isinstance(uri_input, s3_input): return uri_input - elif isinstance(uri_input, file_input): + if isinstance(uri_input, file_input): return uri_input - else: - raise ValueError( - "Cannot format input {}. Expecting one of str, s3_input, or file_input".format( - uri_input - ) + raise ValueError( + "Cannot format input {}. Expecting one of str, s3_input, or file_input".format( + uri_input ) + ) @staticmethod def _prepare_channel( @@ -171,7 +170,7 @@ def _prepare_channel( ): if not channel_uri: return - elif not channel_name: + if not channel_name: raise ValueError( "Expected a channel name if a channel URI {} is specified".format(channel_uri) ) @@ -197,23 +196,20 @@ def _format_model_uri_input(model_uri, validate_uri=True): distribution="FullyReplicated", content_type="application/x-sagemaker-model", ) - elif ( - isinstance(model_uri, string_types) and validate_uri and model_uri.startswith("file://") - ): + if isinstance(model_uri, string_types) and validate_uri and model_uri.startswith("file://"): return file_input(model_uri) - elif isinstance(model_uri, string_types) and validate_uri: + if isinstance(model_uri, string_types) and validate_uri: raise ValueError( 'Model URI must be a valid S3 or FILE URI: must start with "s3://" or ' '"file://' ) - elif isinstance(model_uri, string_types): + if isinstance(model_uri, string_types): return s3_input( model_uri, input_mode="File", distribution="FullyReplicated", content_type="application/x-sagemaker-model", ) - else: - raise ValueError("Cannot format model URI {}. Expecting str".format(model_uri)) + raise ValueError("Cannot format model URI {}. Expecting str".format(model_uri)) @staticmethod def _format_record_set_list_input(inputs): diff --git a/src/sagemaker/local/data.py b/src/sagemaker/local/data.py index 88d5dde41b..62f41bf4e0 100644 --- a/src/sagemaker/local/data.py +++ b/src/sagemaker/local/data.py @@ -45,7 +45,7 @@ def get_data_source_instance(data_source, sagemaker_session): parsed_uri = urlparse(data_source) if parsed_uri.scheme == "file": return LocalFileDataSource(parsed_uri.netloc + parsed_uri.path) - elif parsed_uri.scheme == "s3": + if parsed_uri.scheme == "s3": return S3DataSource(parsed_uri.netloc, parsed_uri.path, sagemaker_session) @@ -62,12 +62,11 @@ def get_splitter_instance(split_type): """ if split_type is None: return NoneSplitter() - elif split_type == "Line": + if split_type == "Line": return LineSplitter() - elif split_type == "RecordIO": + if split_type == "RecordIO": return RecordIOSplitter() - else: - raise ValueError("Invalid Split Type: %s" % split_type) + raise ValueError("Invalid Split Type: %s" % split_type) def get_batch_strategy_instance(strategy, splitter): @@ -82,12 +81,9 @@ def get_batch_strategy_instance(strategy, splitter): """ if strategy == "SingleRecord": return SingleRecordStrategy(splitter) - elif strategy == "MultiRecord": + if strategy == "MultiRecord": return MultiRecordStrategy(splitter) - else: - raise ValueError( - 'Invalid Batch Strategy: %s - Valid Strategies: "SingleRecord", "MultiRecord"' - ) + raise ValueError('Invalid Batch Strategy: %s - Valid Strategies: "SingleRecord", "MultiRecord"') class DataSource(with_metaclass(ABCMeta, object)): @@ -129,8 +125,7 @@ def get_file_list(self): for f in os.listdir(self.root_path) if os.path.isfile(os.path.join(self.root_path, f)) ] - else: - return [self.root_path] + return [self.root_path] def get_root_dir(self): """Retrieve the absolute path to the root directory of this data source. @@ -140,8 +135,7 @@ def get_root_dir(self): """ if os.path.isdir(self.root_path): return self.root_path - else: - return os.path.dirname(self.root_path) + return os.path.dirname(self.root_path) class S3DataSource(DataSource): diff --git a/src/sagemaker/local/image.py b/src/sagemaker/local/image.py index 30a6f47572..95a6277f87 100644 --- a/src/sagemaker/local/image.py +++ b/src/sagemaker/local/image.py @@ -665,7 +665,7 @@ def _aws_credentials(session): "AWS_ACCESS_KEY_ID=%s" % (str(access_key)), "AWS_SECRET_ACCESS_KEY=%s" % (str(secret_key)), ] - elif not _aws_credentials_available_in_metadata_service(): + if not _aws_credentials_available_in_metadata_service(): logger.warning( "Using the short-lived AWS credentials found in session. They might expire while running." ) @@ -674,11 +674,10 @@ def _aws_credentials(session): "AWS_SECRET_ACCESS_KEY=%s" % (str(secret_key)), "AWS_SESSION_TOKEN=%s" % (str(token)), ] - else: - logger.info( - "No AWS credentials found in session but credentials from EC2 Metadata Service are available." - ) - return None + logger.info( + "No AWS credentials found in session but credentials from EC2 Metadata Service are available." + ) + return None except Exception as e: # pylint: disable=broad-except logger.info("Could not get AWS credentials: %s", e) diff --git a/src/sagemaker/local/local_session.py b/src/sagemaker/local/local_session.py index 247dc17790..d2e6582ed4 100644 --- a/src/sagemaker/local/local_session.py +++ b/src/sagemaker/local/local_session.py @@ -107,8 +107,7 @@ def describe_training_job(self, TrainingJobName): } } raise ClientError(error_response, "describe_training_job") - else: - return LocalSagemakerClient._training_jobs[TrainingJobName].describe() + return LocalSagemakerClient._training_jobs[TrainingJobName].describe() def create_transform_job( self, @@ -132,8 +131,7 @@ def describe_transform_job(self, TransformJobName): } } raise ClientError(error_response, "describe_transform_job") - else: - return LocalSagemakerClient._transform_jobs[TransformJobName].describe() + return LocalSagemakerClient._transform_jobs[TransformJobName].describe() def create_model( self, ModelName, PrimaryContainer, *args, **kwargs @@ -152,13 +150,10 @@ def describe_model(self, ModelName): "Error": {"Code": "ValidationException", "Message": "Could not find local model"} } raise ClientError(error_response, "describe_model") - else: - return LocalSagemakerClient._models[ModelName].describe() + return LocalSagemakerClient._models[ModelName].describe() def describe_endpoint_config(self, EndpointConfigName): - if EndpointConfigName in LocalSagemakerClient._endpoint_configs: - return LocalSagemakerClient._endpoint_configs[EndpointConfigName].describe() - else: + if EndpointConfigName not in LocalSagemakerClient._endpoint_configs: error_response = { "Error": { "Code": "ValidationException", @@ -166,6 +161,7 @@ def describe_endpoint_config(self, EndpointConfigName): } } raise ClientError(error_response, "describe_endpoint_config") + return LocalSagemakerClient._endpoint_configs[EndpointConfigName].describe() def create_endpoint_config(self, EndpointConfigName, ProductionVariants, Tags=None): LocalSagemakerClient._endpoint_configs[EndpointConfigName] = _LocalEndpointConfig( @@ -178,8 +174,7 @@ def describe_endpoint(self, EndpointName): "Error": {"Code": "ValidationException", "Message": "Could not find local endpoint"} } raise ClientError(error_response, "describe_endpoint") - else: - return LocalSagemakerClient._endpoints[EndpointName].describe() + return LocalSagemakerClient._endpoints[EndpointName].describe() def create_endpoint(self, EndpointName, EndpointConfigName, Tags=None): endpoint = _LocalEndpoint(EndpointName, EndpointConfigName, Tags, self.sagemaker_session) diff --git a/src/sagemaker/model.py b/src/sagemaker/model.py index b89600576a..e2352c5f41 100644 --- a/src/sagemaker/model.py +++ b/src/sagemaker/model.py @@ -205,8 +205,7 @@ def check_neo_region(self, region): """ if region in NEO_IMAGE_ACCOUNT: return True - else: - return False + return False def _neo_image_account(self, region): if region not in NEO_IMAGE_ACCOUNT: diff --git a/src/sagemaker/predictor.py b/src/sagemaker/predictor.py index e7d43e1925..4bdf0ed665 100644 --- a/src/sagemaker/predictor.py +++ b/src/sagemaker/predictor.py @@ -195,10 +195,9 @@ def _serialize_row(data): if isinstance(data, np.ndarray): data = np.ndarray.flatten(data) if hasattr(data, "__len__"): - if len(data) > 0: - return _csv_serialize_python_array(data) - else: + if len(data) == 0: raise ValueError("Cannot serialize empty array") + return _csv_serialize_python_array(data) # files and buffers if hasattr(data, "read"): @@ -387,9 +386,9 @@ def __call__(self, stream, content_type=CONTENT_TYPE_NPY): return np.genfromtxt( codecs.getreader("utf-8")(stream), delimiter=",", dtype=self.dtype ) - elif content_type == CONTENT_TYPE_JSON: + if content_type == CONTENT_TYPE_JSON: return np.array(json.load(codecs.getreader("utf-8")(stream)), dtype=self.dtype) - elif content_type == CONTENT_TYPE_NPY: + if content_type == CONTENT_TYPE_NPY: return np.load(BytesIO(stream.read())) finally: stream.close() diff --git a/src/sagemaker/rl/estimator.py b/src/sagemaker/rl/estimator.py index 90e2e60a03..5836948c10 100644 --- a/src/sagemaker/rl/estimator.py +++ b/src/sagemaker/rl/estimator.py @@ -226,7 +226,7 @@ def create_model( from sagemaker.tensorflow.serving import Model as tfsModel return tfsModel(framework_version=self.framework_version, **base_args) - elif self.framework == RLFramework.MXNET.value: + if self.framework == RLFramework.MXNET.value: return MXNetModel( framework_version=self.framework_version, py_version=PYTHON_VERSION, **extended_args ) @@ -242,14 +242,13 @@ def train_image(self): """ if self.image_name: return self.image_name - else: - return fw_utils.create_image_uri( - self.sagemaker_session.boto_region_name, - self._image_framework(), - self.train_instance_type, - self._image_version(), - py_version=PYTHON_VERSION, - ) + return fw_utils.create_image_uri( + self.sagemaker_session.boto_region_name, + self._image_framework(), + self.train_instance_type, + self._image_version(), + py_version=PYTHON_VERSION, + ) @classmethod def _prepare_init_params_from_job_description(cls, job_details, model_channel_name=None): @@ -406,7 +405,7 @@ def default_metric_definitions(cls, toolkit): {"Name": "reward-training", "Regex": "^Training>.*Total reward=(.*?),"}, {"Name": "reward-testing", "Regex": "^Testing>.*Total reward=(.*?),"}, ] - elif toolkit is RLToolkit.RAY: + if toolkit is RLToolkit.RAY: float_regex = "[-+]?[0-9]*[.]?[0-9]+([eE][-+]?[0-9]+)?" # noqa: W605, E501 return [ diff --git a/src/sagemaker/session.py b/src/sagemaker/session.py index d14454eb4a..efa56fc0f9 100644 --- a/src/sagemaker/session.py +++ b/src/sagemaker/session.py @@ -1231,8 +1231,7 @@ def expand_role(self, role): """ if "/" in role: return role - else: - return self.boto_session.resource("iam").Role(role).arn + return self.boto_session.resource("iam").Role(role).arn def get_caller_identity_arn(self): """Returns the ARN user or role whose credentials are used to call the API. @@ -1791,5 +1790,4 @@ def _vpc_config_from_training_job( ): if vpc_config_override is vpc_utils.VPC_CONFIG_DEFAULT: return training_job_desc.get(vpc_utils.VPC_CONFIG_KEY) - else: - return vpc_utils.sanitize(vpc_config_override) + return vpc_utils.sanitize(vpc_config_override) diff --git a/src/sagemaker/sklearn/estimator.py b/src/sagemaker/sklearn/estimator.py index e066980b75..56cdcd1f51 100644 --- a/src/sagemaker/sklearn/estimator.py +++ b/src/sagemaker/sklearn/estimator.py @@ -176,7 +176,7 @@ def _prepare_init_params_from_job_description(cls, job_details, model_channel_na training_job_name ) ) - elif not framework: + if not framework: # If we were unable to parse the framework name from the image it is not one of our # officially supported images, in this case just add the image to the init params. init_params["image_name"] = image_name diff --git a/src/sagemaker/tensorflow/estimator.py b/src/sagemaker/tensorflow/estimator.py index 2026dbfca8..a24bf75dec 100644 --- a/src/sagemaker/tensorflow/estimator.py +++ b/src/sagemaker/tensorflow/estimator.py @@ -569,12 +569,11 @@ def _default_s3_path(self, directory, mpi=False): local_code = utils.get_config_value("local.local_code", self.sagemaker_session.config) if self.sagemaker_session.local_mode and local_code: return "/opt/ml/shared/{}".format(directory) - elif mpi: + if mpi: return "/opt/ml/model" - elif self._current_job_name: + if self._current_job_name: return os.path.join(self.output_path, self._current_job_name, directory) - else: - return None + return None def _script_mode_enabled(self): return self.py_version == "py3" or self.script_mode diff --git a/src/sagemaker/tensorflow/predictor.py b/src/sagemaker/tensorflow/predictor.py index c56f72ddc9..f9dd30d014 100644 --- a/src/sagemaker/tensorflow/predictor.py +++ b/src/sagemaker/tensorflow/predictor.py @@ -94,8 +94,7 @@ def __init__(self): def __call__(self, data): if isinstance(data, tensor_pb2.TensorProto): return json_format.MessageToJson(data) - else: - return json_serializer(data) + return json_serializer(data) tf_json_serializer = _TFJsonSerializer() diff --git a/src/sagemaker/vpc_utils.py b/src/sagemaker/vpc_utils.py index 9e412db734..5bf1fd687b 100644 --- a/src/sagemaker/vpc_utils.py +++ b/src/sagemaker/vpc_utils.py @@ -83,9 +83,9 @@ def sanitize(vpc_config): """ if vpc_config is None: return vpc_config - elif not isinstance(vpc_config, dict): + if not isinstance(vpc_config, dict): raise ValueError("vpc_config is not a dict: {}".format(vpc_config)) - elif not vpc_config: + if not vpc_config: raise ValueError("vpc_config is empty") subnets = vpc_config.get(SUBNETS_KEY) @@ -93,7 +93,7 @@ def sanitize(vpc_config): raise ValueError("vpc_config is missing key: {}".format(SUBNETS_KEY)) if not isinstance(subnets, list): raise ValueError("vpc_config value for {} is not a list: {}".format(SUBNETS_KEY, subnets)) - elif not subnets: + if not subnets: raise ValueError("vpc_config value for {} is empty".format(SUBNETS_KEY)) security_group_ids = vpc_config.get(SECURITY_GROUP_IDS_KEY) @@ -105,7 +105,7 @@ def sanitize(vpc_config): SECURITY_GROUP_IDS_KEY, security_group_ids ) ) - elif not security_group_ids: + if not security_group_ids: raise ValueError("vpc_config value for {} is empty".format(SECURITY_GROUP_IDS_KEY)) return to_dict(subnets, security_group_ids)