-
Notifications
You must be signed in to change notification settings - Fork 19
feat: Add S3 support for artifacts #29
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
wmsnp
wants to merge
3
commits into
google:main
Choose a base branch
from
wmsnp:feature/s3-artifact-service
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+1,049
−3
Open
Changes from all commits
Commits
Show all changes
3 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Some comments aren't visible on the classic Files Changed page.
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,5 @@ | ||
| """Community session services for ADK.""" | ||
|
|
||
| from .s3_artifact_service import S3ArtifactService | ||
|
|
||
| __all__ = ["S3ArtifactService"] |
364 changes: 364 additions & 0 deletions
364
src/google/adk_community/artifacts/s3_artifact_service.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,364 @@ | ||
| """An artifact service implementation using Amazon S3 or other S3-compatible services. | ||
|
|
||
| The blob/key name format depends on whether the filename has a user namespace: | ||
| - For files with user namespace (starting with "user:"): | ||
| {app_name}/{user_id}/user/{filename}/{version} | ||
| - For regular session-scoped files: | ||
| {app_name}/{user_id}/{session_id}/{filename}/{version} | ||
|
|
||
| This service supports storing and retrieving artifacts with inline data or text. | ||
| Artifacts can also have optional custom metadata, which is serialized as JSON | ||
| when stored in S3. | ||
| """ | ||
|
|
||
| from __future__ import annotations | ||
|
|
||
| from contextlib import asynccontextmanager | ||
| import json | ||
| import logging | ||
| from typing import Any | ||
|
|
||
| from google.adk.artifacts.base_artifact_service import ArtifactVersion | ||
| from google.adk.artifacts.base_artifact_service import BaseArtifactService | ||
| from google.genai import types | ||
| from pydantic import BaseModel | ||
| from typing_extensions import override | ||
|
|
||
| logger = logging.getLogger("google_adk." + __name__) | ||
|
|
||
|
|
||
| class S3ArtifactService(BaseArtifactService, BaseModel): | ||
| """An artifact service implementation using Amazon S3 or other S3-compatible services. | ||
|
|
||
| Attributes: | ||
| bucket_name: The name of the S3 bucket to use for storing and retrieving artifacts. | ||
| aws_configs: A dictionary of AWS configuration options to pass to the boto3 client. | ||
| save_artifact_max_retries: The maximum number of retries to attempt when saving an artifact with version conflicts. | ||
| If set to -1, the service will retry indefinitely. | ||
| """ | ||
|
|
||
| bucket_name: str | ||
| aws_configs: dict[str, Any] = {} | ||
| save_artifact_max_retries: int = -1 | ||
| _s3_session: Any = None | ||
|
|
||
| async def _session(self): | ||
| import aioboto3 | ||
|
|
||
| if self._s3_session is None: | ||
| self._s3_session = aioboto3.Session() | ||
| return self._s3_session | ||
|
|
||
| @asynccontextmanager | ||
| async def _client(self): | ||
| session = await self._session() | ||
| async with session.client(service_name="s3", **self.aws_configs) as s3: | ||
| yield s3 | ||
|
|
||
| def _flatten_metadata(self, metadata: dict[str, Any]) -> dict[str, str]: | ||
| return {k: json.dumps(v) for k, v in (metadata or {}).items()} | ||
|
|
||
| def _unflatten_metadata(self, metadata: dict[str, str]) -> dict[str, Any]: | ||
| results = {} | ||
| for k, v in (metadata or {}).items(): | ||
| try: | ||
| results[k] = json.loads(v) | ||
| except json.JSONDecodeError: | ||
| logger.warning( | ||
| "Failed to decode metadata value for key %r. Using raw string.", k | ||
| ) | ||
| results[k] = v | ||
| return results | ||
|
|
||
| def _file_has_user_namespace(self, filename: str) -> bool: | ||
| return filename.startswith("user:") | ||
|
|
||
| def _get_blob_prefix( | ||
| self, app_name: str, user_id: str, session_id: str | None, filename: str | ||
| ) -> str: | ||
| if self._file_has_user_namespace(filename): | ||
| return f"{app_name}/{user_id}/user/{filename}" | ||
| if session_id: | ||
| return f"{app_name}/{user_id}/{session_id}/{filename}" | ||
| raise ValueError("session_id is required for session-scoped artifacts.") | ||
|
|
||
| def _get_blob_name( | ||
| self, | ||
| app_name: str, | ||
| user_id: str, | ||
| session_id: str | None, | ||
| filename: str, | ||
| version: int, | ||
| ) -> str: | ||
| return ( | ||
| f"{self._get_blob_prefix(app_name, user_id, session_id, filename)}/{version}" | ||
| ) | ||
|
|
||
| @override | ||
| async def save_artifact( | ||
| self, | ||
| *, | ||
| app_name: str, | ||
| user_id: str, | ||
| filename: str, | ||
| artifact: types.Part, | ||
| session_id: str | None = None, | ||
| custom_metadata: dict[str, Any] | None = None, | ||
| ) -> int: | ||
| """Saves an artifact to S3 with atomic versioning using If-None-Match.""" | ||
| from botocore.exceptions import ClientError | ||
|
|
||
| if self.save_artifact_max_retries < 0: | ||
| retry_iter = iter(int, 1) | ||
| else: | ||
| retry_iter = range(self.save_artifact_max_retries + 1) | ||
| for _ in retry_iter: | ||
| versions = await self.list_versions( | ||
| app_name=app_name, | ||
| user_id=user_id, | ||
| filename=filename, | ||
| session_id=session_id, | ||
| ) | ||
| version = 0 if not versions else max(versions) + 1 | ||
| key = self._get_blob_name( | ||
| app_name, user_id, session_id, filename, version | ||
| ) | ||
| if artifact.inline_data: | ||
| body = artifact.inline_data.data | ||
| mime_type = artifact.inline_data.mime_type | ||
| elif artifact.text: | ||
| body = artifact.text | ||
| mime_type = "text/plain" | ||
| elif artifact.file_data: | ||
| raise NotImplementedError( | ||
| "Saving artifact with file_data is not supported yet in" | ||
| " S3ArtifactService." | ||
| ) | ||
| else: | ||
| raise ValueError("Artifact must have either inline_data or text.") | ||
| async with self._client() as s3: | ||
| try: | ||
| await s3.put_object( | ||
| Bucket=self.bucket_name, | ||
| Key=key, | ||
| Body=body, | ||
| ContentType=mime_type, | ||
| Metadata=self._flatten_metadata(custom_metadata), | ||
| IfNoneMatch="*", | ||
| ) | ||
| return version | ||
| except ClientError as e: | ||
| if e.response["Error"]["Code"] in ( | ||
| "PreconditionFailed", | ||
| "ObjectAlreadyExists", | ||
| ): | ||
| continue | ||
| raise e | ||
| raise RuntimeError( | ||
| "Failed to save artifact due to version conflicts after retries" | ||
| ) | ||
|
|
||
| @override | ||
| async def load_artifact( | ||
| self, | ||
| *, | ||
| app_name: str, | ||
| user_id: str, | ||
| filename: str, | ||
| session_id: str | None = None, | ||
| version: int | None = None, | ||
| ) -> types.Part | None: | ||
| """Loads a specific version of an artifact from S3. | ||
|
|
||
| If version is not provided, the latest version is loaded. | ||
|
|
||
| Returns: | ||
| A types.Part instance (always with inline_data), or None if the artifact does not exist. | ||
| """ | ||
| from botocore.exceptions import ClientError | ||
|
|
||
| if version is None: | ||
| versions = await self.list_versions( | ||
| app_name=app_name, | ||
| user_id=user_id, | ||
| filename=filename, | ||
| session_id=session_id, | ||
| ) | ||
| if not versions: | ||
| return None | ||
| version = max(versions) | ||
|
|
||
| key = self._get_blob_name(app_name, user_id, session_id, filename, version) | ||
| async with self._client() as s3: | ||
| try: | ||
| response = await s3.get_object(Bucket=self.bucket_name, Key=key) | ||
| async with response["Body"] as stream: | ||
| data = await stream.read() | ||
| mime_type = response["ContentType"] | ||
| except ClientError as e: | ||
| if e.response["Error"]["Code"] in ("NoSuchKey", "404"): | ||
| return None | ||
| raise | ||
| return types.Part.from_bytes(data=data, mime_type=mime_type) | ||
|
|
||
| @override | ||
| async def list_artifact_keys( | ||
| self, *, app_name: str, user_id: str, session_id: str | None = None | ||
| ) -> list[str]: | ||
| """Lists all artifact keys for a user, optionally filtered by session.""" | ||
| keys = set() | ||
| prefixes = [ | ||
| f"{app_name}/{user_id}/{session_id}/" if session_id else None, | ||
| f"{app_name}/{user_id}/user/", | ||
| ] | ||
| async with self._client() as s3: | ||
| for prefix in filter(None, prefixes): | ||
| paginator = s3.get_paginator("list_objects_v2") | ||
| async for page in paginator.paginate( | ||
| Bucket=self.bucket_name, Prefix=prefix | ||
| ): | ||
| for obj in page.get("Contents", []): | ||
| relative = obj["Key"][len(prefix) :] | ||
| filename = "/".join(relative.split("/")[:-1]) | ||
| keys.add(filename) | ||
| return sorted(keys) | ||
|
|
||
| @override | ||
| async def delete_artifact( | ||
| self, | ||
| *, | ||
| app_name: str, | ||
| user_id: str, | ||
| filename: str, | ||
| session_id: str | None = None, | ||
| ) -> None: | ||
| """Deletes all versions of a specified artifact efficiently using batch delete.""" | ||
| versions = await self.list_versions( | ||
| app_name=app_name, | ||
| user_id=user_id, | ||
| filename=filename, | ||
| session_id=session_id, | ||
| ) | ||
| if not versions: | ||
| return | ||
|
|
||
| keys_to_delete = [ | ||
| {"Key": self._get_blob_name(app_name, user_id, session_id, filename, v)} | ||
| for v in versions | ||
| ] | ||
| async with self._client() as s3: | ||
| for i in range(0, len(keys_to_delete), 1000): | ||
| batch = keys_to_delete[i : i + 1000] | ||
| await s3.delete_objects( | ||
| Bucket=self.bucket_name, Delete={"Objects": batch} | ||
| ) | ||
|
|
||
| @override | ||
| async def list_versions( | ||
| self, | ||
| *, | ||
| app_name: str, | ||
| user_id: str, | ||
| filename: str, | ||
| session_id: str | None = None, | ||
| ) -> list[int]: | ||
| """Lists all available versions of a specified artifact.""" | ||
| prefix = ( | ||
| self._get_blob_prefix(app_name, user_id, session_id, filename) + "/" | ||
| ) | ||
| versions = [] | ||
| async with self._client() as s3: | ||
| paginator = s3.get_paginator("list_objects_v2") | ||
| async for page in paginator.paginate( | ||
| Bucket=self.bucket_name, Prefix=prefix | ||
| ): | ||
| for obj in page.get("Contents", []): | ||
| try: | ||
| versions.append(int(obj["Key"].split("/")[-1])) | ||
| except ValueError: | ||
| continue | ||
| return sorted(versions) | ||
|
|
||
| @override | ||
| async def list_artifact_versions( | ||
| self, | ||
| *, | ||
| app_name: str, | ||
| user_id: str, | ||
| filename: str, | ||
| session_id: str | None = None, | ||
| ) -> list[ArtifactVersion]: | ||
| """Lists all artifact versions with their metadata.""" | ||
| prefix = ( | ||
| self._get_blob_prefix(app_name, user_id, session_id, filename) + "/" | ||
| ) | ||
| results: list[ArtifactVersion] = [] | ||
| async with self._client() as s3: | ||
| paginator = s3.get_paginator("list_objects_v2") | ||
| async for page in paginator.paginate( | ||
| Bucket=self.bucket_name, Prefix=prefix | ||
| ): | ||
| for obj in page.get("Contents", []): | ||
| try: | ||
| version = int(obj["Key"].split("/")[-1]) | ||
| except ValueError: | ||
| continue | ||
|
|
||
| head = await s3.head_object(Bucket=self.bucket_name, Key=obj["Key"]) | ||
| mime_type = head["ContentType"] | ||
| metadata = head.get("Metadata", {}) | ||
|
|
||
| canonical_uri = f"s3://{self.bucket_name}/{obj['Key']}" | ||
|
|
||
| results.append( | ||
| ArtifactVersion( | ||
| version=version, | ||
| canonical_uri=canonical_uri, | ||
| custom_metadata=self._unflatten_metadata(metadata), | ||
| create_time=obj["LastModified"].timestamp(), | ||
| mime_type=mime_type, | ||
| ) | ||
| ) | ||
| return sorted(results, key=lambda a: a.version) | ||
|
|
||
| @override | ||
| async def get_artifact_version( | ||
| self, | ||
| *, | ||
| app_name: str, | ||
| user_id: str, | ||
| filename: str, | ||
| session_id: str | None = None, | ||
| version: int | None = None, | ||
| ) -> ArtifactVersion | None: | ||
| """Retrieves a specific artifact version, or the latest if version is None.""" | ||
| if version is None: | ||
| all_versions = await self.list_versions( | ||
| app_name=app_name, | ||
| user_id=user_id, | ||
| filename=filename, | ||
| session_id=session_id, | ||
| ) | ||
| if not all_versions: | ||
| return None | ||
| version = max(all_versions) | ||
|
|
||
| key = self._get_blob_name(app_name, user_id, session_id, filename, version) | ||
|
|
||
| from botocore.exceptions import ClientError | ||
|
|
||
| async with self._client() as s3: | ||
| try: | ||
| head = await s3.head_object(Bucket=self.bucket_name, Key=key) | ||
| except ClientError as e: | ||
| if e.response["Error"]["Code"] in ("NoSuchKey", "404"): | ||
| return None | ||
| raise | ||
|
|
||
| return ArtifactVersion( | ||
| version=version, | ||
| canonical_uri=f"s3://{self.bucket_name}/{key}", | ||
| custom_metadata=self._unflatten_metadata(head.get("Metadata", {})), | ||
| create_time=head["LastModified"].timestamp(), | ||
| mime_type=head["ContentType"], | ||
| ) |
Empty file.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.