Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
45aa0b4
Add on_socket_connect and on_socket_disconnect callbacks
francisjervis Aug 4, 2025
2d89a95
Add tests for on_socket_connect and on_socket_disconnect callbacks
francisjervis Aug 4, 2025
db94280
Update tests for on_socket_connect and on_socket_disconnect callbacks
francisjervis Aug 4, 2025
c20dfd6
Fix unwanted UI display of steps
francisjervis Aug 5, 2025
28eb155
Merge branch 'Chainlit:main' into feature/websocket-lifecycle-callbacks
francisjervis Aug 19, 2025
d00e6f8
Merge branch 'Chainlit:main' into feature/websocket-lifecycle-callbacks
francisjervis Aug 25, 2025
8116855
refactor: remove debug print statement from connect callback
francisjervis Aug 25, 2025
37661a5
Merge remote-tracking branch 'origin/feature/websocket-lifecycle-call…
francisjervis Aug 25, 2025
07f58d7
Merge branch 'Chainlit:main' into feature/websocket-lifecycle-callbacks
francisjervis Aug 26, 2025
d76ce93
test: update assertions for socket connect and disconnect callbacks (…
francisjervis Aug 26, 2025
c0625f3
Merge remote-tracking branch 'origin/feature/websocket-lifecycle-call…
francisjervis Aug 26, 2025
e2d2f5d
Merge branch 'Chainlit:main' into feature/websocket-lifecycle-callbacks
francisjervis Aug 28, 2025
c2512c4
Merge branch 'main' into feature/websocket-lifecycle-callbacks
francisjervis Sep 2, 2025
abe619d
Modify uv sync command to include additional flags
francisjervis Sep 4, 2025
e4a67a2
Update Mypy command to run without project
francisjervis Sep 4, 2025
61d0b38
Remove frontend build step and add check for assets
francisjervis Sep 4, 2025
fed6ca6
Update pytest command to include --no-project flag
francisjervis Sep 4, 2025
065409d
Refactor build process to use CustomBuildHook
francisjervis Sep 4, 2025
e807c65
Add custom build hook for Hatch
francisjervis Sep 4, 2025
2c7b278
feat: add lifecycle and socket callback handlers to initialization
francisjervis Sep 4, 2025
ada79d7
feat: add lifecycle and socket callback handlers to initialization
francisjervis Sep 4, 2025
25cb41a
Merge remote-tracking branch 'origin/feature/websocket-lifecycle-call…
francisjervis Sep 4, 2025
8ffb1e3
feat: update socket lifecycle callbacks to use new callback structure
francisjervis Sep 4, 2025
f35aa0f
feat: add socket connection and disconnection callbacks
francisjervis Sep 4, 2025
b951b81
feat: enhance socket connection and disconnection callbacks with task…
francisjervis Sep 4, 2025
5b0ee83
feat: remove socket connection and disconnection callbacks from initi…
francisjervis Sep 4, 2025
a39b152
revert wrapping to remove step display
francisjervis Sep 4, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/actions/uv-python-install/action.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,6 @@ runs:
with:
python-version: ${{ inputs.python-version }}
- name: Install Python dependencies
run: uv sync ${{ inputs.extra-dependencies }}
run: uv sync --no-install-project --no-editable ${{ inputs.extra-dependencies }}
shell: bash
working-directory: ${{ inputs.working-directory }}
2 changes: 1 addition & 1 deletion .github/workflows/lint-backend.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -28,5 +28,5 @@ jobs:
changed-files: "true"
args: "format --check"
- name: Run Mypy
run: uv run mypy chainlit/
run: uv run --no-project mypy chainlit/
working-directory: ${{ env.BACKEND_DIR }}
6 changes: 1 addition & 5 deletions .github/workflows/publish.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -35,15 +35,11 @@ jobs:
with:
working-directory: ${{ env.BACKEND_DIR }}

- name: Build frontend and prepare assets
run: uv run python build.py
working-directory: ${{ env.BACKEND_DIR }}

- name: Build Python distribution
run: uv build
working-directory: ${{ env.BACKEND_DIR }}

- name: List wheel contents
- name: Check frontend and copilot folder included
run: |
pip install wheel
python -m wheel unpack dist/chainlit-*.whl -d unpacked
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/pytest.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -26,5 +26,5 @@ jobs:
run: pnpm run buildUi
timeout-minutes: 5
- name: Run Pytest
run: uv run pytest --cov=chainlit/
run: uv run --no-project pytest --cov=chainlit/
working-directory: ${{ env.BACKEND_DIR }}
6 changes: 4 additions & 2 deletions backend/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import subprocess
import sys

from hatchling.builders.hooks.plugin.interface import BuildHookInterface

class BuildError(Exception):
"""Custom exception for build failures"""
Expand Down Expand Up @@ -98,5 +99,6 @@ def build():
sys.exit(1)


if __name__ == "__main__":
build()
class CustomBuildHook(BuildHookInterface):
def initialize(self, version, build_data):
build()
28 changes: 28 additions & 0 deletions backend/chainlit/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,32 @@ def acall(self):
"LlamaIndexCallbackHandler": "chainlit.llama_index.callbacks",
"instrument_openai": "chainlit.openai",
"instrument_mistralai": "chainlit.mistralai",
"HaystackAgentCallbackHandler": "chainlit.haystack.callbacks",
"AsyncHaystackAgentCallbackHandler": "chainlit.haystack.callbacks",
"HaystackCallbackHandler": "chainlit.haystack.callbacks",
"AsyncHaystackCallbackHandler": "chainlit.haystack.callbacks",
"on_chat_start": "chainlit.lifecycle",
"on_chat_resume": "chainlit.lifecycle",
"on_chat_end": "chainlit.lifecycle",
"on_stop": "chainlit.lifecycle",
"on_message": "chainlit.lifecycle",
"on_audio_chunk": "chainlit.lifecycle",
"on_audio_start": "chainlit.lifecycle",
"on_audio_end": "chainlit.lifecycle",
"on_settings_update": "chainlit.lifecycle",
"on_action": "chainlit.lifecycle",
"on_feedback": "chainlit.lifecycle",
"on_app_start": "chainlit.lifecycle",
"on_app_shutdown": "chainlit.lifecycle",
"on_connect": "chainlit.socket",
"on_disconnect": "chainlit.socket",
"on_socket_connect": "chainlit.callbacks",
"on_socket_disconnect": "chainlit.callbacks",
"action_callback": "chainlit.action",
"author_rename": "chainlit.message",
"set_chat_profiles": "chainlit.chat_profile",
"set_starters": "chainlit.starters",
"data_layer": "chainlit.data",
"SemanticKernelFilter": "chainlit.semantic_kernel",
"server": "chainlit.server",
}
Expand Down Expand Up @@ -190,6 +216,8 @@ def acall(self):
"on_mcp_disconnect",
"on_message",
"on_settings_update",
"on_socket_connect",
"on_socket_disconnect",
"on_stop",
"on_window_message",
"password_auth_callback",
Expand Down
62 changes: 61 additions & 1 deletion backend/chainlit/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,36 @@
from chainlit.user import User
from chainlit.utils import wrap_user_function

__all__ = [
"action_callback",
"author_rename",
"data_layer",
"header_auth_callback",
"oauth_callback",
"on_app_shutdown",
"on_app_startup",
"on_audio_chunk",
"on_audio_end",
"on_audio_start",
"on_chat_end",
"on_chat_resume",
"on_chat_start",
"on_feedback",
"on_logout",
"on_mcp_connect",
"on_mcp_disconnect",
"on_message",
"on_settings_update",
"on_socket_connect",
"on_socket_disconnect",
"on_stop",
"on_window_message",
"password_auth_callback",
"send_window_message",
"set_chat_profiles",
"set_starters",
]


def on_app_startup(func: Callable[[], Union[None, Awaitable[None]]]) -> Callable:
"""
Expand Down Expand Up @@ -191,6 +221,37 @@ def on_window_message(func: Callable[[str], Any]) -> Callable:
return func


def on_socket_connect(func: Callable) -> Callable:
"""
Hook to react to the socket connection event.
This is called when a new WebSocket connection is established.

Args:
func (Callable[[], Any]): The connection hook to execute.

Returns:
Callable[[], Any]: The decorated hook.
"""
config.code.on_socket_connect = wrap_user_function(func)
return func


def on_socket_disconnect(func: Callable) -> Callable:
"""
Hook to react to the socket disconnection event.
This is called when a WebSocket connection is closed.

Args:
func (Callable[[], Any]): The disconnection hook to execute.

Returns:
Callable[[], Any]: The decorated hook.
"""
config.code.on_socket_disconnect = wrap_user_function(func)

return func


def on_chat_start(func: Callable) -> Callable:
"""
Hook to react to the user websocket connection event.
Expand All @@ -201,7 +262,6 @@ def on_chat_start(func: Callable) -> Callable:
Returns:
Callable[], Any]: The decorated hook.
"""

config.code.on_chat_start = wrap_user_function(
step(func, name="on_chat_start", type="run"), with_task=True
)
Expand Down
2 changes: 2 additions & 0 deletions backend/chainlit/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,6 +358,8 @@ class CodeSettings(BaseModel):
on_logout: Optional[Callable[["Request", "Response"], Any]] = None
on_stop: Optional[Callable[[], Any]] = None
on_chat_start: Optional[Callable[[], Any]] = None
on_socket_connect: Optional[Callable[[], Any]] = None
on_socket_disconnect: Optional[Callable[[], Any]] = None
on_chat_end: Optional[Callable[[], Any]] = None
on_chat_resume: Optional[Callable[["ThreadDict"], Any]] = None
on_message: Optional[Callable[["Message"], Any]] = None
Expand Down
83 changes: 58 additions & 25 deletions backend/chainlit/socket.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,31 +129,51 @@ def emit_call_fn(event: Literal["ask", "call_fn"], data, timeout):
return sio.call(event, data, timeout=timeout, to=sid)

session_id = auth.get("sessionId")

# Try to restore existing session first
if restore_existing_session(sid, session_id, emit_fn, emit_call_fn):
return True
session = WebsocketSession.get(sid)
if not session:
logger.error("Failed to restore existing session")
return False

user_env_string = auth.get("userEnv")
user_env = load_user_env(user_env_string)

client_type = auth.get("clientType")
url_encoded_chat_profile = auth.get("chatProfile")
chat_profile = (
unquote(url_encoded_chat_profile) if url_encoded_chat_profile else None
)

WebsocketSession(
id=session_id,
socket_id=sid,
emit=emit_fn,
emit_call=emit_call_fn,
client_type=client_type,
user_env=user_env,
user=user,
token=token,
chat_profile=chat_profile,
thread_id=auth.get("threadId"),
environ=environ,
)
# Initialize WebSocket context for restored session
init_ws_context(session)
else:
# Create new session
user_env_string = auth.get("userEnv")
user_env = load_user_env(user_env_string)

client_type = auth.get("clientType")
url_encoded_chat_profile = auth.get("chatProfile")
chat_profile = (
unquote(url_encoded_chat_profile) if url_encoded_chat_profile else None
)

session = WebsocketSession(
id=session_id,
socket_id=sid,
emit=emit_fn,
emit_call=emit_call_fn,
client_type=client_type,
user_env=user_env,
user=user,
token=token,
chat_profile=chat_profile,
thread_id=auth.get("threadId"),
environ=environ,
)

# Initialize WebSocket context with the session object
init_ws_context(session)

# Call on_socket_connect if defined
if config.code.on_socket_connect:
task = asyncio.create_task(
config.code.on_socket_connect(),
name="on_socket_connect",
)
session.current_task = task

return True

Expand Down Expand Up @@ -207,10 +227,16 @@ async def disconnect(sid):
if not session:
return

# Re-initialize context after error
init_ws_context(session)

if config.code.on_chat_end:
await config.code.on_chat_end()
# Call on_socket_disconnect if defined
if config.code.on_socket_disconnect:
try:
# Call the disconnect callback
await config.code.on_socket_disconnect()
except Exception as e:
logger.error("Error in on_socket_disconnect: %s", str(e))

if session.thread_id and session.has_first_interaction:
await persist_user_session(session.thread_id, session.to_persistable())
Expand All @@ -233,6 +259,13 @@ async def clear_on_timeout(_sid):

asyncio.ensure_future(clear_on_timeout(sid))

# Call on_chat_end if defined
if config.code.on_chat_end:
try:
await config.code.on_chat_end()
except Exception as e:
logger.error("Error in on_chat_end: %s", str(e))


@sio.on("stop") # pyright: ignore [reportOptionalCall]
async def stop(sid):
Expand Down
3 changes: 3 additions & 0 deletions backend/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,9 @@ exclude = [
"chainlit/copilot/**/**/"
]

[tool.hatch.build.hooks.custom]
path = "build.py"

[tool.hatch.build.targets.wheel]
packages = ["chainlit"]
artifacts = [
Expand Down
54 changes: 51 additions & 3 deletions backend/tests/test_callbacks.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

import asyncio
from unittest.mock import AsyncMock, Mock
from unittest.mock import AsyncMock, Mock, patch

from chainlit import config
from chainlit.callbacks import password_auth_callback
Expand Down Expand Up @@ -61,8 +61,6 @@ async def auth_func(headers: Headers) -> User | None:


async def test_oauth_callback(test_config: config.ChainlitConfig):
from unittest.mock import patch

from chainlit.callbacks import oauth_callback
from chainlit.user import User

Expand Down Expand Up @@ -541,6 +539,56 @@ async def handle_chat_end():
context.session.emit.assert_called()


async def test_on_socket_connect(
mock_chainlit_context, test_config: config.ChainlitConfig
):
from chainlit.callbacks import on_socket_connect

async with mock_chainlit_context:
# Setup test data
socket_connected = False

@on_socket_connect
async def handle_socket_connect():
nonlocal socket_connected
socket_connected = True

# Test that the callback is properly registered
assert test_config.code.on_socket_connect is not None

# Call the registered callback
await test_config.code.on_socket_connect()

# Check that the callback was executed
assert socket_connected
# Socket connect callbacks don't emit steps, so no emit call expected


async def test_on_socket_disconnect(
mock_chainlit_context, test_config: config.ChainlitConfig
):
from chainlit.callbacks import on_socket_disconnect

async with mock_chainlit_context:
# Setup test data
socket_disconnected = False

@on_socket_disconnect
async def handle_socket_disconnect():
nonlocal socket_disconnected
socket_disconnected = True

# Test that the callback is properly registered
assert test_config.code.on_socket_disconnect is not None

# Call the registered callback
await test_config.code.on_socket_disconnect()

# Check that the callback was executed
assert socket_disconnected
# Socket disconnect callbacks don't emit steps, so no emit call expected


async def test_data_layer_config(
mock_data_layer: AsyncMock,
test_config: config.ChainlitConfig,
Expand Down
Loading