From 8a289359b2c052108a93bfed687ffc5c1e805ecc Mon Sep 17 00:00:00 2001 From: wmsnp Date: Sat, 15 Nov 2025 16:26:41 +0800 Subject: [PATCH 1/4] Add S3ArtifactService with unit tests --- pyproject.toml | 1 + src/google/adk/artifacts/__init__.py | 2 + .../adk/artifacts/s3_artifact_service.py | 310 ++++++++++++++++++ .../artifacts/test_artifact_service.py | 178 +++++++++- 4 files changed, 487 insertions(+), 4 deletions(-) create mode 100644 src/google/adk/artifacts/s3_artifact_service.py diff --git a/pyproject.toml b/pyproject.toml index e33d650f6f..3f9e5e76cf 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -116,6 +116,7 @@ test = [ # go/keep-sorted start "a2a-sdk>=0.3.0,<0.4.0;python_version>='3.10'", "anthropic>=0.43.0", # For anthropic model tests + "aioboto3>=15.5.0", "crewai[tools];python_version>='3.10' and python_version<'3.14'", # For CrewaiTool tests "kubernetes>=29.0.0", # For GkeCodeExecutor "langchain-community>=0.3.17", diff --git a/src/google/adk/artifacts/__init__.py b/src/google/adk/artifacts/__init__.py index 90a8063fae..88dd05dd7c 100644 --- a/src/google/adk/artifacts/__init__.py +++ b/src/google/adk/artifacts/__init__.py @@ -16,10 +16,12 @@ from .file_artifact_service import FileArtifactService from .gcs_artifact_service import GcsArtifactService from .in_memory_artifact_service import InMemoryArtifactService +from .s3_artifact_service import S3ArtifactService __all__ = [ 'BaseArtifactService', 'FileArtifactService', 'GcsArtifactService', 'InMemoryArtifactService', + 'S3ArtifactService', ] diff --git a/src/google/adk/artifacts/s3_artifact_service.py b/src/google/adk/artifacts/s3_artifact_service.py new file mode 100644 index 0000000000..2b1aa59803 --- /dev/null +++ b/src/google/adk/artifacts/s3_artifact_service.py @@ -0,0 +1,310 @@ +"""An artifact service implementation using Amazon S3 or other S3-compatible services. + +The blob/key name format depends on whether the filename has a user namespace: + - For files with user namespace (starting with "user:"): + {app_name}/{user_id}/user/{filename}/{version} + - For regular session-scoped files: + {app_name}/{user_id}/{session_id}/{filename}/{version} + +This service supports storing and retrieving artifacts with inline data or text. +Artifacts can also have optional custom metadata, which is serialized as JSON +when stored in S3. +""" + +from __future__ import annotations + +import json +import logging +from typing import Any +from typing import override + +from google.genai import types +from pydantic import BaseModel + +from .base_artifact_service import ArtifactVersion +from .base_artifact_service import BaseArtifactService + +logger = logging.getLogger("google_adk." + __name__) + + +class S3ArtifactService(BaseArtifactService, BaseModel): + """An artifact service implementation using Amazon S3 or other S3-compatible services.""" + + bucket_name: str + aws_configs: dict[str, Any] = {} + _s3_client: Any = None + + async def _client(self): + """Creates or returns the aioboto3 S3 client.""" + import aioboto3 + + if self._s3_client is None: + self._s3_client = ( + await aioboto3.Session() + .client(service_name="s3", **self.aws_configs) + .__aenter__() + ) + return self._s3_client + + async def close(self): + """Closes the underlying S3 client session.""" + if self._s3_client: + await self._s3_client.__aexit__(None, None, None) + self._s3_client = None + + def _flatten_metadata(self, metadata: dict[str, Any]) -> dict[str, str]: + return {k: json.dumps(v) for k, v in (metadata or {}).items()} + + def _unflatten_metadata(self, metadata: dict[str, str]) -> dict[str, Any]: + return {k: json.loads(v) for k, v in (metadata or {}).items()} + + def _file_has_user_namespace(self, filename: str) -> bool: + return filename.startswith("user:") + + def _get_blob_prefix( + self, app_name: str, user_id: str, session_id: str | None, filename: str + ) -> str: + if self._file_has_user_namespace(filename): + return f"{app_name}/{user_id}/user/{filename}" + if session_id: + return f"{app_name}/{user_id}/{session_id}/{filename}" + raise ValueError("session_id is required for session-scoped artifacts.") + + def _get_blob_name( + self, + app_name: str, + user_id: str, + session_id: str | None, + filename: str, + version: int, + ) -> str: + return ( + f"{self._get_blob_prefix(app_name, user_id, session_id, filename)}/{version}" + ) + + @override + async def save_artifact( + self, + *, + app_name: str, + user_id: str, + filename: str, + artifact: types.Part, + session_id: str | None = None, + custom_metadata: dict[str, Any] | None = None, + ) -> int: + """Saves an artifact to S3 and returns its assigned version number. + + Args: + app_name: Application name. + user_id: User ID. + filename: Artifact filename. + artifact: The artifact data (inline_data or text). + session_id: Session ID for session-scoped artifacts. + custom_metadata: Optional metadata to store with the artifact. + + Returns: + The version number of the saved artifact. + """ + s3 = await self._client() + versions = await self.list_versions( + app_name=app_name, + user_id=user_id, + filename=filename, + session_id=session_id, + ) + version = 0 if not versions else max(versions) + 1 + key = self._get_blob_name(app_name, user_id, session_id, filename, version) + + if artifact.inline_data: + body = artifact.inline_data.data + mime_type = artifact.inline_data.mime_type + elif artifact.text: + body = artifact.text + mime_type = "text/plain" + elif artifact.file_data: + raise NotImplementedError( + "Saving artifact with file_data is not supported yet in" + " S3ArtifactService." + ) + else: + raise ValueError("Artifact must have either inline_data or text.") + await s3.put_object( + Bucket=self.bucket_name, + Key=key, + Body=body, + ContentType=mime_type, + Metadata=self._flatten_metadata(custom_metadata), + ) + return version + + @override + async def load_artifact( + self, + *, + app_name: str, + user_id: str, + filename: str, + session_id: str | None = None, + version: int | None = None, + ) -> types.Part | None: + """Loads a specific version of an artifact from S3. + + If version is not provided, the latest version is loaded. + + Returns: + A types.Part instance (always with inline_data), or None if the artifact does not exist. + """ + from botocore.exceptions import ClientError + + s3 = await self._client() + if version is None: + versions = await self.list_versions( + app_name=app_name, + user_id=user_id, + filename=filename, + session_id=session_id, + ) + if not versions: + return None + version = max(versions) + + key = self._get_blob_name(app_name, user_id, session_id, filename, version) + try: + response = await s3.get_object(Bucket=self.bucket_name, Key=key) + async with response["Body"] as stream: + data = await stream.read() + mime_type = response["ContentType"] + except ClientError as e: + if e.response["Error"]["Code"] == "NoSuchKey": + return None + raise + return types.Part.from_bytes(data=data, mime_type=mime_type) + + @override + async def list_artifact_keys( + self, *, app_name: str, user_id: str, session_id: str | None = None + ) -> list[str]: + """Lists all artifact keys for a user, optionally filtered by session.""" + s3 = await self._client() + keys = set() + prefixes = [ + f"{app_name}/{user_id}/{session_id}/" if session_id else None, + f"{app_name}/{user_id}/user/", + ] + + for prefix in filter(None, prefixes): + response = await s3.list_objects_v2( + Bucket=self.bucket_name, Prefix=prefix + ) + for obj in response.get("Contents", []): + relative = obj["Key"][len(prefix) :] + filename = "/".join(relative.split("/")[:-1]) + keys.add(filename) + return sorted(keys) + + @override + async def delete_artifact( + self, + *, + app_name: str, + user_id: str, + filename: str, + session_id: str | None = None, + ) -> None: + """Deletes all versions of a specified artifact.""" + s3 = await self._client() + versions = await self.list_versions( + app_name=app_name, + user_id=user_id, + filename=filename, + session_id=session_id, + ) + for v in versions: + key = self._get_blob_name(app_name, user_id, session_id, filename, v) + await s3.delete_object(Bucket=self.bucket_name, Key=key) + + @override + async def list_versions( + self, + *, + app_name: str, + user_id: str, + filename: str, + session_id: str | None = None, + ) -> list[int]: + """Lists all available versions of a specified artifact.""" + s3 = await self._client() + prefix = ( + self._get_blob_prefix(app_name, user_id, session_id, filename) + "/" + ) + versions = [] + response = await s3.list_objects_v2(Bucket=self.bucket_name, Prefix=prefix) + for obj in response.get("Contents", []): + try: + versions.append(int(obj["Key"].split("/")[-1])) + except ValueError: + continue + return sorted(versions) + + @override + async def list_artifact_versions( + self, + *, + app_name: str, + user_id: str, + filename: str, + session_id: str | None = None, + ) -> list[ArtifactVersion]: + """Lists all artifact versions with their metadata.""" + s3 = await self._client() + prefix = ( + self._get_blob_prefix(app_name, user_id, session_id, filename) + "/" + ) + results: list[ArtifactVersion] = [] + + response = await s3.list_objects_v2(Bucket=self.bucket_name, Prefix=prefix) + for obj in response.get("Contents", []): + try: + version = int(obj["Key"].split("/")[-1]) + except ValueError: + continue + head = await s3.head_object(Bucket=self.bucket_name, Key=obj["Key"]) + mime_type = head["ContentType"] + metadata = head.get("Metadata", {}) + + canonical_uri = f"s3://{self.bucket_name}/{obj['Key']}" + + results.append( + ArtifactVersion( + version=version, + canonical_uri=canonical_uri, + custom_metadata=self._unflatten_metadata(metadata), + create_time=obj["LastModified"].timestamp(), + mime_type=mime_type, + ) + ) + return sorted(results, key=lambda a: a.version) + + @override + async def get_artifact_version( + self, + *, + app_name: str, + user_id: str, + filename: str, + session_id: str | None = None, + version: int | None = None, + ) -> ArtifactVersion | None: + """Retrieves a specific artifact version, or the latest if version is None.""" + versions = await self.list_artifact_versions( + app_name=app_name, + user_id=user_id, + filename=filename, + session_id=session_id, + ) + if not versions: + return None + if version is None: + return max(versions, key=lambda v: v.version) + return next(filter(lambda av: av.version == version, versions), None) diff --git a/tests/unittests/artifacts/test_artifact_service.py b/tests/unittests/artifacts/test_artifact_service.py index 007b18ecf7..8b55d318aa 100644 --- a/tests/unittests/artifacts/test_artifact_service.py +++ b/tests/unittests/artifacts/test_artifact_service.py @@ -20,6 +20,7 @@ import enum import json from pathlib import Path +import sys from typing import Any from typing import Optional from typing import Union @@ -28,10 +29,12 @@ from urllib.parse import unquote from urllib.parse import urlparse +from botocore.exceptions import ClientError from google.adk.artifacts.base_artifact_service import ArtifactVersion from google.adk.artifacts.file_artifact_service import FileArtifactService from google.adk.artifacts.gcs_artifact_service import GcsArtifactService from google.adk.artifacts.in_memory_artifact_service import InMemoryArtifactService +from google.adk.artifacts.s3_artifact_service import S3ArtifactService from google.genai import types import pytest @@ -45,6 +48,7 @@ class ArtifactServiceType(Enum): FILE = "FILE" IN_MEMORY = "IN_MEMORY" GCS = "GCS" + S3 = "S3" class MockBlob: @@ -167,6 +171,139 @@ def mock_gcs_artifact_service(): return GcsArtifactService(bucket_name="test_bucket") +class MockBody: + + def __init__(self, data: bytes): + self._data = data + + async def read(self) -> bytes: + return self._data + + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc, tb): + pass + + +class MockAsyncS3Object: + + def __init__(self, key): + self.key = key + self.data = None + self.content_type = None + self.metadata = {} + self.last_modified = FIXED_DATETIME + + async def put(self, Body, ContentType=None, Metadata=None): + self.data = Body if isinstance(Body, bytes) else Body.encode("utf-8") + self.content_type = ContentType + self.metadata = Metadata or {} + + async def get(self): + if self.data is None: + raise ClientError( + {"Error": {"Code": "NoSuchKey", "Message": "Not Found"}}, + operation_name="GetObject", + ) + return { + "Body": MockBody(self.data), + "ContentType": self.content_type, + "Metadata": self.metadata, + "LastModified": self.last_modified, + } + + +class MockAsyncS3Bucket: + + def __init__(self, name): + self.name = name + self.objects = {} + + def object(self, key): + if key not in self.objects: + self.objects[key] = MockAsyncS3Object(key) + return self.objects[key] + + async def listed_keys(self, prefix=None): + return [ + k + for k, obj in self.objects.items() + if obj.data is not None and (prefix is None or k.startswith(prefix)) + ] + + +class MockAsyncS3Client: + + def __init__(self): + self.buckets = {} + + def get_bucket(self, bucket_name): + if bucket_name not in self.buckets: + self.buckets[bucket_name] = MockAsyncS3Bucket(bucket_name) + return self.buckets[bucket_name] + + async def put_object( + self, Bucket, Key, Body, ContentType=None, Metadata=None + ): + bucket = self.get_bucket(Bucket) + await bucket.object(Key).put( + Body=Body, ContentType=ContentType, Metadata=Metadata + ) + + async def get_object(self, Bucket, Key): + bucket = self.get_bucket(Bucket) + obj = bucket.object(Key) + return await obj.get() + + async def delete_object(self, Bucket, Key): + bucket = self.get_bucket(Bucket) + bucket.objects.pop(Key, None) + + async def list_objects_v2(self, Bucket, Prefix=None): + bucket = self.get_bucket(Bucket) + keys = await bucket.listed_keys(Prefix) + return { + "KeyCount": len(keys), + "Contents": [ + {"Key": k, "LastModified": bucket.objects[k].last_modified} + for k in keys + ], + } + + async def head_object(self, Bucket, Key): + obj = await self.get_object(Bucket, Key) + return { + "ContentType": obj["ContentType"], + "Metadata": obj.get("Metadata", {}), + "LastModified": obj.get("LastModified"), + } + + +def mock_s3_artifact_service(): + mock_s3_client = MockAsyncS3Client() + + class MockSession: + + def client(self, *args, **kwargs): + class MockClientCtx: + + async def __aenter__(self_inner): + return mock_s3_client + + async def __aexit__(self_inner, exc_type, exc, tb): + pass + + return MockClientCtx() + + class MockAioboto3: + Session = MockSession + + sys.modules["aioboto3"] = MockAioboto3 + artifact_service = S3ArtifactService(bucket_name="test_bucket") + return artifact_service + + @pytest.fixture def artifact_service_factory(tmp_path: Path): """Provides an artifact service constructor bound to the test tmp path.""" @@ -178,6 +315,8 @@ def factory( return mock_gcs_artifact_service() if service_type == ArtifactServiceType.FILE: return FileArtifactService(root_dir=tmp_path / "artifacts") + if service_type == ArtifactServiceType.S3: + return mock_s3_artifact_service() return InMemoryArtifactService() return factory @@ -190,6 +329,7 @@ def factory( ArtifactServiceType.IN_MEMORY, ArtifactServiceType.GCS, ArtifactServiceType.FILE, + ArtifactServiceType.S3, ], ) async def test_load_empty(service_type, artifact_service_factory): @@ -210,6 +350,7 @@ async def test_load_empty(service_type, artifact_service_factory): ArtifactServiceType.IN_MEMORY, ArtifactServiceType.GCS, ArtifactServiceType.FILE, + ArtifactServiceType.S3, ], ) async def test_save_load_delete(service_type, artifact_service_factory): @@ -268,6 +409,7 @@ async def test_save_load_delete(service_type, artifact_service_factory): ArtifactServiceType.IN_MEMORY, ArtifactServiceType.GCS, ArtifactServiceType.FILE, + ArtifactServiceType.S3, ], ) async def test_list_keys(service_type, artifact_service_factory): @@ -304,6 +446,7 @@ async def test_list_keys(service_type, artifact_service_factory): ArtifactServiceType.IN_MEMORY, ArtifactServiceType.GCS, ArtifactServiceType.FILE, + ArtifactServiceType.S3, ], ) async def test_list_versions(service_type, artifact_service_factory): @@ -348,6 +491,7 @@ async def test_list_versions(service_type, artifact_service_factory): ArtifactServiceType.IN_MEMORY, ArtifactServiceType.GCS, ArtifactServiceType.FILE, + ArtifactServiceType.S3, ], ) async def test_list_keys_preserves_user_prefix( @@ -398,7 +542,12 @@ async def test_list_keys_preserves_user_prefix( @pytest.mark.asyncio @pytest.mark.parametrize( - "service_type", [ArtifactServiceType.IN_MEMORY, ArtifactServiceType.GCS] + "service_type", + [ + ArtifactServiceType.IN_MEMORY, + ArtifactServiceType.GCS, + ArtifactServiceType.S3, + ], ) async def test_list_artifact_versions_and_get_artifact_version( service_type, artifact_service_factory @@ -446,6 +595,10 @@ async def test_list_artifact_versions_and_get_artifact_version( uri = ( f"gs://test_bucket/{app_name}/{user_id}/{session_id}/{filename}/{i}" ) + elif service_type == ArtifactServiceType.S3: + uri = ( + f"s3://test_bucket/{app_name}/{user_id}/{session_id}/{filename}/{i}" + ) else: uri = f"memory://apps/{app_name}/users/{user_id}/sessions/{session_id}/artifacts/{filename}/versions/{i}" expected_artifact_versions.append( @@ -485,7 +638,12 @@ async def test_list_artifact_versions_and_get_artifact_version( @pytest.mark.asyncio @pytest.mark.parametrize( - "service_type", [ArtifactServiceType.IN_MEMORY, ArtifactServiceType.GCS] + "service_type", + [ + ArtifactServiceType.IN_MEMORY, + ArtifactServiceType.GCS, + ArtifactServiceType.S3, + ], ) async def test_list_artifact_versions_with_user_prefix( service_type, artifact_service_factory @@ -532,6 +690,8 @@ async def test_list_artifact_versions_with_user_prefix( metadata = {"key": "value" + str(i)} if service_type == ArtifactServiceType.GCS: uri = f"gs://test_bucket/{app_name}/{user_id}/user/{user_scoped_filename}/{i}" + elif service_type == ArtifactServiceType.S3: + uri = f"s3://test_bucket/{app_name}/{user_id}/user/{user_scoped_filename}/{i}" else: uri = f"memory://apps/{app_name}/users/{user_id}/artifacts/{user_scoped_filename}/versions/{i}" expected_artifact_versions.append( @@ -548,7 +708,12 @@ async def test_list_artifact_versions_with_user_prefix( @pytest.mark.asyncio @pytest.mark.parametrize( - "service_type", [ArtifactServiceType.IN_MEMORY, ArtifactServiceType.GCS] + "service_type", + [ + ArtifactServiceType.IN_MEMORY, + ArtifactServiceType.GCS, + ArtifactServiceType.S3, + ], ) async def test_get_artifact_version_artifact_does_not_exist( service_type, artifact_service_factory @@ -565,7 +730,12 @@ async def test_get_artifact_version_artifact_does_not_exist( @pytest.mark.asyncio @pytest.mark.parametrize( - "service_type", [ArtifactServiceType.IN_MEMORY, ArtifactServiceType.GCS] + "service_type", + [ + ArtifactServiceType.IN_MEMORY, + ArtifactServiceType.GCS, + ArtifactServiceType.S3, + ], ) async def test_get_artifact_version_out_of_index( service_type, artifact_service_factory From a9b93268356e4a76912f69867e10e7a506a45e02 Mon Sep 17 00:00:00 2001 From: wmsnp Date: Sat, 15 Nov 2025 17:30:17 +0800 Subject: [PATCH 2/4] Fix Windows path handling in FileArtifactService and its tests --- .../adk/artifacts/file_artifact_service.py | 2 +- .../adk/artifacts/s3_artifact_service.py | 620 +++++++++--------- .../artifacts/test_artifact_service.py | 5 +- 3 files changed, 314 insertions(+), 313 deletions(-) diff --git a/src/google/adk/artifacts/file_artifact_service.py b/src/google/adk/artifacts/file_artifact_service.py index 97b2fb147d..203e2eac41 100644 --- a/src/google/adk/artifacts/file_artifact_service.py +++ b/src/google/adk/artifacts/file_artifact_service.py @@ -106,7 +106,7 @@ def _resolve_scoped_artifact_path( pure_path = _to_posix_path(stripped) scope_root_resolved = scope_root.resolve(strict=False) - if pure_path.is_absolute(): + if Path(stripped).is_absolute(): raise ValueError( f"Absolute artifact filename {filename!r} is not permitted; " "provide a path relative to the storage scope." diff --git a/src/google/adk/artifacts/s3_artifact_service.py b/src/google/adk/artifacts/s3_artifact_service.py index 2b1aa59803..00377795cb 100644 --- a/src/google/adk/artifacts/s3_artifact_service.py +++ b/src/google/adk/artifacts/s3_artifact_service.py @@ -1,310 +1,310 @@ -"""An artifact service implementation using Amazon S3 or other S3-compatible services. - -The blob/key name format depends on whether the filename has a user namespace: - - For files with user namespace (starting with "user:"): - {app_name}/{user_id}/user/{filename}/{version} - - For regular session-scoped files: - {app_name}/{user_id}/{session_id}/{filename}/{version} - -This service supports storing and retrieving artifacts with inline data or text. -Artifacts can also have optional custom metadata, which is serialized as JSON -when stored in S3. -""" - -from __future__ import annotations - -import json -import logging -from typing import Any -from typing import override - -from google.genai import types -from pydantic import BaseModel - -from .base_artifact_service import ArtifactVersion -from .base_artifact_service import BaseArtifactService - -logger = logging.getLogger("google_adk." + __name__) - - -class S3ArtifactService(BaseArtifactService, BaseModel): - """An artifact service implementation using Amazon S3 or other S3-compatible services.""" - - bucket_name: str - aws_configs: dict[str, Any] = {} - _s3_client: Any = None - - async def _client(self): - """Creates or returns the aioboto3 S3 client.""" - import aioboto3 - - if self._s3_client is None: - self._s3_client = ( - await aioboto3.Session() - .client(service_name="s3", **self.aws_configs) - .__aenter__() - ) - return self._s3_client - - async def close(self): - """Closes the underlying S3 client session.""" - if self._s3_client: - await self._s3_client.__aexit__(None, None, None) - self._s3_client = None - - def _flatten_metadata(self, metadata: dict[str, Any]) -> dict[str, str]: - return {k: json.dumps(v) for k, v in (metadata or {}).items()} - - def _unflatten_metadata(self, metadata: dict[str, str]) -> dict[str, Any]: - return {k: json.loads(v) for k, v in (metadata or {}).items()} - - def _file_has_user_namespace(self, filename: str) -> bool: - return filename.startswith("user:") - - def _get_blob_prefix( - self, app_name: str, user_id: str, session_id: str | None, filename: str - ) -> str: - if self._file_has_user_namespace(filename): - return f"{app_name}/{user_id}/user/{filename}" - if session_id: - return f"{app_name}/{user_id}/{session_id}/{filename}" - raise ValueError("session_id is required for session-scoped artifacts.") - - def _get_blob_name( - self, - app_name: str, - user_id: str, - session_id: str | None, - filename: str, - version: int, - ) -> str: - return ( - f"{self._get_blob_prefix(app_name, user_id, session_id, filename)}/{version}" - ) - - @override - async def save_artifact( - self, - *, - app_name: str, - user_id: str, - filename: str, - artifact: types.Part, - session_id: str | None = None, - custom_metadata: dict[str, Any] | None = None, - ) -> int: - """Saves an artifact to S3 and returns its assigned version number. - - Args: - app_name: Application name. - user_id: User ID. - filename: Artifact filename. - artifact: The artifact data (inline_data or text). - session_id: Session ID for session-scoped artifacts. - custom_metadata: Optional metadata to store with the artifact. - - Returns: - The version number of the saved artifact. - """ - s3 = await self._client() - versions = await self.list_versions( - app_name=app_name, - user_id=user_id, - filename=filename, - session_id=session_id, - ) - version = 0 if not versions else max(versions) + 1 - key = self._get_blob_name(app_name, user_id, session_id, filename, version) - - if artifact.inline_data: - body = artifact.inline_data.data - mime_type = artifact.inline_data.mime_type - elif artifact.text: - body = artifact.text - mime_type = "text/plain" - elif artifact.file_data: - raise NotImplementedError( - "Saving artifact with file_data is not supported yet in" - " S3ArtifactService." - ) - else: - raise ValueError("Artifact must have either inline_data or text.") - await s3.put_object( - Bucket=self.bucket_name, - Key=key, - Body=body, - ContentType=mime_type, - Metadata=self._flatten_metadata(custom_metadata), - ) - return version - - @override - async def load_artifact( - self, - *, - app_name: str, - user_id: str, - filename: str, - session_id: str | None = None, - version: int | None = None, - ) -> types.Part | None: - """Loads a specific version of an artifact from S3. - - If version is not provided, the latest version is loaded. - - Returns: - A types.Part instance (always with inline_data), or None if the artifact does not exist. - """ - from botocore.exceptions import ClientError - - s3 = await self._client() - if version is None: - versions = await self.list_versions( - app_name=app_name, - user_id=user_id, - filename=filename, - session_id=session_id, - ) - if not versions: - return None - version = max(versions) - - key = self._get_blob_name(app_name, user_id, session_id, filename, version) - try: - response = await s3.get_object(Bucket=self.bucket_name, Key=key) - async with response["Body"] as stream: - data = await stream.read() - mime_type = response["ContentType"] - except ClientError as e: - if e.response["Error"]["Code"] == "NoSuchKey": - return None - raise - return types.Part.from_bytes(data=data, mime_type=mime_type) - - @override - async def list_artifact_keys( - self, *, app_name: str, user_id: str, session_id: str | None = None - ) -> list[str]: - """Lists all artifact keys for a user, optionally filtered by session.""" - s3 = await self._client() - keys = set() - prefixes = [ - f"{app_name}/{user_id}/{session_id}/" if session_id else None, - f"{app_name}/{user_id}/user/", - ] - - for prefix in filter(None, prefixes): - response = await s3.list_objects_v2( - Bucket=self.bucket_name, Prefix=prefix - ) - for obj in response.get("Contents", []): - relative = obj["Key"][len(prefix) :] - filename = "/".join(relative.split("/")[:-1]) - keys.add(filename) - return sorted(keys) - - @override - async def delete_artifact( - self, - *, - app_name: str, - user_id: str, - filename: str, - session_id: str | None = None, - ) -> None: - """Deletes all versions of a specified artifact.""" - s3 = await self._client() - versions = await self.list_versions( - app_name=app_name, - user_id=user_id, - filename=filename, - session_id=session_id, - ) - for v in versions: - key = self._get_blob_name(app_name, user_id, session_id, filename, v) - await s3.delete_object(Bucket=self.bucket_name, Key=key) - - @override - async def list_versions( - self, - *, - app_name: str, - user_id: str, - filename: str, - session_id: str | None = None, - ) -> list[int]: - """Lists all available versions of a specified artifact.""" - s3 = await self._client() - prefix = ( - self._get_blob_prefix(app_name, user_id, session_id, filename) + "/" - ) - versions = [] - response = await s3.list_objects_v2(Bucket=self.bucket_name, Prefix=prefix) - for obj in response.get("Contents", []): - try: - versions.append(int(obj["Key"].split("/")[-1])) - except ValueError: - continue - return sorted(versions) - - @override - async def list_artifact_versions( - self, - *, - app_name: str, - user_id: str, - filename: str, - session_id: str | None = None, - ) -> list[ArtifactVersion]: - """Lists all artifact versions with their metadata.""" - s3 = await self._client() - prefix = ( - self._get_blob_prefix(app_name, user_id, session_id, filename) + "/" - ) - results: list[ArtifactVersion] = [] - - response = await s3.list_objects_v2(Bucket=self.bucket_name, Prefix=prefix) - for obj in response.get("Contents", []): - try: - version = int(obj["Key"].split("/")[-1]) - except ValueError: - continue - head = await s3.head_object(Bucket=self.bucket_name, Key=obj["Key"]) - mime_type = head["ContentType"] - metadata = head.get("Metadata", {}) - - canonical_uri = f"s3://{self.bucket_name}/{obj['Key']}" - - results.append( - ArtifactVersion( - version=version, - canonical_uri=canonical_uri, - custom_metadata=self._unflatten_metadata(metadata), - create_time=obj["LastModified"].timestamp(), - mime_type=mime_type, - ) - ) - return sorted(results, key=lambda a: a.version) - - @override - async def get_artifact_version( - self, - *, - app_name: str, - user_id: str, - filename: str, - session_id: str | None = None, - version: int | None = None, - ) -> ArtifactVersion | None: - """Retrieves a specific artifact version, or the latest if version is None.""" - versions = await self.list_artifact_versions( - app_name=app_name, - user_id=user_id, - filename=filename, - session_id=session_id, - ) - if not versions: - return None - if version is None: - return max(versions, key=lambda v: v.version) - return next(filter(lambda av: av.version == version, versions), None) +"""An artifact service implementation using Amazon S3 or other S3-compatible services. + +The blob/key name format depends on whether the filename has a user namespace: + - For files with user namespace (starting with "user:"): + {app_name}/{user_id}/user/{filename}/{version} + - For regular session-scoped files: + {app_name}/{user_id}/{session_id}/{filename}/{version} + +This service supports storing and retrieving artifacts with inline data or text. +Artifacts can also have optional custom metadata, which is serialized as JSON +when stored in S3. +""" + +from __future__ import annotations + +import json +import logging +from typing import Any +from typing import override + +from google.genai import types +from pydantic import BaseModel + +from .base_artifact_service import ArtifactVersion +from .base_artifact_service import BaseArtifactService + +logger = logging.getLogger("google_adk." + __name__) + + +class S3ArtifactService(BaseArtifactService, BaseModel): + """An artifact service implementation using Amazon S3 or other S3-compatible services.""" + + bucket_name: str + aws_configs: dict[str, Any] = {} + _s3_client: Any = None + + async def _client(self): + """Creates or returns the aioboto3 S3 client.""" + import aioboto3 + + if self._s3_client is None: + self._s3_client = ( + await aioboto3.Session() + .client(service_name="s3", **self.aws_configs) + .__aenter__() + ) + return self._s3_client + + async def close(self): + """Closes the underlying S3 client session.""" + if self._s3_client: + await self._s3_client.__aexit__(None, None, None) + self._s3_client = None + + def _flatten_metadata(self, metadata: dict[str, Any]) -> dict[str, str]: + return {k: json.dumps(v) for k, v in (metadata or {}).items()} + + def _unflatten_metadata(self, metadata: dict[str, str]) -> dict[str, Any]: + return {k: json.loads(v) for k, v in (metadata or {}).items()} + + def _file_has_user_namespace(self, filename: str) -> bool: + return filename.startswith("user:") + + def _get_blob_prefix( + self, app_name: str, user_id: str, session_id: str | None, filename: str + ) -> str: + if self._file_has_user_namespace(filename): + return f"{app_name}/{user_id}/user/{filename}" + if session_id: + return f"{app_name}/{user_id}/{session_id}/{filename}" + raise ValueError("session_id is required for session-scoped artifacts.") + + def _get_blob_name( + self, + app_name: str, + user_id: str, + session_id: str | None, + filename: str, + version: int, + ) -> str: + return ( + f"{self._get_blob_prefix(app_name, user_id, session_id, filename)}/{version}" + ) + + @override + async def save_artifact( + self, + *, + app_name: str, + user_id: str, + filename: str, + artifact: types.Part, + session_id: str | None = None, + custom_metadata: dict[str, Any] | None = None, + ) -> int: + """Saves an artifact to S3 and returns its assigned version number. + + Args: + app_name: Application name. + user_id: User ID. + filename: Artifact filename. + artifact: The artifact data (inline_data or text). + session_id: Session ID for session-scoped artifacts. + custom_metadata: Optional metadata to store with the artifact. + + Returns: + The version number of the saved artifact. + """ + s3 = await self._client() + versions = await self.list_versions( + app_name=app_name, + user_id=user_id, + filename=filename, + session_id=session_id, + ) + version = 0 if not versions else max(versions) + 1 + key = self._get_blob_name(app_name, user_id, session_id, filename, version) + + if artifact.inline_data: + body = artifact.inline_data.data + mime_type = artifact.inline_data.mime_type + elif artifact.text: + body = artifact.text + mime_type = "text/plain" + elif artifact.file_data: + raise NotImplementedError( + "Saving artifact with file_data is not supported yet in" + " S3ArtifactService." + ) + else: + raise ValueError("Artifact must have either inline_data or text.") + await s3.put_object( + Bucket=self.bucket_name, + Key=key, + Body=body, + ContentType=mime_type, + Metadata=self._flatten_metadata(custom_metadata), + ) + return version + + @override + async def load_artifact( + self, + *, + app_name: str, + user_id: str, + filename: str, + session_id: str | None = None, + version: int | None = None, + ) -> types.Part | None: + """Loads a specific version of an artifact from S3. + + If version is not provided, the latest version is loaded. + + Returns: + A types.Part instance (always with inline_data), or None if the artifact does not exist. + """ + from botocore.exceptions import ClientError + + s3 = await self._client() + if version is None: + versions = await self.list_versions( + app_name=app_name, + user_id=user_id, + filename=filename, + session_id=session_id, + ) + if not versions: + return None + version = max(versions) + + key = self._get_blob_name(app_name, user_id, session_id, filename, version) + try: + response = await s3.get_object(Bucket=self.bucket_name, Key=key) + async with response["Body"] as stream: + data = await stream.read() + mime_type = response["ContentType"] + except ClientError as e: + if e.response["Error"]["Code"] == "NoSuchKey": + return None + raise + return types.Part.from_bytes(data=data, mime_type=mime_type) + + @override + async def list_artifact_keys( + self, *, app_name: str, user_id: str, session_id: str | None = None + ) -> list[str]: + """Lists all artifact keys for a user, optionally filtered by session.""" + s3 = await self._client() + keys = set() + prefixes = [ + f"{app_name}/{user_id}/{session_id}/" if session_id else None, + f"{app_name}/{user_id}/user/", + ] + + for prefix in filter(None, prefixes): + response = await s3.list_objects_v2( + Bucket=self.bucket_name, Prefix=prefix + ) + for obj in response.get("Contents", []): + relative = obj["Key"][len(prefix) :] + filename = "/".join(relative.split("/")[:-1]) + keys.add(filename) + return sorted(keys) + + @override + async def delete_artifact( + self, + *, + app_name: str, + user_id: str, + filename: str, + session_id: str | None = None, + ) -> None: + """Deletes all versions of a specified artifact.""" + s3 = await self._client() + versions = await self.list_versions( + app_name=app_name, + user_id=user_id, + filename=filename, + session_id=session_id, + ) + for v in versions: + key = self._get_blob_name(app_name, user_id, session_id, filename, v) + await s3.delete_object(Bucket=self.bucket_name, Key=key) + + @override + async def list_versions( + self, + *, + app_name: str, + user_id: str, + filename: str, + session_id: str | None = None, + ) -> list[int]: + """Lists all available versions of a specified artifact.""" + s3 = await self._client() + prefix = ( + self._get_blob_prefix(app_name, user_id, session_id, filename) + "/" + ) + versions = [] + response = await s3.list_objects_v2(Bucket=self.bucket_name, Prefix=prefix) + for obj in response.get("Contents", []): + try: + versions.append(int(obj["Key"].split("/")[-1])) + except ValueError: + continue + return sorted(versions) + + @override + async def list_artifact_versions( + self, + *, + app_name: str, + user_id: str, + filename: str, + session_id: str | None = None, + ) -> list[ArtifactVersion]: + """Lists all artifact versions with their metadata.""" + s3 = await self._client() + prefix = ( + self._get_blob_prefix(app_name, user_id, session_id, filename) + "/" + ) + results: list[ArtifactVersion] = [] + + response = await s3.list_objects_v2(Bucket=self.bucket_name, Prefix=prefix) + for obj in response.get("Contents", []): + try: + version = int(obj["Key"].split("/")[-1]) + except ValueError: + continue + head = await s3.head_object(Bucket=self.bucket_name, Key=obj["Key"]) + mime_type = head["ContentType"] + metadata = head.get("Metadata", {}) + + canonical_uri = f"s3://{self.bucket_name}/{obj['Key']}" + + results.append( + ArtifactVersion( + version=version, + canonical_uri=canonical_uri, + custom_metadata=self._unflatten_metadata(metadata), + create_time=obj["LastModified"].timestamp(), + mime_type=mime_type, + ) + ) + return sorted(results, key=lambda a: a.version) + + @override + async def get_artifact_version( + self, + *, + app_name: str, + user_id: str, + filename: str, + session_id: str | None = None, + version: int | None = None, + ) -> ArtifactVersion | None: + """Retrieves a specific artifact version, or the latest if version is None.""" + versions = await self.list_artifact_versions( + app_name=app_name, + user_id=user_id, + filename=filename, + session_id=session_id, + ) + if not versions: + return None + if version is None: + return max(versions, key=lambda v: v.version) + return next(filter(lambda av: av.version == version, versions), None) diff --git a/tests/unittests/artifacts/test_artifact_service.py b/tests/unittests/artifacts/test_artifact_service.py index 8b55d318aa..b3621e5103 100644 --- a/tests/unittests/artifacts/test_artifact_service.py +++ b/tests/unittests/artifacts/test_artifact_service.py @@ -28,6 +28,7 @@ from unittest.mock import patch from urllib.parse import unquote from urllib.parse import urlparse +from urllib.request import url2pathname from botocore.exceptions import ClientError from google.adk.artifacts.base_artifact_service import ArtifactVersion @@ -813,7 +814,7 @@ async def test_file_metadata_camelcase(tmp_path, artifact_service_factory): "customMetadata": {}, } parsed_canonical = urlparse(metadata["canonicalUri"]) - canonical_path = Path(unquote(parsed_canonical.path)) + canonical_path = Path(url2pathname(unquote(parsed_canonical.path))) assert canonical_path.name == "report.txt" assert canonical_path.read_bytes() == b"binary-content" @@ -863,7 +864,7 @@ async def test_file_list_artifact_versions(tmp_path, artifact_service_factory): assert version_meta.canonical_uri == version_payload_path.as_uri() assert version_meta.custom_metadata == custom_metadata parsed_version_uri = urlparse(version_meta.canonical_uri) - version_uri_path = Path(unquote(parsed_version_uri.path)) + version_uri_path = Path(url2pathname(unquote(parsed_version_uri.path))) assert version_uri_path.read_bytes() == b"binary-content" fetched = await artifact_service.get_artifact_version( From ecbff9e2f3d4859af289e1023132f1680fe96df1 Mon Sep 17 00:00:00 2001 From: wmsnp Date: Sat, 15 Nov 2025 19:17:20 +0800 Subject: [PATCH 3/4] Fix multiple issues in S3ArtifactService and improve tests - Fixed data race during save_artifact - Added pagination to multiple list_objects_v2 calls - Deleted all versions when removing an artifact - Introduced potential data race in mock_s3_artifact_service to better simulate real concurrency - Correctly injected aioboto3 mock in unit tests --- .../adk/artifacts/s3_artifact_service.py | 231 +++++++++++------- .../artifacts/test_artifact_service.py | 79 ++++-- 2 files changed, 207 insertions(+), 103 deletions(-) diff --git a/src/google/adk/artifacts/s3_artifact_service.py b/src/google/adk/artifacts/s3_artifact_service.py index 00377795cb..6e98ccee8b 100644 --- a/src/google/adk/artifacts/s3_artifact_service.py +++ b/src/google/adk/artifacts/s3_artifact_service.py @@ -28,10 +28,18 @@ class S3ArtifactService(BaseArtifactService, BaseModel): - """An artifact service implementation using Amazon S3 or other S3-compatible services.""" + """An artifact service implementation using Amazon S3 or other S3-compatible services. + + Attributes: + bucket_name: The name of the S3 bucket to use for storing and retrieving artifacts. + aws_configs: A dictionary of AWS configuration options to pass to the boto3 client. + save_artifact_max_retries: The maximum number of retries to attempt when saving an artifact with version conflicts. + If set to -1, the service will retry indefinitely. + """ bucket_name: str aws_configs: dict[str, Any] = {} + save_artifact_max_retries: int = -1 _s3_client: Any = None async def _client(self): @@ -93,50 +101,60 @@ async def save_artifact( session_id: str | None = None, custom_metadata: dict[str, Any] | None = None, ) -> int: - """Saves an artifact to S3 and returns its assigned version number. - - Args: - app_name: Application name. - user_id: User ID. - filename: Artifact filename. - artifact: The artifact data (inline_data or text). - session_id: Session ID for session-scoped artifacts. - custom_metadata: Optional metadata to store with the artifact. + """Saves an artifact to S3 with atomic versioning using If-None-Match.""" + from botocore.exceptions import ClientError - Returns: - The version number of the saved artifact. - """ s3 = await self._client() - versions = await self.list_versions( - app_name=app_name, - user_id=user_id, - filename=filename, - session_id=session_id, - ) - version = 0 if not versions else max(versions) + 1 - key = self._get_blob_name(app_name, user_id, session_id, filename, version) - if artifact.inline_data: - body = artifact.inline_data.data - mime_type = artifact.inline_data.mime_type - elif artifact.text: - body = artifact.text - mime_type = "text/plain" - elif artifact.file_data: - raise NotImplementedError( - "Saving artifact with file_data is not supported yet in" - " S3ArtifactService." - ) + if self.save_artifact_max_retries < 0: + retry_iter = iter(int, 1) else: - raise ValueError("Artifact must have either inline_data or text.") - await s3.put_object( - Bucket=self.bucket_name, - Key=key, - Body=body, - ContentType=mime_type, - Metadata=self._flatten_metadata(custom_metadata), + retry_iter = range(self.save_artifact_max_retries + 1) + for _ in retry_iter: + versions = await self.list_versions( + app_name=app_name, + user_id=user_id, + filename=filename, + session_id=session_id, + ) + version = 0 if not versions else max(versions) + 1 + key = self._get_blob_name( + app_name, user_id, session_id, filename, version + ) + if artifact.inline_data: + body = artifact.inline_data.data + mime_type = artifact.inline_data.mime_type + elif artifact.text: + body = artifact.text + mime_type = "text/plain" + elif artifact.file_data: + raise NotImplementedError( + "Saving artifact with file_data is not supported yet in" + " S3ArtifactService." + ) + else: + raise ValueError("Artifact must have either inline_data or text.") + + try: + await s3.put_object( + Bucket=self.bucket_name, + Key=key, + Body=body, + ContentType=mime_type, + Metadata=self._flatten_metadata(custom_metadata), + IfNoneMatch="*", + ) + return version + except ClientError as e: + if e.response["Error"]["Code"] in ( + "PreconditionFailed", + "ObjectAlreadyExists", + ): + continue + raise e + raise RuntimeError( + "Failed to save artifact due to version conflicts after retries" ) - return version @override async def load_artifact( @@ -176,7 +194,7 @@ async def load_artifact( data = await stream.read() mime_type = response["ContentType"] except ClientError as e: - if e.response["Error"]["Code"] == "NoSuchKey": + if e.response["Error"]["Code"] in ("NoSuchKey", "404"): return None raise return types.Part.from_bytes(data=data, mime_type=mime_type) @@ -194,13 +212,14 @@ async def list_artifact_keys( ] for prefix in filter(None, prefixes): - response = await s3.list_objects_v2( + paginator = s3.get_paginator("list_objects_v2") + async for page in paginator.paginate( Bucket=self.bucket_name, Prefix=prefix - ) - for obj in response.get("Contents", []): - relative = obj["Key"][len(prefix) :] - filename = "/".join(relative.split("/")[:-1]) - keys.add(filename) + ): + for obj in page.get("Contents", []): + relative = obj["Key"][len(prefix) :] + filename = "/".join(relative.split("/")[:-1]) + keys.add(filename) return sorted(keys) @override @@ -212,7 +231,7 @@ async def delete_artifact( filename: str, session_id: str | None = None, ) -> None: - """Deletes all versions of a specified artifact.""" + """Deletes all versions of a specified artifact efficiently using batch delete.""" s3 = await self._client() versions = await self.list_versions( app_name=app_name, @@ -220,9 +239,18 @@ async def delete_artifact( filename=filename, session_id=session_id, ) - for v in versions: - key = self._get_blob_name(app_name, user_id, session_id, filename, v) - await s3.delete_object(Bucket=self.bucket_name, Key=key) + if not versions: + return + + keys_to_delete = [ + {"Key": self._get_blob_name(app_name, user_id, session_id, filename, v)} + for v in versions + ] + for i in range(0, len(keys_to_delete), 1000): + batch = keys_to_delete[i : i + 1000] + await s3.delete_objects( + Bucket=self.bucket_name, Delete={"Objects": batch} + ) @override async def list_versions( @@ -239,12 +267,15 @@ async def list_versions( self._get_blob_prefix(app_name, user_id, session_id, filename) + "/" ) versions = [] - response = await s3.list_objects_v2(Bucket=self.bucket_name, Prefix=prefix) - for obj in response.get("Contents", []): - try: - versions.append(int(obj["Key"].split("/")[-1])) - except ValueError: - continue + paginator = s3.get_paginator("list_objects_v2") + async for page in paginator.paginate( + Bucket=self.bucket_name, Prefix=prefix + ): + for obj in page.get("Contents", []): + try: + versions.append(int(obj["Key"].split("/")[-1])) + except ValueError: + continue return sorted(versions) @override @@ -263,27 +294,32 @@ async def list_artifact_versions( ) results: list[ArtifactVersion] = [] - response = await s3.list_objects_v2(Bucket=self.bucket_name, Prefix=prefix) - for obj in response.get("Contents", []): - try: - version = int(obj["Key"].split("/")[-1]) - except ValueError: - continue - head = await s3.head_object(Bucket=self.bucket_name, Key=obj["Key"]) - mime_type = head["ContentType"] - metadata = head.get("Metadata", {}) - - canonical_uri = f"s3://{self.bucket_name}/{obj['Key']}" - - results.append( - ArtifactVersion( - version=version, - canonical_uri=canonical_uri, - custom_metadata=self._unflatten_metadata(metadata), - create_time=obj["LastModified"].timestamp(), - mime_type=mime_type, - ) - ) + paginator = s3.get_paginator("list_objects_v2") + async for page in paginator.paginate( + Bucket=self.bucket_name, Prefix=prefix + ): + for obj in page.get("Contents", []): + try: + version = int(obj["Key"].split("/")[-1]) + except ValueError: + continue + + head = await s3.head_object(Bucket=self.bucket_name, Key=obj["Key"]) + mime_type = head["ContentType"] + metadata = head.get("Metadata", {}) + + canonical_uri = f"s3://{self.bucket_name}/{obj['Key']}" + + results.append( + ArtifactVersion( + version=version, + canonical_uri=canonical_uri, + custom_metadata=self._unflatten_metadata(metadata), + create_time=obj["LastModified"].timestamp(), + mime_type=mime_type, + ) + ) + return sorted(results, key=lambda a: a.version) @override @@ -297,14 +333,33 @@ async def get_artifact_version( version: int | None = None, ) -> ArtifactVersion | None: """Retrieves a specific artifact version, or the latest if version is None.""" - versions = await self.list_artifact_versions( - app_name=app_name, - user_id=user_id, - filename=filename, - session_id=session_id, - ) - if not versions: - return None + s3 = await self._client() if version is None: - return max(versions, key=lambda v: v.version) - return next(filter(lambda av: av.version == version, versions), None) + all_versions = await self.list_versions( + app_name=app_name, + user_id=user_id, + filename=filename, + session_id=session_id, + ) + if not all_versions: + return None + version = max(all_versions) + + key = self._get_blob_name(app_name, user_id, session_id, filename, version) + + from botocore.exceptions import ClientError + + try: + head = await s3.head_object(Bucket=self.bucket_name, Key=key) + except ClientError as e: + if e.response["Error"]["Code"] in ("NoSuchKey", "404"): + return None + raise + + return ArtifactVersion( + version=version, + canonical_uri=f"s3://{self.bucket_name}/{key}", + custom_metadata=self._unflatten_metadata(head.get("Metadata", {})), + create_time=head["LastModified"].timestamp(), + mime_type=head["ContentType"], + ) diff --git a/tests/unittests/artifacts/test_artifact_service.py b/tests/unittests/artifacts/test_artifact_service.py index b3621e5103..a0468809de 100644 --- a/tests/unittests/artifacts/test_artifact_service.py +++ b/tests/unittests/artifacts/test_artifact_service.py @@ -16,10 +16,12 @@ """Tests for the artifact service.""" +import asyncio from datetime import datetime import enum import json from pathlib import Path +import random import sys from typing import Any from typing import Optional @@ -245,9 +247,18 @@ def get_bucket(self, bucket_name): return self.buckets[bucket_name] async def put_object( - self, Bucket, Key, Body, ContentType=None, Metadata=None + self, Bucket, Key, Body, ContentType=None, Metadata=None, IfNoneMatch=None ): + await asyncio.sleep(random.uniform(0, 0.05)) bucket = self.get_bucket(Bucket) + obj_exists = Key in bucket.objects and bucket.objects[Key].data is not None + + if IfNoneMatch == "*" and obj_exists: + raise ClientError( + {"Error": {"Code": "PreconditionFailed", "Message": "Object exists"}}, + operation_name="PutObject", + ) + await bucket.object(Key).put( Body=Body, ContentType=ContentType, Metadata=Metadata ) @@ -261,6 +272,13 @@ async def delete_object(self, Bucket, Key): bucket = self.get_bucket(Bucket) bucket.objects.pop(Key, None) + async def delete_objects(self, Bucket, Delete): + bucket = self.get_bucket(Bucket) + for item in Delete.get("Objects", []): + key = item.get("Key") + if key in bucket.objects: + bucket.objects.pop(key) + async def list_objects_v2(self, Bucket, Prefix=None): bucket = self.get_bucket(Bucket) keys = await bucket.listed_keys(Prefix) @@ -280,33 +298,64 @@ async def head_object(self, Bucket, Key): "LastModified": obj.get("LastModified"), } + def get_paginator(self, operation_name): + if operation_name != "list_objects_v2": + raise NotImplementedError( + f"Paginator for {operation_name} not implemented" + ) -def mock_s3_artifact_service(): - mock_s3_client = MockAsyncS3Client() + class MockAsyncPaginator: - class MockSession: + def __init__(self, client, Bucket, Prefix=None): + self.client = client + self.Bucket = Bucket + self.Prefix = Prefix - def client(self, *args, **kwargs): - class MockClientCtx: + async def __aiter__(self): + response = await self.client.list_objects_v2( + Bucket=self.Bucket, Prefix=self.Prefix + ) + contents = response.get("Contents", []) + page_size = 2 + for i in range(0, len(contents), page_size): + yield { + "KeyCount": len(contents[i : i + page_size]), + "Contents": contents[i : i + page_size], + } - async def __aenter__(self_inner): - return mock_s3_client + class MockPaginator: - async def __aexit__(self_inner, exc_type, exc, tb): - pass + def paginate(inner_self, Bucket, Prefix=None): + return MockAsyncPaginator(self, Bucket, Prefix) - return MockClientCtx() + return MockPaginator() + + +def mock_s3_artifact_service(monkeypatch): + mock_s3_client = MockAsyncS3Client() class MockAioboto3: - Session = MockSession - sys.modules["aioboto3"] = MockAioboto3 + class Session: + + def client(self, *args, **kwargs): + class MockClientCtx: + + async def __aenter__(self_inner): + return mock_s3_client + + async def __aexit__(self_inner, exc_type, exc, tb): + pass + + return MockClientCtx() + + monkeypatch.setitem(sys.modules, "aioboto3", MockAioboto3) artifact_service = S3ArtifactService(bucket_name="test_bucket") return artifact_service @pytest.fixture -def artifact_service_factory(tmp_path: Path): +def artifact_service_factory(tmp_path: Path, monkeypatch): """Provides an artifact service constructor bound to the test tmp path.""" def factory( @@ -317,7 +366,7 @@ def factory( if service_type == ArtifactServiceType.FILE: return FileArtifactService(root_dir=tmp_path / "artifacts") if service_type == ArtifactServiceType.S3: - return mock_s3_artifact_service() + return mock_s3_artifact_service(monkeypatch) return InMemoryArtifactService() return factory From 458b8c8fb8ad96c3d0f5e02c352d8462335436f7 Mon Sep 17 00:00:00 2001 From: wmsnp Date: Wed, 19 Nov 2025 12:33:29 +0800 Subject: [PATCH 4/4] fix: update override import and improve metadata handling - Use `override` from `typing_extensions` instead of `typing` for better compatibility. - Add error handling to `_unflatten_metadata` to improve robustness. - Update `pyproject.toml` to resolve previous merge conflicts. --- pyproject.toml | 3 ++- src/google/adk/artifacts/s3_artifact_service.py | 13 +++++++++++-- 2 files changed, 13 insertions(+), 3 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 083924afda..452502bdd8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -115,7 +115,7 @@ eval = [ test = [ # go/keep-sorted start "a2a-sdk>=0.3.0,<0.4.0", - "aioboto3>=15.5.0", + "aioboto3>=15.5.0", # For S3 Artifact Service tests "anthropic>=0.43.0", # For anthropic model tests "crewai[tools];python_version>='3.10' and python_version<'3.12'", # For CrewaiTool tests; chromadb/pypika fail on 3.12+ "kubernetes>=29.0.0", # For GkeCodeExecutor @@ -145,6 +145,7 @@ docs = [ # Optional extensions extensions = [ + "aioboto3>=15.5.0", # For S3 Artifact Service "anthropic>=0.43.0", # For anthropic model support "beautifulsoup4>=3.2.2", # For load_web_page tool. "crewai[tools];python_version>='3.10' and python_version<'3.12'", # For CrewaiTool; chromadb/pypika fail on 3.12+ diff --git a/src/google/adk/artifacts/s3_artifact_service.py b/src/google/adk/artifacts/s3_artifact_service.py index 6e98ccee8b..fa2cad65c4 100644 --- a/src/google/adk/artifacts/s3_artifact_service.py +++ b/src/google/adk/artifacts/s3_artifact_service.py @@ -16,10 +16,10 @@ import json import logging from typing import Any -from typing import override from google.genai import types from pydantic import BaseModel +from typing_extensions import override from .base_artifact_service import ArtifactVersion from .base_artifact_service import BaseArtifactService @@ -64,7 +64,16 @@ def _flatten_metadata(self, metadata: dict[str, Any]) -> dict[str, str]: return {k: json.dumps(v) for k, v in (metadata or {}).items()} def _unflatten_metadata(self, metadata: dict[str, str]) -> dict[str, Any]: - return {k: json.loads(v) for k, v in (metadata or {}).items()} + results = {} + for k, v in (metadata or {}).items(): + try: + results[k] = json.loads(v) + except json.JSONDecodeError: + logger.warning( + f"Failed to decode metadata value for key {k}. Using raw string." + ) + results[k] = v + return results def _file_has_user_namespace(self, filename: str) -> bool: return filename.startswith("user:")