Skip to content

Commit f0a05fa

Browse files
schloerkewch
andauthored
Update InjectAutoreloadMiddleware to be compatible with starlette >= 0.35.0 (#1013)
Co-authored-by: Winston Chang <[email protected]>
1 parent 7579f24 commit f0a05fa

File tree

1 file changed

+32
-8
lines changed

1 file changed

+32
-8
lines changed

shiny/_autoreload.py

Lines changed: 32 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,9 @@
88
import secrets
99
import threading
1010
import webbrowser
11-
from typing import Callable, Optional
11+
from typing import Callable, Optional, cast
1212

13+
import starlette.types
1314
from asgiref.typing import (
1415
ASGI3Application,
1516
ASGIReceiveCallable,
@@ -90,8 +91,19 @@ class InjectAutoreloadMiddleware:
9091
because we want autoreload to be effective even when displaying an error page.
9192
"""
9293

93-
def __init__(self, app: ASGI3Application):
94-
self.app = app
94+
def __init__(
95+
self,
96+
app: starlette.types.ASGIApp | ASGI3Application,
97+
*args: object,
98+
**kwargs: object,
99+
):
100+
if len(args) > 0 or len(kwargs) > 0:
101+
raise TypeError(
102+
f"InjectAutoreloadMiddleware does not support positional or keyword arguments, received {args}, {kwargs}"
103+
)
104+
# The starlette types and the asgiref types are compatible, but we'll use the
105+
# latter internally. See the note in the __call__ method for more details.
106+
self.app = cast(ASGI3Application, app)
95107
ws_url = autoreload_url()
96108
self.script = (
97109
f""" <script src="__shared/shiny-autoreload.js" data-ws-url="{html.escape(ws_url)}"></script>
@@ -103,19 +115,31 @@ def __init__(self, app: ASGI3Application):
103115
)
104116

105117
async def __call__(
106-
self, scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable
118+
self,
119+
scope: starlette.types.Scope | Scope,
120+
receive: starlette.types.Receive | ASGIReceiveCallable,
121+
send: starlette.types.Send | ASGISendCallable,
107122
) -> None:
108-
if scope["type"] != "http" or scope["path"] != "/" or len(self.script) == 0:
109-
return await self.app(scope, receive, send)
123+
# The starlette types and the asgiref types are compatible, but the latter are
124+
# more rigorous. In the call interface, we accept both types for compatibility
125+
# with both. But internally we'll use the more rigorous types.
126+
# See https://github.com/encode/starlette/blob/39dccd9/docs/middleware.md#type-annotations
127+
scope = cast(Scope, scope)
128+
receive_casted = cast(ASGIReceiveCallable, receive)
129+
send_casted = cast(ASGISendCallable, send)
130+
if scope["type"] != "http":
131+
return await self.app(scope, receive_casted, send_casted)
132+
if scope["path"] != "/" or len(self.script) == 0:
133+
return await self.app(scope, receive_casted, send_casted)
110134

111135
def mangle_callback(body: bytes) -> tuple[bytes, bool]:
112136
if b"</head>" in body:
113137
return (body.replace(b"</head>", self.script, 1), True)
114138
else:
115139
return (body, False)
116140

117-
mangler = ResponseMangler(send, mangle_callback)
118-
await self.app(scope, receive, mangler.send)
141+
mangler = ResponseMangler(send_casted, mangle_callback)
142+
await self.app(scope, receive_casted, mangler.send)
119143

120144

121145
# PARENT PROCESS ------------------------------------------------------------

0 commit comments

Comments
 (0)