1212# language governing permissions and limitations under the License.
1313"""This module stores exceptions related to SageMaker JumpStart."""
1414
15+ from __future__ import absolute_import
1516from typing import List , Optional
1617
18+ from sagemaker .jumpstart .constants import JumpStartScriptScope
19+
1720
1821class VulnerableJumpStartModelError (Exception ):
19- """Exception raised for errors with vulnerable JumpStart models."""
22+ """Exception raised when trying to access a JumpStart model specs flagged as vulnerable.
23+
24+ Raise this exception only if the scope of attributes accessed in the specifications have
25+ vulnerabilities. For example, a model training script may have vulnerabilities, but not
26+ the hosting scripts. In such a case, raise a ``VulnerableJumpStartModelError`` only when
27+ accessing the training specifications.
28+ """
2029
2130 def __init__ (
2231 self ,
2332 model_id : Optional [str ] = None ,
2433 version : Optional [str ] = None ,
2534 vulnerabilities : Optional [List [str ]] = None ,
26- inference : Optional [bool ] = None ,
35+ scope : Optional [JumpStartScriptScope ] = None ,
2736 message : Optional [str ] = None ,
2837 ):
38+ """Instantiates VulnerableJumpStartModelError exception.
39+
40+ Args:
41+ model_id (Optional[str]): model id of vulnerable JumpStart model.
42+ (Default: None).
43+ version (Optional[str]): version of vulnerable JumpStart model.
44+ (Default: None).
45+ vulnerabilities (Optional[List[str]]): vulnerabilities associated with
46+ model. (Default: None).
47+
48+ """
2949 if message :
3050 self .message = message
3151 else :
32- if None in [model_id , version , vulnerabilities , inference ]:
52+ if None in [model_id , version , vulnerabilities , scope ]:
3353 raise ValueError (
34- "Must specify `model_id`, `version`, `vulnerabilities`, "
35- "and inference arguments."
54+ "Must specify `model_id`, `version`, `vulnerabilities`, " "and scope arguments."
3655 )
37- if inference is True :
56+ if scope == JumpStartScriptScope . INFERENCE :
3857 self .message = (
39- f"JumpStart model '{ model_id } ' and version '{ version } ' has at least 1 "
40- "vulnerable dependency in the inference scripts. "
41- f"List of vulnerabilities: { ', ' .join (vulnerabilities )} "
58+ f"Version '{ version } ' of JumpStart model '{ model_id } ' " # type: ignore
59+ "has at least 1 vulnerable dependency in the inference script. "
60+ "Please try targetting a higher version of the model. "
61+ f"List of vulnerabilities: { ', ' .join (vulnerabilities )} " # type: ignore
4262 )
43- else :
63+ elif scope == JumpStartScriptScope . TRAINING :
4464 self .message = (
45- f"JumpStart model '{ model_id } ' and version '{ version } ' has at least 1 "
46- "vulnerable dependency in the training scripts. "
47- f"List of vulnerabilities: { ', ' .join (vulnerabilities )} "
65+ f"Version '{ version } ' of JumpStart model '{ model_id } ' " # type: ignore
66+ "has at least 1 vulnerable dependency in the training script. "
67+ "Please try targetting a higher version of the model. "
68+ f"List of vulnerabilities: { ', ' .join (vulnerabilities )} " # type: ignore
69+ )
70+ else :
71+ raise NotImplementedError (
72+ "Unsupported scope for VulnerableJumpStartModelError: " # type: ignore
73+ f"'{ scope .value } '"
4874 )
4975
5076 super ().__init__ (self .message )
5177
5278
5379class DeprecatedJumpStartModelError (Exception ):
54- """Exception raised for errors with deprecated JumpStart models."""
80+ """Exception raised when trying to access a JumpStart model deprecated specifications.
81+
82+ A deprecated specification for a JumpStart model does not mean the whole model is
83+ deprecated. There may be more recent specifications available for this model. For
84+ example, all specification before version ``2.0.0`` may be deprecated, in such a
85+ case, the SDK would raise this exception only when specifications ``1.*`` are
86+ accessed.
87+ """
5588
5689 def __init__ (
5790 self ,
@@ -64,6 +97,9 @@ def __init__(
6497 else :
6598 if None in [model_id , version ]:
6699 raise ValueError ("Must specify `model_id` and `version` arguments." )
67- self .message = f"JumpStart model '{ model_id } ' and version '{ version } ' is deprecated."
100+ self .message = (
101+ f"Version '{ version } ' of JumpStart model '{ model_id } ' is deprecated. "
102+ "Please try targetting a higher version of the model."
103+ )
68104
69105 super ().__init__ (self .message )
0 commit comments