Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions shiny/_autoreload.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,11 @@ async def process_request(
) -> Optional[tuple[http.HTTPStatus, websockets.datastructures.HeadersLike, bytes]]:
# If there's no Upgrade header, it's not a WebSocket request.
if request_headers.get("Upgrade") is None:
# For some unknown reason, this fixes a tendency on GitHub Codespaces to
# correctly proxy through this request, but give a 404 when the redirect is
# followed and app_url is requested. With the sleep, both requests tend to
# succeed reliably.
await asyncio.sleep(1)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just want to check: it doesn't take this code path when developing locally, right?

return (http.HTTPStatus.MOVED_PERMANENTLY, [("Location", app_url)], b"")

async with websockets.serve(
Expand Down
71 changes: 54 additions & 17 deletions shiny/_hostenv.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,24 +7,34 @@
from ipaddress import ip_address
from subprocess import run
from typing import Pattern
from urllib.parse import urlparse
from urllib.parse import ParseResult, urlparse


def is_workbench() -> bool:
return bool(os.getenv("RS_SERVER_URL") and os.getenv("RS_SESSION_URL"))


def is_codespaces() -> bool:
# See https://docs.github.com/en/codespaces/developing-in-a-codespace/default-environment-variables-for-your-codespace
return bool(
os.getenv("CODESPACES")
and os.getenv("CODESPACE_NAME")
and os.getenv("GITHUB_CODESPACES_PORT_FORWARDING_DOMAIN")
)


def is_proxy_env() -> bool:
return is_workbench()
return is_workbench() or is_codespaces()


port_cache: dict[int, str] = {}


def get_proxy_url(url: str) -> str:
if not is_workbench():
if not is_proxy_env():
return url

# Regardless of proxying strategy, we don't want to proxy URLs that are not loopback
parts = urlparse(url)
is_loopback = parts.hostname == "localhost"
if not is_loopback:
Expand All @@ -35,6 +45,45 @@ def get_proxy_url(url: str) -> str:
if not is_loopback:
return url

# Regardless of proxying strategy, we need to know the port, whether explicit or
# implicit (from the scheme)
if parts.port is not None:
if parts.port == 0:
# Not sure if this is even legal but we're definitely not going to succeed
# in proxying it
return url
port = parts.port
elif parts.scheme.lower() in ["ws", "http"]:
port = 80
elif parts.scheme.lower() in ["wss", "https"]:
port = 443
else:
# No idea what kind of URL this is
return url

if is_workbench():
return _get_proxy_url_workbench(parts, port) or url
if is_codespaces():
return _get_proxy_url_codespaces(parts, port) or url
return url


def _get_proxy_url_codespaces(parts: ParseResult, port: int) -> str | None:
# See https://docs.github.com/en/codespaces/developing-in-a-codespace/default-environment-variables-for-your-codespace
codespace_name = os.getenv("CODESPACE_NAME")
port_forwarding_domain = os.getenv("GITHUB_CODESPACES_PORT_FORWARDING_DOMAIN")
netloc = f"{codespace_name}-{port}.{port_forwarding_domain}"
if parts.scheme.lower() in ["ws", "wss"]:
scheme = "wss"
elif parts.scheme.lower() in ["http", "https"]:
scheme = "https"
else:
return None

return parts._replace(scheme=scheme, netloc=netloc).geturl()


def _get_proxy_url_workbench(parts: ParseResult, port: int) -> str | None:
path = parts.path or "/"

server_url = os.getenv("RS_SERVER_URL", "")
Expand All @@ -45,18 +94,6 @@ def get_proxy_url(url: str) -> str:
server_url = re.sub("/$", "", server_url)
session_url = re.sub("^/", "", session_url)

port = (
parts.port
if parts.port
else (
80
if parts.scheme in ["ws", "http"]
else 443 if parts.scheme in ["wss", "https"] else 0
)
)
if port == 0:
return url

if port in port_cache:
ptoken = port_cache[port]
else:
Expand All @@ -67,9 +104,9 @@ def get_proxy_url(url: str) -> str:
encoding="ascii",
)
except FileNotFoundError:
return url
return None
if res.returncode != 0:
return url
return None
ptoken = res.stdout
port_cache[port] = ptoken

Expand Down