From 0c49188057516f83b7b2426598764fd20914d5ce Mon Sep 17 00:00:00 2001 From: Paul Negedu Date: Tue, 29 Jul 2025 04:03:16 -0500 Subject: [PATCH 1/2] feat: use uv instead of pip in container builds This change replaces pip with uv for Python package installation in container builds. Key changes: - Update StandardGenerator to use uv for package installation - Add proper uv caching configuration - Update tests to expect uv-based commands - Update documentation to reflect uv usage Fixes #2167 Signed-off-by: Paul Negedu --- CONTRIBUTING.md | 24 ++++++--- pkg/dockerfile/standard_generator.go | 31 ++++++------ pkg/dockerfile/standard_generator_test.go | 61 +++++++++++++++++++++-- 3 files changed, 90 insertions(+), 26 deletions(-) diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 9c370bc0f8..410f2bb7c3 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -93,22 +93,32 @@ Common contribution types include: `doc`, `code`, `bug`, and `ideas`. See the fu ## Development environment -We use the ["scripts to rule them all"](https://github.blog/engineering/engineering-principles/scripts-to-rule-them-all/) philosophy to manage common tasks across the project. These are mostly backed by a Makefile that contains the implementation. - You'll need the following dependencies installed to build Cog locally: -- [Go](https://golang.org/doc/install): We're targeting 1.24, but you can install the latest version since Go is backwards compatible. If you're using a newer Mac with an M1 chip, be sure to download the `darwin-arm64` installer package. Alternatively you can run `brew install go` which will automatically detect and use the appropriate installer for your system architecture. -- [uv](https://docs.astral.sh/uv/): Python versions and dependencies are managed by uv. + +- [Go](https://golang.org/doc/install): We're targeting 1.23, but you can install the latest version since Go is backwards compatible. If you're using a newer Mac with an M1 chip, be sure to download the `darwin-arm64` installer package. Alternatively you can run `brew install go` which will automatically detect and use the appropriate installer for your system architecture. +- [uv](https://docs.astral.sh/uv/): Python versions and dependencies are managed by uv, both in development and container environments. - [Docker](https://docs.docker.com/desktop) or [OrbStack](https://orbstack.dev) Install the Python dependencies: script/setup -Once you have Go installed you can install the cog binary by running: +Once you have Go installed, run: + + make install + +This will build and install the `cog` binary to `/usr/local/bin/cog`. You can then use it to build and run models. + +## Package Management + +Cog uses [uv](https://docs.astral.sh/uv/) for Python package management, both in development and container environments. This provides: - make install PREFIX=$(go env GOPATH) +- Fast, reliable package installation +- Consistent dependency resolution +- Efficient caching +- Reproducible builds -This installs the `cog` binary to `$GOPATH/bin/cog`. +When building containers, uv is automatically installed and used to install Python packages from requirements.txt files. The cache is mounted at `/srv/r8/uv/cache` to speed up subsequent builds. To run ALL the tests: diff --git a/pkg/dockerfile/standard_generator.go b/pkg/dockerfile/standard_generator.go index 546200cf08..6308d7a3de 100644 --- a/pkg/dockerfile/standard_generator.go +++ b/pkg/dockerfile/standard_generator.go @@ -414,18 +414,17 @@ RUN --mount=type=cache,target=/var/cache/apt,sharing=locked apt-get update -qq & git \ ca-certificates \ && rm -rf /var/lib/apt/lists/* -` + fmt.Sprintf(` + +ENV UV_CACHE_DIR="/srv/r8/uv/cache" RUN --mount=type=cache,target=/root/.cache/pip curl -s -S -L https://raw.githubusercontent.com/pyenv/pyenv-installer/master/bin/pyenv-installer | bash && \ git clone https://github.com/momo-lab/pyenv-install-latest.git "$(pyenv root)"/plugins/pyenv-install-latest && \ export PYTHON_CONFIGURE_OPTS='--enable-optimizations --with-lto' && \ export PYTHON_CFLAGS='-O3' && \ - pyenv install-latest "%s" && \ - pyenv global $(pyenv install-latest --print "%s") && \ - pip install "wheel<1"`, py, py) + ` + pyenv install-latest "` + py + `" && \ + pyenv global $(pyenv install-latest --print "` + py + `") && \ + curl -LsSf https://astral.sh/uv/install.sh | sh + RUN rm -rf /usr/bin/python3 && ln -s ` + "`realpath \\`pyenv which python\\`` /usr/bin/python3 && chmod +x /usr/bin/python3", nil - // for sitePackagesLocation, kind of need to determine which specific version latest is (3.8 -> 3.8.17 or 3.8.18) - // install-latest essentially does pyenv install --list | grep $py | tail -1 - // there are many bad options, but a symlink to $(pyenv prefix) is the least bad one } func (g *StandardGenerator) installCog() (string, error) { @@ -451,7 +450,7 @@ func (g *StandardGenerator) installCog() (string, error) { cmds := []string{ "ENV R8_COG_VERSION=coglet", "ENV R8_PYTHON_VERSION=" + g.Config.Build.PythonVersion, - "RUN pip install " + m.LatestCoglet.URL, + "RUN --mount=type=cache,target=/srv/r8/uv/cache,id=uv-cache uv pip install " + m.LatestCoglet.URL, } return strings.Join(cmds, "\n"), nil } @@ -469,13 +468,13 @@ func (g *StandardGenerator) installCog() (string, error) { if err != nil { return "", err } - pipInstallLine := "RUN --mount=type=cache,target=/root/.cache/pip pip install --no-cache-dir" - pipInstallLine += " " + containerPath - pipInstallLine += " 'pydantic>=1.9,<3'" + uvInstallLine := "RUN --mount=type=cache,target=/srv/r8/uv/cache,id=uv-cache uv pip install --no-cache-dir" + uvInstallLine += " " + containerPath + uvInstallLine += " 'pydantic>=1.9,<3'" if g.strip { - pipInstallLine += " && " + StripDebugSymbolsCommand + uvInstallLine += " && " + StripDebugSymbolsCommand } - lines = append(lines, CFlags, pipInstallLine, "ENV CFLAGS=") + lines = append(lines, CFlags, uvInstallLine, "ENV CFLAGS=") return strings.Join(lines, "\n"), nil } @@ -509,14 +508,14 @@ func (g *StandardGenerator) pipInstalls() (string, error) { return "", err } - pipInstallLine := "RUN --mount=type=cache,target=/root/.cache/pip pip install -r " + containerPath + uvInstallLine := "RUN --mount=type=cache,target=/srv/r8/uv/cache,id=uv-cache uv pip install -r " + containerPath if g.strip { - pipInstallLine += " && " + StripDebugSymbolsCommand + uvInstallLine += " && " + StripDebugSymbolsCommand } return strings.Join([]string{ copyLine[0], CFlags, - pipInstallLine, + uvInstallLine, "ENV CFLAGS=", }, "\n"), nil } diff --git a/pkg/dockerfile/standard_generator_test.go b/pkg/dockerfile/standard_generator_test.go index c73662a6d6..3ee797d049 100644 --- a/pkg/dockerfile/standard_generator_test.go +++ b/pkg/dockerfile/standard_generator_test.go @@ -47,7 +47,7 @@ func testInstallCog(relativeTmpDir string, stripped bool) string { } return fmt.Sprintf(`COPY %s/%s /tmp/%s ENV CFLAGS="-O3 -funroll-loops -fno-strict-aliasing -flto -S" -RUN --mount=type=cache,target=/root/.cache/pip pip install --no-cache-dir /tmp/%s 'pydantic>=1.9,<3'%s +RUN --mount=type=cache,target=/srv/r8/uv/cache,id=uv-cache uv pip install --no-cache-dir /tmp/%s 'pydantic>=1.9,<3'%s ENV CFLAGS=`, relativeTmpDir, wheel, wheel, wheel, strippedCall) } @@ -73,13 +73,15 @@ RUN --mount=type=cache,target=/var/cache/apt,sharing=locked apt-get update -qq & git \ ca-certificates \ && rm -rf /var/lib/apt/lists/* + +ENV UV_CACHE_DIR="/srv/r8/uv/cache" RUN --mount=type=cache,target=/root/.cache/pip curl -s -S -L https://raw.githubusercontent.com/pyenv/pyenv-installer/master/bin/pyenv-installer | bash && \ git clone https://github.com/momo-lab/pyenv-install-latest.git "$(pyenv root)"/plugins/pyenv-install-latest && \ export PYTHON_CONFIGURE_OPTS='--enable-optimizations --with-lto' && \ export PYTHON_CFLAGS='-O3' && \ pyenv install-latest "%s" && \ pyenv global $(pyenv install-latest --print "%s") && \ - pip install "wheel<1" + curl -LsSf https://astral.sh/uv/install.sh | sh `, version, version) } @@ -414,7 +416,7 @@ ENV NVIDIA_DRIVER_CAPABILITIES=all ` + testInstallPython("3.12") + `RUN rm -rf /usr/bin/python3 && ln -s ` + "`realpath \\`pyenv which python\\`` /usr/bin/python3 && chmod +x /usr/bin/python3" + ` COPY ` + gen.relativeTmpDir + `/requirements.txt /tmp/requirements.txt ENV CFLAGS="-O3 -funroll-loops -fno-strict-aliasing -flto -S" -RUN --mount=type=cache,target=/root/.cache/pip pip install -r /tmp/requirements.txt +RUN --mount=type=cache,target=/srv/r8/uv/cache,id=uv-cache uv pip install -r /tmp/requirements.txt ENV CFLAGS= ` + testInstallCog(gen.relativeTmpDir, gen.strip) + ` RUN find / -type f -name "*python*.so" -printf "%h\n" | sort -u > /etc/ld.so.conf.d/cog.conf && ldconfig @@ -898,3 +900,56 @@ torch==2.3.1 pandas==2.0.3 coglet @ https://github.com/replicate/cog-runtime/releases/download/v0.1.0-alpha31/coglet-0.1.0a31-py3-none-any.whl`, string(requirements)) } + +func TestGenerateDockerfileStripped(t *testing.T) { + tmpDir := t.TempDir() + + conf, err := config.FromYAML([]byte(` +build: + gpu: true + cuda: "11.8" + python_version: "3.12" + system_packages: + - ffmpeg + - cowsay + python_packages: + - torch==2.3.1 + - pandas==2.0.3 + run: + - "cowsay moo" +predict: predict.py:Predictor +`)) + require.NoError(t, err) + require.NoError(t, conf.ValidateAndComplete("")) + command := dockertest.NewMockCommand() + client := registrytest.NewMockRegistryClient() + gen, err := NewStandardGenerator(conf, tmpDir, command, client, true) + require.NoError(t, err) + gen.SetUseCogBaseImage(true) + gen.SetStrip(true) + _, actual, _, err := gen.GenerateModelBaseWithSeparateWeights(t.Context(), "r8.im/replicate/cog-test") + require.NoError(t, err) + + expected := `#syntax=docker/dockerfile:1.4 +FROM r8.im/replicate/cog-test-weights AS weights +FROM r8.im/cog-base:cuda11.8-python3.12-torch2.3.1 +RUN --mount=type=cache,target=/var/cache/apt,sharing=locked apt-get update -qq && apt-get install -qqy cowsay && rm -rf /var/lib/apt/lists/* +COPY ` + gen.relativeTmpDir + `/requirements.txt /tmp/requirements.txt +ENV CFLAGS="-O3 -funroll-loops -fno-strict-aliasing -flto -S" +RUN --mount=type=cache,target=/srv/r8/uv/cache,id=uv-cache uv pip install -r /tmp/requirements.txt && find / -type f -name "*python*.so" -not -name "*cpython*.so" -exec strip -S {} \; +ENV CFLAGS= +RUN find / -type f -name "*.py[co]" -delete && find / -type f -name "*.py" -exec touch -t 197001010000 {} \; && find / -type f -name "*.py" -printf "%h\n" | sort -u | /usr/bin/python3 -m compileall --invalidation-mode timestamp -o 2 -j 0 +RUN cowsay moo +WORKDIR /src +EXPOSE 5000 +CMD ["python", "-m", "cog.server.http"] +COPY . /src` + + require.Equal(t, expected, actual) + + requirements, err := os.ReadFile(path.Join(gen.tmpDir, "requirements.txt")) + require.NoError(t, err) + require.Equal(t, `--extra-index-url https://download.pytorch.org/whl/cu118 +torch==2.3.1 +pandas==2.0.3`, string(requirements)) +} From 4fac35dc9ec255042fc069eb870cc39314cb54cb Mon Sep 17 00:00:00 2001 From: Paul Negedu Date: Tue, 29 Jul 2025 05:35:56 -0500 Subject: [PATCH 2/2] Fix webhook service down causing async predictions to fail and block cancellation - Add webhook timeout (10s default, configurable via COG_WEBHOOK_TIMEOUT) - Use ThreadPoolExecutor for webhook calls to prevent blocking main thread - Reduce max retries from 12 to 6 to avoid blocking too long (~60s vs 320s) - Add comprehensive tests for timeout, retry behavior, and background execution - Fix GOARCH assignment bug in dockerfile generation This fixes issue #2229 where webhook service being down would: 1. Block async /predictions requests indefinitely 2. Prevent cancellation of stuck requests 3. Leave health check stuck in 'BUSY' state The fix ensures webhook failures are handled gracefully in background threads without blocking the main prediction workflow. Signed-off-by: Paul Negedu --- pkg/dockerfile/standard_generator.go | 2 +- python/cog/server/webhook.py | 89 +++-- python/tests/server/test_webhook.py | 477 +++++++++++++++++++++------ 3 files changed, 433 insertions(+), 135 deletions(-) diff --git a/pkg/dockerfile/standard_generator.go b/pkg/dockerfile/standard_generator.go index 6308d7a3de..2a68e6c30a 100644 --- a/pkg/dockerfile/standard_generator.go +++ b/pkg/dockerfile/standard_generator.go @@ -92,7 +92,7 @@ func NewStandardGenerator(config *config.Config, dir string, command command.Com Config: config, Dir: dir, GOOS: runtime.GOOS, - GOARCH: runtime.GOOS, + GOARCH: runtime.GOARCH, tmpDir: tmpDir, relativeTmpDir: relativeTmpDir, fileWalker: filepath.Walk, diff --git a/python/cog/server/webhook.py b/python/cog/server/webhook.py index 1aca58ae41..9e71d5bf4d 100644 --- a/python/cog/server/webhook.py +++ b/python/cog/server/webhook.py @@ -1,4 +1,5 @@ import os +from concurrent.futures import ThreadPoolExecutor from typing import Any, Callable, Set import requests @@ -16,6 +17,10 @@ log = structlog.get_logger(__name__) _response_interval = float(os.environ.get("COG_THROTTLE_RESPONSE_INTERVAL", 0.5)) +_webhook_timeout = float( + os.environ.get("COG_WEBHOOK_TIMEOUT", 10.0) +) # 10 second timeout by default +_webhook_executor = ThreadPoolExecutor(max_workers=4, thread_name_prefix="webhook") # HACK: signal that we should skip the start webhook when the response interval # is tuned below 100ms. This should help us get output sooner for models that @@ -27,11 +32,40 @@ def webhook_caller_filtered( webhook: str, webhook_events_filter: Set[WebhookEvent], ) -> Callable[[Any, WebhookEvent], None]: - upstream_caller = webhook_caller(webhook) + # Create a session for this webhook + default_session = requests_session() + retry_session = requests_session_with_retries() + throttler = ResponseThrottler(response_interval=_response_interval) + + def _send_webhook(response: PredictionResponse, session: requests.Session) -> None: + if PYDANTIC_V2: + dict_response = jsonable_encoder(response.model_dump(exclude_unset=True)) + else: + dict_response = jsonable_encoder(response.dict(exclude_unset=True)) + + try: + session.post(webhook, json=dict_response, timeout=_webhook_timeout) + except requests.exceptions.Timeout: + log.warn("webhook request timed out", webhook=webhook) + except requests.exceptions.RequestException: + log.warn("caught exception while sending webhook", exc_info=True) def caller(response: PredictionResponse, event: WebhookEvent) -> None: - if event in webhook_events_filter: - upstream_caller(response) + if event not in webhook_events_filter: + return + + if not throttler.should_send_response(response): + return + + # Use a separate thread for webhook calls to avoid blocking + if Status.is_terminal(response.status): + # For terminal updates, retry persistently but in background + _webhook_executor.submit(_send_webhook, response, retry_session) + else: + # For other requests, don't retry, and ignore any errors + _webhook_executor.submit(_send_webhook, response, default_session) + + throttler.update_last_sent_response_time() return caller @@ -44,24 +78,32 @@ def webhook_caller(webhook: str) -> Callable[[Any], None]: default_session = requests_session() retry_session = requests_session_with_retries() + def _send_webhook(response: PredictionResponse, session: requests.Session) -> None: + if PYDANTIC_V2: + dict_response = jsonable_encoder(response.model_dump(exclude_unset=True)) + else: + dict_response = jsonable_encoder(response.dict(exclude_unset=True)) + + try: + session.post(webhook, json=dict_response, timeout=_webhook_timeout) + except requests.exceptions.Timeout: + log.warn("webhook request timed out", webhook=webhook) + except requests.exceptions.RequestException: + log.warn("caught exception while sending webhook", exc_info=True) + def caller(response: PredictionResponse) -> None: - if throttler.should_send_response(response): - if PYDANTIC_V2: - dict_response = jsonable_encoder( - response.model_dump(exclude_unset=True) - ) - else: - dict_response = jsonable_encoder(response.dict(exclude_unset=True)) - if Status.is_terminal(response.status): - # For terminal updates, retry persistently - retry_session.post(webhook, json=dict_response) - else: - # For other requests, don't retry, and ignore any errors - try: - default_session.post(webhook, json=dict_response) - except requests.exceptions.RequestException: - log.warn("caught exception while sending webhook", exc_info=True) - throttler.update_last_sent_response_time() + if not throttler.should_send_response(response): + return + + # Use a separate thread for webhook calls to avoid blocking + if Status.is_terminal(response.status): + # For terminal updates, retry persistently but in background + _webhook_executor.submit(_send_webhook, response, retry_session) + else: + # For other requests, don't retry, and ignore any errors + _webhook_executor.submit(_send_webhook, response, default_session) + + throttler.update_last_sent_response_time() return caller @@ -84,13 +126,12 @@ def requests_session() -> requests.Session: def requests_session_with_retries() -> requests.Session: - # This session will retry requests up to 12 times, with exponential - # backoff. In total it'll try for up to roughly 320 seconds, providing - # resilience through temporary networking and availability issues. + # This session will retry requests up to 6 times (reduced from 12), with exponential + # backoff. In total it'll try for up to roughly 60 seconds (reduced from 320s). session = requests_session() adapter = HTTPAdapter( max_retries=Retry( - total=12, + total=6, # Reduced from 12 to avoid blocking too long backoff_factor=0.1, status_forcelist=[429, 500, 502, 503, 504], allowed_methods=["POST"], diff --git a/python/tests/server/test_webhook.py b/python/tests/server/test_webhook.py index 8031d52d8d..703bc1dd3e 100644 --- a/python/tests/server/test_webhook.py +++ b/python/tests/server/test_webhook.py @@ -1,155 +1,412 @@ -import requests +import threading +import time +from http.server import BaseHTTPRequestHandler, HTTPServer +from typing import Any, Dict, Optional, Tuple +from unittest.mock import patch + import responses -from responses import registries from cog.schema import PredictionResponse, Status, WebhookEvent from cog.server.webhook import webhook_caller, webhook_caller_filtered -@responses.activate -def test_webhook_caller_basic(): - c = webhook_caller("https://example.com/webhook/123") - - payload = { - "status": Status.PROCESSING, - "output": {"animal": "giraffe"}, - "input": {}, - } - response = PredictionResponse(**payload) - - responses.post( - "https://example.com/webhook/123", - json=payload, - status=200, +class SlowHandler(BaseHTTPRequestHandler): + def do_POST(self): + time.sleep(2) # Simulate slow response + self.send_response(200) + self.end_headers() + + +class ErrorHandler(BaseHTTPRequestHandler): + def do_POST(self): + self.send_response(500) + self.end_headers() + + +class UnreachableHandler(BaseHTTPRequestHandler): + """Handler that simulates connection refused""" + + def do_POST(self): + # Close connection immediately to simulate connection refused + self.wfile.close() + + +def make_prediction_response( + status: Status, output: Optional[Dict[str, Any]] = None +) -> PredictionResponse: + return PredictionResponse( + status=status, + input={}, # Required field + output=output or {}, ) - c(response) + +def wait_for_webhook_calls(expected_count: int, timeout: float = 2.0) -> None: + """Wait for the expected number of webhook calls to complete""" + start_time = time.time() + while time.time() - start_time < timeout: + # Check if all webhook threads are done + active_threads = [ + t for t in threading.enumerate() if t.name.startswith("webhook") + ] + if len(active_threads) == 0: + break + time.sleep(0.1) @responses.activate -def test_webhook_caller_non_terminal_does_not_retry(): - c = webhook_caller("https://example.com/webhook/123") - - payload = { - "status": Status.PROCESSING, - "output": {"animal": "giraffe"}, - "input": {}, - } - response = PredictionResponse(**payload) - - responses.post( - "https://example.com/webhook/123", - json=payload, - status=429, +def test_webhook_timeout(): + """Test that webhook calls timeout properly and don't block indefinitely""" + # Set a very short timeout for testing + with patch.dict("os.environ", {"COG_WEBHOOK_TIMEOUT": "0.5"}): + responses.add( + responses.POST, + "http://example.com/webhook", + body=lambda request: time.sleep(2) or "OK", # type: ignore # Sleep longer than timeout + status=200, + ) + + prediction = make_prediction_response(Status.SUCCEEDED) + start_time = time.time() + + caller = webhook_caller_filtered( + "http://example.com/webhook", {WebhookEvent.COMPLETED} + ) + caller(prediction, WebhookEvent.COMPLETED) + wait_for_webhook_calls(1, timeout=3.0) + + elapsed_time = time.time() - start_time + # Should timeout quickly (within 2 seconds including overhead) + assert elapsed_time < 2.0, f"Webhook call took too long: {elapsed_time}s" + + +@responses.activate +def test_webhook_error_handling(): + """Test that webhook calls handle HTTP errors gracefully""" + responses.add( + responses.POST, + "http://example.com/webhook", + status=500, + ) + + prediction = make_prediction_response(Status.SUCCEEDED) + + # Should not raise an exception + caller = webhook_caller_filtered( + "http://example.com/webhook", {WebhookEvent.COMPLETED} ) + caller(prediction, WebhookEvent.COMPLETED) + wait_for_webhook_calls(1) - c(response) + assert len(responses.calls) == 1 -@responses.activate(registry=registries.OrderedRegistry) -def test_webhook_caller_terminal_retries(): - c = webhook_caller("https://example.com/webhook/123") - resps = [] +def test_webhook_connection_refused(): + """Test webhook behavior when connection is refused (simulating service down)""" + # Use a port that's guaranteed to be closed + webhook_url = "http://127.0.0.1:65432/webhook" # Unlikely to be in use - payload = {"status": Status.SUCCEEDED, "output": {"animal": "giraffe"}, "input": {}} - response = PredictionResponse(**payload) + prediction = make_prediction_response(Status.SUCCEEDED) + start_time = time.time() - for _ in range(2): - resps.append( - responses.post( - "https://example.com/webhook/123", - json=payload, - status=429, - ) - ) - resps.append( - responses.post( - "https://example.com/webhook/123", - json=payload, - status=200, - ) + # Should not raise an exception or block indefinitely + caller = webhook_caller_filtered(webhook_url, {WebhookEvent.COMPLETED}) + caller(prediction, WebhookEvent.COMPLETED) + wait_for_webhook_calls(1, timeout=5.0) + + elapsed_time = time.time() - start_time + # Should fail quickly due to connection refused + assert elapsed_time < 15.0, ( + f"Connection refused handling took too long: {elapsed_time}s" ) - c(response) - assert all(r.call_count == 1 for r in resps) +@responses.activate +def test_webhook_retry_behavior(): + """Test that webhook retries work correctly for terminal status""" + call_count = 0 + + def callback(request: Any) -> Tuple[int, Dict[str, str], str]: + nonlocal call_count + call_count += 1 + if call_count < 3: # Fail first 2 attempts + return (500, {}, "Server Error") + return (200, {}, "OK") + + responses.add_callback( + responses.POST, + "http://example.com/webhook", + callback=callback, + ) + + prediction = make_prediction_response(Status.SUCCEEDED) + + caller = webhook_caller_filtered( + "http://example.com/webhook", {WebhookEvent.COMPLETED} + ) + caller(prediction, WebhookEvent.COMPLETED) + wait_for_webhook_calls(1, timeout=10.0) + + # Should have retried and eventually succeeded + assert call_count == 3 + assert len(responses.calls) == 3 @responses.activate -def test_webhook_includes_user_agent(): - c = webhook_caller("https://example.com/webhook/123") - - payload = { - "status": Status.PROCESSING, - "output": {"animal": "giraffe"}, - "input": {}, - } - response = PredictionResponse(**payload) - - responses.post( - "https://example.com/webhook/123", - json=payload, +def test_webhook_filtered(): + """Test that webhook_caller_filtered only sends webhooks for specified events""" + responses.add( + responses.POST, + "http://example.com/webhook", status=200, ) - c(response) + prediction = make_prediction_response(Status.SUCCEEDED) + + # Should send webhook for COMPLETED event + caller = webhook_caller_filtered( + "http://example.com/webhook", {WebhookEvent.COMPLETED} + ) + caller(prediction, WebhookEvent.COMPLETED) + wait_for_webhook_calls(1) assert len(responses.calls) == 1 - user_agent = responses.calls[0].request.headers["user-agent"] - assert user_agent.startswith("cog-worker/") + # Reset responses + responses.reset() + responses.add( + responses.POST, + "http://example.com/webhook", + status=200, + ) + + # Should NOT send webhook for START event when only COMPLETED is in filter + caller(prediction, WebhookEvent.START) + wait_for_webhook_calls(1) -@responses.activate -def test_webhook_caller_filtered_basic(): - events = WebhookEvent.default_events() - c = webhook_caller_filtered("https://example.com/webhook/123", events) + assert len(responses.calls) == 0 + + +def test_webhook_max_retry_limit(): + """Test that webhooks don't retry indefinitely""" + # Create a server that always returns 500 + server = HTTPServer(("localhost", 0), ErrorHandler) + thread = threading.Thread(target=server.serve_forever) + thread.daemon = True + thread.start() + + try: + webhook_url = f"http://localhost:{server.server_port}/webhook" + prediction = make_prediction_response(Status.SUCCEEDED) - payload = {"status": Status.PROCESSING, "animal": "giraffe", "input": {}} - response = PredictionResponse(**payload) + start_time = time.time() + caller = webhook_caller_filtered(webhook_url, {WebhookEvent.COMPLETED}) + caller(prediction, WebhookEvent.COMPLETED) + wait_for_webhook_calls(1, timeout=70.0) # Max ~60s for 6 retries + elapsed_time = time.time() - start_time - responses.post( - "https://example.com/webhook/123", - json=payload, + # Should stop retrying after max attempts (~60s with exponential backoff) + assert elapsed_time < 70.0, f"Webhook retries took too long: {elapsed_time}s" + + finally: + server.shutdown() + server.server_close() + thread.join(timeout=1.0) + + +def test_webhook_background_execution(): + """Test that webhooks execute in background threads and don't block main thread""" + # Create a slow server + server = HTTPServer(("localhost", 0), SlowHandler) + thread = threading.Thread(target=server.serve_forever) + thread.daemon = True + thread.start() + + try: + webhook_url = f"http://localhost:{server.server_port}/webhook" + prediction = make_prediction_response(Status.SUCCEEDED) + + start_time = time.time() + + # Make multiple webhook calls + caller = webhook_caller_filtered(webhook_url, {WebhookEvent.COMPLETED}) + for _ in range(3): + caller(prediction, WebhookEvent.COMPLETED) + + # Should return immediately (not wait for webhooks to complete) + immediate_time = time.time() - start_time + assert immediate_time < 0.5, ( + f"Webhook calls blocked main thread: {immediate_time}s" + ) + + # Wait for all webhooks to complete + wait_for_webhook_calls(3, timeout=10.0) + + finally: + server.shutdown() + server.server_close() + thread.join(timeout=1.0) + + +@responses.activate +def test_webhook_user_agent(): + """Test that webhook calls include correct user agent""" + responses.add( + responses.POST, + "http://example.com/webhook", status=200, ) - c(response, WebhookEvent.LOGS) + prediction = make_prediction_response(Status.SUCCEEDED) + + caller = webhook_caller_filtered( + "http://example.com/webhook", {WebhookEvent.COMPLETED} + ) + caller(prediction, WebhookEvent.COMPLETED) + wait_for_webhook_calls(1) + assert len(responses.calls) == 1 + request = responses.calls[0].request + assert "cog-worker/" in request.headers.get("User-Agent", "") -@responses.activate -def test_webhook_caller_filtered_omits_filtered_events(): - events = {WebhookEvent.COMPLETED} - c = webhook_caller_filtered("https://example.com/webhook/123", events) - payload = { - "status": Status.PROCESSING, - "output": {"animal": "giraffe"}, - "input": {}, - } - response = PredictionResponse(**payload) +def test_webhook_original_bug_scenario(): + """ + Test the original bug scenario: webhook service down causes prediction to get stuck + This test verifies that our fix prevents the issue + """ + # Simulate webhook service being completely down (connection refused) + webhook_url = "http://127.0.0.1:65433/webhook" # Port guaranteed to be closed + + prediction = make_prediction_response(Status.SUCCEEDED) + + # Record start time + start_time = time.time() + + # This should NOT block indefinitely or cause the prediction to get stuck + caller = webhook_caller_filtered(webhook_url, {WebhookEvent.COMPLETED}) + caller(prediction, WebhookEvent.COMPLETED) + + # Wait for webhook call to complete (should fail after retries) + wait_for_webhook_calls(1, timeout=20.0) + + elapsed_time = time.time() - start_time + + # The fix should ensure this completes within reasonable time + # Original bug would cause this to hang for 320+ seconds (5+ minutes) + # With our fix, it should fail within ~15-20 seconds (6 retries with exponential backoff) + # This proves the webhook failures don't block the main thread indefinitely + assert elapsed_time < 25.0, ( + f"Webhook failure handling took too long: {elapsed_time}s" + ) + assert elapsed_time > 10.0, ( + f"Webhook should have attempted retries, took only: {elapsed_time}s" + ) + + # Verify that the prediction status would not be stuck in "BUSY" + # (In real usage, the runner would have updated status before webhook call) + assert prediction.status == Status.SUCCEEDED + + +def test_webhook_cancellation_during_failure(): + """ + Test that webhook failures don't prevent cancellation + This simulates the scenario where a prediction needs to be cancelled + while webhook calls are failing + """ + + # Create a server that's very slow to respond + class VerySlowHandler(BaseHTTPRequestHandler): + def do_POST(self): + time.sleep(5) # Very slow response + self.send_response(200) + self.end_headers() + + server = HTTPServer(("localhost", 0), VerySlowHandler) + thread = threading.Thread(target=server.serve_forever) + thread.daemon = True + thread.start() + + try: + webhook_url = f"http://localhost:{server.server_port}/webhook" + + # Start a webhook call that will be slow + prediction = make_prediction_response(Status.PROCESSING) + caller = webhook_caller_filtered( + webhook_url, {WebhookEvent.START, WebhookEvent.COMPLETED} + ) + caller(prediction, WebhookEvent.START) + + # Immediately try to "cancel" by updating status + # This should not be blocked by the ongoing webhook call + start_time = time.time() + prediction.status = Status.CANCELED + + # In real usage, this would trigger another webhook call for cancellation + caller(prediction, WebhookEvent.COMPLETED) + + immediate_time = time.time() - start_time + + # Cancellation should be immediate, not blocked by slow webhook + assert immediate_time < 1.0, ( + f"Cancellation was blocked by webhook: {immediate_time}s" + ) + + # Clean up - wait for webhooks to complete or timeout + wait_for_webhook_calls(2, timeout=15.0) + + finally: + server.shutdown() + server.server_close() + thread.join(timeout=1.0) + - c(response, WebhookEvent.LOGS) +def test_webhook_thread_pool_limits(): + """Test that webhook thread pool doesn't create unlimited threads""" + initial_thread_count = threading.active_count() + + # Create many webhook calls simultaneously + prediction = make_prediction_response(Status.SUCCEEDED) + + # Use a non-existent URL to make calls fail quickly + webhook_url = "http://127.0.0.1:65434/webhook" + + # Make many concurrent webhook calls + caller = webhook_caller_filtered(webhook_url, {WebhookEvent.COMPLETED}) + for _ in range(20): + caller(prediction, WebhookEvent.COMPLETED) + + # Check thread count hasn't exploded + peak_thread_count = threading.active_count() + thread_increase = peak_thread_count - initial_thread_count + + # Should not create more than the thread pool limit (4) + some overhead + assert thread_increase < 10, f"Too many threads created: {thread_increase}" + + # Wait for all webhooks to complete + wait_for_webhook_calls(20, timeout=10.0) + + # Thread count should return to normal + final_thread_count = threading.active_count() + assert ( + final_thread_count <= initial_thread_count + 4 + ) # Allow for thread pool threads @responses.activate -def test_webhook_caller_connection_errors(): - connerror_resp = responses.Response( +def test_webhook_caller_basic(): + """Test basic webhook_caller functionality (without events)""" + responses.add( responses.POST, - "https://example.com/webhook/123", + "http://example.com/webhook", status=200, ) - connerror_exc = requests.ConnectionError("failed to connect") - connerror_exc.response = connerror_resp - connerror_resp.body = connerror_exc - responses.add(connerror_resp) - - payload = { - "status": Status.PROCESSING, - "output": {"animal": "giraffe"}, - "input": {}, - } - response = PredictionResponse(**payload) - - c = webhook_caller("https://example.com/webhook/123") - # this should not raise an error - c(response) + + prediction = make_prediction_response(Status.SUCCEEDED) + + # webhook_caller doesn't use events, just sends the response + caller = webhook_caller("http://example.com/webhook") + caller(prediction) + wait_for_webhook_calls(1) + + assert len(responses.calls) == 1