Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions app/filters/asset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
from typing import Annotated

from app.db.model import Asset
from app.dependencies.filter import FilterDepends
from app.filters.base import CustomFilter


class AssetFilter(CustomFilter):
order_by: list[str] = ["-creation_date"] # noqa: RUF012

class Constants(CustomFilter.Constants):
model = Asset
ordering_model_fields = ["creation_date"] # noqa: RUF012


AssetFilterDep = Annotated[AssetFilter, FilterDepends(AssetFilter)]
18 changes: 0 additions & 18 deletions app/repository/asset.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
"""Asset repository module."""

import uuid
from collections.abc import Sequence

import sqlalchemy as sa

Expand All @@ -14,23 +13,6 @@
class AssetRepository(BaseRepository):
"""AssetRepository."""

def get_entity_assets(
self,
entity_type: EntityType,
entity_id: uuid.UUID,
) -> Sequence[Asset]:
"""Return a sequence of assets, potentially empty."""
query = (
sa.select(Asset)
.join(Entity, Entity.id == Asset.entity_id)
.where(
Asset.entity_id == entity_id,
Asset.status != AssetStatus.DELETED,
Entity.type == entity_type.name,
)
)
return self.db.execute(query).scalars().all()

def get_entity_asset(
self,
entity_type: EntityType,
Expand Down
18 changes: 10 additions & 8 deletions app/routers/asset.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,17 +11,19 @@
from app.config import settings, storages
from app.db.types import AssetLabel, ContentType, StorageType
from app.dependencies.auth import UserContextDep, UserContextWithProjectIdDep
from app.dependencies.common import PaginationQuery
from app.dependencies.db import RepoGroupDep
from app.dependencies.s3 import StorageClientFactoryDep
from app.errors import ApiError, ApiErrorCode
from app.filters.asset import AssetFilterDep
from app.schemas.asset import (
AssetAndPresignedURLS,
AssetRead,
AssetRegister,
DetailedFileList,
DirectoryUpload,
)
from app.schemas.types import ListResponse, PaginationResponse
from app.schemas.types import ListResponse
from app.service import asset as asset_service
from app.utils.files import calculate_sha256_digest, get_content_type
from app.utils.routers import EntityRoute, entity_route_to_type
Expand All @@ -46,17 +48,17 @@ def get_entity_assets(
user_context: UserContextDep,
entity_route: EntityRoute,
entity_id: uuid.UUID,
pagination_request: PaginationQuery,
filter_model: AssetFilterDep,
) -> ListResponse[AssetRead]:
"""Return the list of assets associated with a specific entity."""
assets = asset_service.get_entity_assets(
repos,
return asset_service.get_entity_assets(
repos=repos,
user_context=user_context,
entity_type=entity_route_to_type(entity_route),
entity_route=entity_route,
entity_id=entity_id,
pagination_request=pagination_request,
filter_model=filter_model,
)
# TODO: proper pagination
pagination = PaginationResponse(page=1, page_size=len(assets), total_items=len(assets))
return ListResponse[AssetRead](data=assets, pagination=pagination)


@router.get("/{entity_route}/{entity_id}/assets/{asset_id}")
Expand Down
41 changes: 34 additions & 7 deletions app/service/asset.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,12 @@
from types_boto3_s3 import S3Client

from app.config import StorageUnion, storages
from app.db.model import Asset, Entity
from app.db.types import AssetLabel, AssetStatus, ContentType, EntityType, StorageType
from app.dependencies.common import PaginationQuery
from app.errors import ApiError, ApiErrorCode, ensure_result, ensure_uniqueness, ensure_valid_schema
from app.queries.common import get_or_create_user_agent
from app.filters.asset import AssetFilterDep
from app.queries.common import get_or_create_user_agent, router_read_many
from app.repository.group import RepositoryGroup
from app.schemas.asset import (
AssetCreate,
Expand All @@ -19,7 +22,9 @@
DirectoryUpload,
)
from app.schemas.auth import UserContext, UserContextWithProjectId
from app.schemas.types import ListResponse
from app.service import entity as entity_service
from app.utils.routers import EntityRoute, entity_route_to_type
from app.utils.s3 import (
StorageClientFactory,
build_s3_path,
Expand All @@ -33,20 +38,42 @@
def get_entity_assets(
repos: RepositoryGroup,
user_context: UserContext,
entity_type: EntityType,
entity_route: EntityRoute,
entity_id: uuid.UUID,
) -> list[AssetRead]:
pagination_request: PaginationQuery,
filter_model: AssetFilterDep,
) -> ListResponse[AssetRead]:
"""Return the list of assets associated with a specific entity."""
db_model_class = Asset
entity_type = entity_route_to_type(entity_route)
_ = entity_service.get_readable_entity(
repos,
user_context=user_context,
entity_type=entity_type,
entity_id=entity_id,
)
return [
AssetRead.model_validate(row)
for row in repos.asset.get_entity_assets(entity_type=entity_type, entity_id=entity_id)
]
apply_filter_query_operations = lambda q: q.join(Entity, Entity.id == Asset.entity_id).where(
Asset.entity_id == entity_id,
Asset.status != AssetStatus.DELETED,
Entity.type == entity_type.name,
)
name_to_facet_query_params = filter_joins = None
return router_read_many(
db=repos.db,
db_model_class=db_model_class,
authorized_project_id=user_context.project_id,
with_search=None,
with_in_brain_region=None,
facets=None,
aliases={},
apply_filter_query_operations=apply_filter_query_operations,
apply_data_query_operations=None,
pagination_request=pagination_request,
response_schema_class=AssetRead,
name_to_facet_query_params=name_to_facet_query_params,
filter_model=filter_model,
filter_joins=filter_joins,
)


def get_entity_asset(
Expand Down
106 changes: 106 additions & 0 deletions tests/routers/test_asset.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,33 @@ def test_upload_entity_asset(client, entity):
assert error.error_code == ApiErrorCode.ASSET_INVALID_CONTENT_TYPE


@pytest.mark.parametrize(
("client_fixture", "expected_status", "expected_error"),
[
("client_user_2", 404, ApiErrorCode.ENTITY_NOT_FOUND),
("client_no_project", 403, ApiErrorCode.NOT_AUTHORIZED),
],
)
def test_upload_entity_asset_non_authorized(
request, client_fixture, expected_status, expected_error, entity
):
client = request.getfixturevalue(client_fixture)

response = _upload_entity_asset(
client,
entity_type=entity.type,
entity_id=entity.id,
label="morphology",
file_upload_name="morph.asc",
content_type="application/asc",
)
assert response.status_code == expected_status, (
f"Asset creation didn't fail as expected: {response.text}"
)
error = ErrorResponse.model_validate(response.json())
assert error.error_code == expected_error


def test_upload_entity_asset__label(monkeypatch, client, entity):
response = _upload_entity_asset(
client,
Expand Down Expand Up @@ -451,6 +478,16 @@ def test_get_entity_asset(client, entity, asset):
assert error.error_code == ApiErrorCode.ASSET_NOT_FOUND


@pytest.mark.parametrize("client_fixture", ["client_user_2", "client_no_project"])
def test_get_entity_asset_non_authorized(request, client_fixture, entity, asset):
client = request.getfixturevalue(client_fixture)

response = client.get(f"{route(entity.type)}/{entity.id}/assets/{asset.id}")
assert response.status_code == 404, f"Unexpected result: {response.text}"
error = ErrorResponse.model_validate(response.json())
assert error.error_code == ApiErrorCode.ENTITY_NOT_FOUND


def test_get_entity_assets(client, entity, asset):
response = client.get(f"{route(entity.type)}/{entity.id}/assets")

Expand Down Expand Up @@ -480,6 +517,16 @@ def test_get_entity_assets(client, entity, asset):
assert error.error_code == ApiErrorCode.ENTITY_NOT_FOUND


@pytest.mark.parametrize("client_fixture", ["client_user_2", "client_no_project"])
def test_get_entity_assets_non_authorized(request, client_fixture, entity):
client = request.getfixturevalue(client_fixture)

response = client.get(f"{route(entity.type)}/{entity.id}/assets")
assert response.status_code == 404, f"Unexpected result: {response.text}"
error = ErrorResponse.model_validate(response.json())
assert error.error_code == ApiErrorCode.ENTITY_NOT_FOUND


def test_download_entity_asset(client, entity, asset):
response = client.get(
f"{route(entity.type)}/{entity.id}/assets/{asset.id}/download",
Expand Down Expand Up @@ -515,6 +562,16 @@ def test_download_entity_asset(client, entity, asset):
)


@pytest.mark.parametrize("client_fixture", ["client_user_2", "client_no_project"])
def test_download_entity_asset_non_authorized(request, client_fixture, entity, asset):
client = request.getfixturevalue(client_fixture)

response = client.get(f"{route(entity.type)}/{entity.id}/assets/{asset.id}/download")
assert response.status_code == 404, f"Unexpected result: {response.text}"
error = ErrorResponse.model_validate(response.json())
assert error.error_code == ApiErrorCode.ENTITY_NOT_FOUND


def test_delete_entity_asset(client, entity, asset):
response = client.delete(f"{route(entity.type)}/{entity.id}/assets/{asset.id}")
assert response.status_code == 200, f"Failed to delete asset: {response.text}"
Expand All @@ -524,14 +581,38 @@ def test_delete_entity_asset(client, entity, asset):
# try to delete again the same asset
response = client.delete(f"{route(entity.type)}/{entity.id}/assets/{asset.id}")
assert response.status_code == 404, f"Unexpected result: {response.text}"
error = ErrorResponse.model_validate(response.json())
assert error.error_code == ApiErrorCode.ASSET_NOT_FOUND

# try to delete an asset with non-existent entity id
response = client.delete(f"{route(entity.type)}/{MISSING_ID}/assets/{asset.id}")
assert response.status_code == 404, f"Unexpected result: {response.text}"
error = ErrorResponse.model_validate(response.json())
assert error.error_code == ApiErrorCode.ENTITY_NOT_FOUND

# try to delete an asset with non-existent asset id
response = client.delete(f"{route(entity.type)}/{entity.id}/assets/{MISSING_ID}")
assert response.status_code == 404, f"Unexpected result: {response.text}"
error = ErrorResponse.model_validate(response.json())
assert error.error_code == ApiErrorCode.ASSET_NOT_FOUND


@pytest.mark.parametrize(
("client_fixture", "expected_status", "expected_error"),
[
("client_user_2", 404, ApiErrorCode.ENTITY_NOT_FOUND),
("client_no_project", 403, ApiErrorCode.NOT_AUTHORIZED),
],
)
def test_delete_entity_asset_non_authorized(
request, client_fixture, expected_status, expected_error, entity, asset
):
client = request.getfixturevalue(client_fixture)

response = client.delete(f"{route(entity.type)}/{entity.id}/assets/{asset.id}")
assert response.status_code == expected_status, f"Unexpected result: {response.text}"
error = ErrorResponse.model_validate(response.json())
assert error.error_code == expected_error


def test_upload_delete_upload_entity_asset(client, entity):
Expand Down Expand Up @@ -750,11 +831,36 @@ def test_list_entity_asset_directory_failures(client, entity, asset):
# non-directory asset
response = client.get(f"{entity_type}/{entity.id}/assets/{asset.id}/list")
assert response.status_code == 422
error = ErrorResponse.model_validate(response.json())
assert error.error_code == ApiErrorCode.ASSET_NOT_A_DIRECTORY

# non-existent entity
response = client.get(f"{entity_type}/{MISSING_ID}/assets/{MISSING_ID}/list")
assert response.status_code == 404
error = ErrorResponse.model_validate(response.json())
assert error.error_code == ApiErrorCode.ENTITY_NOT_FOUND

# non-existent asset
response = client.get(f"{entity_type}/{entity.id}/assets/{MISSING_ID}/list")
assert response.status_code == 404
error = ErrorResponse.model_validate(response.json())
assert error.error_code == ApiErrorCode.ASSET_NOT_FOUND


@pytest.mark.parametrize(
("client_fixture", "expected_status", "expected_error"),
[
("client_user_2", 404, ApiErrorCode.ENTITY_NOT_FOUND),
("client_no_project", 403, ApiErrorCode.NOT_AUTHORIZED),
],
)
def test_list_entity_asset_directory_non_authorized(
request, client_fixture, expected_status, expected_error, entity, asset
):
client = request.getfixturevalue(client_fixture)
entity_type = route(entity.type)

response = client.get(f"{entity_type}/{entity.id}/assets/{asset.id}/list")
assert response.status_code == expected_status
error = ErrorResponse.model_validate(response.json())
assert error.error_code == expected_error