@@ -41,8 +41,13 @@ def __eq__(self, other: Any) -> bool:
4141 if self .__slots__ != other .__slots__ :
4242 return False
4343 for attribute in self .__slots__ :
44- if getattr (self , attribute ) != getattr (other , attribute ):
44+ if (hasattr (self , attribute ) and not hasattr (other , attribute )) or (
45+ hasattr (other , attribute ) and not hasattr (self , attribute )
46+ ):
4547 return False
48+ if hasattr (self , attribute ) and hasattr (other , attribute ):
49+ if getattr (self , attribute ) != getattr (other , attribute ):
50+ return False
4651 return True
4752
4853 def __hash__ (self ) -> int :
@@ -112,7 +117,7 @@ def __init__(self, header: Dict[str, str]):
112117
113118 def to_json (self ) -> Dict [str , str ]:
114119 """Returns json representation of JumpStartModelHeader object."""
115- json_obj = {att : getattr (self , att ) for att in self .__slots__ }
120+ json_obj = {att : getattr (self , att ) for att in self .__slots__ if hasattr ( self , att ) }
116121 return json_obj
117122
118123 def from_json (self , json_obj : Dict [str , str ]) -> None :
@@ -134,6 +139,7 @@ class JumpStartECRSpecs(JumpStartDataHolderType):
134139 "framework" ,
135140 "framework_version" ,
136141 "py_version" ,
142+ "huggingface_transformers_version" ,
137143 }
138144
139145 def __init__ (self , spec : Dict [str , Any ]):
@@ -154,10 +160,13 @@ def from_json(self, json_obj: Dict[str, Any]) -> None:
154160 self .framework = json_obj ["framework" ]
155161 self .framework_version = json_obj ["framework_version" ]
156162 self .py_version = json_obj ["py_version" ]
163+ huggingface_transformers_version = json_obj .get ("huggingface_transformers_version" )
164+ if huggingface_transformers_version is not None :
165+ self .huggingface_transformers_version = huggingface_transformers_version
157166
158167 def to_json (self ) -> Dict [str , Any ]:
159168 """Returns json representation of JumpStartECRSpecs object."""
160- json_obj = {att : getattr (self , att ) for att in self .__slots__ }
169+ json_obj = {att : getattr (self , att ) for att in self .__slots__ if hasattr ( self , att ) }
161170 return json_obj
162171
163172
@@ -202,26 +211,23 @@ def from_json(self, json_obj: Dict[str, Any]) -> None:
202211 self .hosting_script_key : str = json_obj ["hosting_script_key" ]
203212 self .training_supported : bool = bool (json_obj ["training_supported" ])
204213 if self .training_supported :
205- self .training_ecr_specs : Optional [ JumpStartECRSpecs ] = JumpStartECRSpecs (
214+ self .training_ecr_specs : JumpStartECRSpecs = JumpStartECRSpecs (
206215 json_obj ["training_ecr_specs" ]
207216 )
208- self .training_artifact_key : Optional [str ] = json_obj ["training_artifact_key" ]
209- self .training_script_key : Optional [str ] = json_obj ["training_script_key" ]
210- self .hyperparameters : Optional [Dict [str , Any ]] = json_obj .get ("hyperparameters" )
211- else :
212- self .training_ecr_specs = (
213- self .training_artifact_key
214- ) = self .training_script_key = self .hyperparameters = None
217+ self .training_artifact_key : str = json_obj ["training_artifact_key" ]
218+ self .training_script_key : str = json_obj ["training_script_key" ]
219+ self .hyperparameters : Dict [str , Any ] = json_obj .get ("hyperparameters" , {})
215220
216221 def to_json (self ) -> Dict [str , Any ]:
217222 """Returns json representation of JumpStartModelSpecs object."""
218223 json_obj = {}
219224 for att in self .__slots__ :
220- cur_val = getattr (self , att )
221- if isinstance (cur_val , JumpStartECRSpecs ):
222- json_obj [att ] = cur_val .to_json ()
223- else :
224- json_obj [att ] = cur_val
225+ if hasattr (self , att ):
226+ cur_val = getattr (self , att )
227+ if isinstance (cur_val , JumpStartECRSpecs ):
228+ json_obj [att ] = cur_val .to_json ()
229+ else :
230+ json_obj [att ] = cur_val
225231 return json_obj
226232
227233
0 commit comments