1- """Middleware to remove BASE_PATH from incoming requests and update links in responses."""
1+ """Middleware to remove the application root path from incoming requests and update links in responses."""
22
33import logging
44import re
55from dataclasses import dataclass
6- from typing import Any
6+ from typing import Any , Optional
77from urllib .parse import urlparse , urlunparse
88
99from starlette .datastructures import Headers
1010from starlette .requests import Request
11- from starlette .types import ASGIApp , Receive , Scope , Send
11+ from starlette .types import ASGIApp , Scope
1212
1313from ..utils .middleware import JsonResponseMiddleware
1414from ..utils .stac import get_links
1717
1818
1919@dataclass
20- class BasePathMiddleware (JsonResponseMiddleware ):
20+ class ProcessLinksMiddleware (JsonResponseMiddleware ):
2121 """
22- Middleware to handle the base path of the request and update links in responses.
23-
24- IMPORTANT: This middleware must be the first middleware in the chain (ie last in the
25- order of declaration) so that it trims the base_path from the request path before
26- other middleware review the request.
22+ Middleware to update links in responses, removing the upstream_url path and adding
23+ the root_path if it exists.
2724 """
2825
2926 app : ASGIApp
30- base_path : str
31- transform_links : bool = True
27+ upstream_url : str
28+ root_path : Optional [ str ] = None
3229
3330 json_content_type_expr : str = r"application/(geo\+)?json"
3431
35- async def __call__ (self , scope : Scope , receive : Receive , send : Send ) -> None :
36- """Remove BASE_PATH from the request path if it exists."""
37- if scope ["type" ] != "http" :
38- return await self .app (scope , receive , send )
39-
40- path = scope ["path" ]
41-
42- # Remove base_path if it exists at the start of the path
43- if path .startswith (self .base_path ):
44- scope ["raw_path" ] = scope ["path" ].encode ()
45- scope ["path" ] = path [len (self .base_path ) :] or "/"
46-
47- return await super ().__call__ (scope , receive , send )
48-
4932 def should_transform_response (self , request : Request , scope : Scope ) -> bool :
5033 """Only transform responses with JSON content type."""
5134 return bool (
@@ -69,11 +52,22 @@ def transform_json(self, data: dict[str, Any], request: Request) -> dict[str, An
6952 if parsed_link .netloc != request .headers .get ("host" ):
7053 continue
7154
72- parsed_link = parsed_link ._replace (
73- path = f"{ self .base_path } { parsed_link .path } "
74- )
55+ # Remove the upstream_url path from the link if it exists
56+ if urlparse (self .upstream_url ).path != "/" :
57+ parsed_link = parsed_link ._replace (
58+ path = parsed_link .path [len (urlparse (self .upstream_url ).path ) :]
59+ )
60+
61+ # Add the root_path to the link if it exists
62+ if self .root_path :
63+ parsed_link = parsed_link ._replace (
64+ path = f"{ self .root_path } { parsed_link .path } "
65+ )
66+
7567 link ["href" ] = urlunparse (parsed_link )
7668 except Exception as e :
77- logger .warning ("Failed to parse link href %s: %s" , href , str (e ))
69+ logger .error (
70+ "Failed to parse link href %r, (ignoring): %s" , href , str (e )
71+ )
7872
7973 return data
0 commit comments