|
76 | 76 | "role = sagemaker.get_execution_role()\n", |
77 | 77 | "region = sess.boto_region_name\n", |
78 | 78 | "\n", |
79 | | - "iam_client = boto3.client('iam')\n", |
| 79 | + "iam_client = boto3.client(\"iam\")\n", |
80 | 80 | "sts_client = boto3.client(\"sts\")\n", |
81 | | - "sm_client = boto3.client('sagemaker')\n", |
| 81 | + "sm_client = boto3.client(\"sagemaker\")\n", |
82 | 82 | "account_id = sts_client.get_caller_identity()[\"Account\"]\n", |
83 | | - "tracking_server_name = 'my-setup-test3'\n", |
84 | | - "mlflow_role_name = 'mlflow-test3'" |
| 83 | + "tracking_server_name = \"my-setup-test3\"\n", |
| 84 | + "mlflow_role_name = \"mlflow-test3\"" |
85 | 85 | ] |
86 | 86 | }, |
87 | 87 | { |
|
142 | 142 | "outputs": [], |
143 | 143 | "source": [ |
144 | 144 | "mlflow_trust_policy = {\n", |
145 | | - " \"Version\": \"2012-10-17\",\n", |
146 | | - " \"Statement\": [\n", |
147 | | - " {\n", |
148 | | - " \"Effect\": \"Allow\",\n", |
149 | | - " \"Principal\": {\n", |
150 | | - " \"Service\": [\n", |
151 | | - " \"sagemaker.amazonaws.com\"\n", |
152 | | - " ]\n", |
153 | | - " },\n", |
154 | | - " \"Action\": \"sts:AssumeRole\"\n", |
155 | | - " }\n", |
156 | | - " ]\n", |
| 145 | + " \"Version\": \"2012-10-17\",\n", |
| 146 | + " \"Statement\": [\n", |
| 147 | + " {\n", |
| 148 | + " \"Effect\": \"Allow\",\n", |
| 149 | + " \"Principal\": {\"Service\": [\"sagemaker.amazonaws.com\"]},\n", |
| 150 | + " \"Action\": \"sts:AssumeRole\",\n", |
| 151 | + " }\n", |
| 152 | + " ],\n", |
157 | 153 | "}\n", |
158 | 154 | "\n", |
159 | 155 | "# Create role for MLflow\n", |
160 | 156 | "mlflow_role = iam_client.create_role(\n", |
161 | | - " RoleName=mlflow_role_name,\n", |
162 | | - " AssumeRolePolicyDocument=json.dumps(mlflow_trust_policy)\n", |
| 157 | + " RoleName=mlflow_role_name, AssumeRolePolicyDocument=json.dumps(mlflow_trust_policy)\n", |
163 | 158 | ")\n", |
164 | | - "mlflow_role_arn = mlflow_role['Role']['Arn']\n", |
| 159 | + "mlflow_role_arn = mlflow_role[\"Role\"][\"Arn\"]\n", |
165 | 160 | "\n", |
166 | 161 | "# Create policy for S3 and SageMaker Model Registry\n", |
167 | 162 | "sm_s3_model_registry_policy = {\n", |
|
177 | 172 | " \"sagemaker:CreateModelPackageGroup\",\n", |
178 | 173 | " \"sagemaker:CreateModelPackage\",\n", |
179 | 174 | " \"sagemaker:UpdateModelPackage\",\n", |
180 | | - " \"sagemaker:DescribeModelPackageGroup\"\n", |
| 175 | + " \"sagemaker:DescribeModelPackageGroup\",\n", |
181 | 176 | " ],\n", |
182 | | - " \"Resource\": \"*\"\n", |
| 177 | + " \"Resource\": \"*\",\n", |
183 | 178 | " }\n", |
184 | | - " ]\n", |
| 179 | + " ],\n", |
185 | 180 | "}\n", |
186 | 181 | "\n", |
187 | | - "mlflow_s3_sm_model_registry_iam_policy = iam_client.create_policy(PolicyName='mlflow-s3-sm-model-registry',\n", |
188 | | - " PolicyDocument=json.dumps(sm_s3_model_registry_policy))\n", |
189 | | - "mlflow_s3_sm_model_registry_iam_policy_arn = mlflow_s3_sm_model_registry_iam_policy['Policy']['Arn']\n", |
| 182 | + "mlflow_s3_sm_model_registry_iam_policy = iam_client.create_policy(\n", |
| 183 | + " PolicyName=\"mlflow-s3-sm-model-registry\", PolicyDocument=json.dumps(sm_s3_model_registry_policy)\n", |
| 184 | + ")\n", |
| 185 | + "mlflow_s3_sm_model_registry_iam_policy_arn = mlflow_s3_sm_model_registry_iam_policy[\"Policy\"][\"Arn\"]\n", |
190 | 186 | "\n", |
191 | 187 | "# Attach the policy to the MLflow role\n", |
192 | 188 | "iam_client.attach_role_policy(\n", |
193 | | - " RoleName=mlflow_role_name,\n", |
194 | | - " PolicyArn=mlflow_s3_sm_model_registry_iam_policy_arn\n", |
| 189 | + " RoleName=mlflow_role_name, PolicyArn=mlflow_s3_sm_model_registry_iam_policy_arn\n", |
195 | 190 | ")" |
196 | 191 | ] |
197 | 192 | }, |
|
241 | 236 | "source": [ |
242 | 237 | "sm_client.create_mlflow_tracking_server(\n", |
243 | 238 | " TrackingServerName=tracking_server_name,\n", |
244 | | - " ArtifactStoreUri=f's3://{bucket_name}/{tracking_server_name}',\n", |
245 | | - " TrackingServerSize='Small',\n", |
246 | | - " MlflowVersion='2.13.2',\n", |
| 239 | + " ArtifactStoreUri=f\"s3://{bucket_name}/{tracking_server_name}\",\n", |
| 240 | + " TrackingServerSize=\"Small\",\n", |
| 241 | + " MlflowVersion=\"2.13.2\",\n", |
247 | 242 | " RoleArn=mlflow_role_arn,\n", |
248 | | - " AutomaticModelRegistration=False\n", |
| 243 | + " AutomaticModelRegistration=False,\n", |
249 | 244 | ")" |
250 | 245 | ] |
251 | 246 | }, |
|
256 | 251 | "metadata": {}, |
257 | 252 | "outputs": [], |
258 | 253 | "source": [ |
259 | | - "tracking_server_arn = f\"arn:aws:sagemaker:{region}:{account_id}:mlflow-tracking-server/{tracking_server_name}\"" |
| 254 | + "tracking_server_arn = (\n", |
| 255 | + " f\"arn:aws:sagemaker:{region}:{account_id}:mlflow-tracking-server/{tracking_server_name}\"\n", |
| 256 | + ")" |
260 | 257 | ] |
261 | 258 | }, |
262 | 259 | { |
|
311 | 308 | "outputs": [], |
312 | 309 | "source": [ |
313 | 310 | "import mlflow\n", |
| 311 | + "\n", |
314 | 312 | "mlflow.set_tracking_uri(tracking_server_arn)" |
315 | 313 | ] |
316 | 314 | }, |
|
348 | 346 | "metadata": {}, |
349 | 347 | "outputs": [], |
350 | 348 | "source": [ |
351 | | - "sm_client.create_presigned_mlflow_tracking_server_url(\n", |
352 | | - " TrackingServerName=tracking_server_name\n", |
353 | | - ")" |
| 349 | + "sm_client.create_presigned_mlflow_tracking_server_url(TrackingServerName=tracking_server_name)" |
354 | 350 | ] |
355 | 351 | }, |
356 | 352 | { |
|
0 commit comments