diff --git a/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/SSEClientTransport.kt b/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/SSEClientTransport.kt index 9b178466..d30f5288 100644 --- a/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/SSEClientTransport.kt +++ b/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/SSEClientTransport.kt @@ -16,6 +16,7 @@ import io.ktor.http.protocolWithAuthority import io.modelcontextprotocol.kotlin.sdk.JSONRPCMessage import io.modelcontextprotocol.kotlin.sdk.shared.AbstractTransport import io.modelcontextprotocol.kotlin.sdk.shared.McpJson +import kotlinx.coroutines.CancellationException import kotlinx.coroutines.CompletableDeferred import kotlinx.coroutines.CoroutineName import kotlinx.coroutines.CoroutineScope @@ -24,10 +25,11 @@ import kotlinx.coroutines.Job import kotlinx.coroutines.SupervisorJob import kotlinx.coroutines.cancel import kotlinx.coroutines.cancelAndJoin +import kotlinx.coroutines.ensureActive import kotlinx.coroutines.launch +import kotlinx.serialization.SerializationException import kotlin.concurrent.atomics.AtomicBoolean import kotlin.concurrent.atomics.ExperimentalAtomicApi -import kotlin.properties.Delegates import kotlin.time.Duration @Deprecated("Use SseClientTransport instead", ReplaceWith("SseClientTransport"), DeprecationLevel.WARNING) @@ -44,97 +46,59 @@ public class SseClientTransport( private val reconnectionTime: Duration? = null, private val requestBuilder: HttpRequestBuilder.() -> Unit = {}, ) : AbstractTransport() { - private val scope by lazy { - CoroutineScope(session.coroutineContext + SupervisorJob()) - } - private val initialized: AtomicBoolean = AtomicBoolean(false) - private var session: ClientSSESession by Delegates.notNull() private val endpoint = CompletableDeferred() + private lateinit var session: ClientSSESession + private lateinit var scope: CoroutineScope private var job: Job? = null - private val baseUrl by lazy { - val requestUrl = session.call.request.url.toString() - val url = Url(requestUrl) - var path = url.encodedPath - if (path.isEmpty()) { - url.protocolWithAuthority - } else if (path.endsWith("/")) { - url.protocolWithAuthority + path.removeSuffix("/") - } else { - // the last item is not a directory, so will not be taken into account - path = path.substring(0, path.lastIndexOf("/")) - url.protocolWithAuthority + path + private val baseUrl: String by lazy { + session.call.request.url.let { url -> + val path = url.encodedPath + when { + path.isEmpty() -> url.protocolWithAuthority + path.endsWith("/") -> url.protocolWithAuthority + path.removeSuffix("/") + else -> url.protocolWithAuthority + path.take(path.lastIndexOf("/")) + } } } override suspend fun start() { - if (!initialized.compareAndSet(expectedValue = false, newValue = true)) { - error( - "SSEClientTransport already started! " + - "If using Client class, note that connect() calls start() automatically.", - ) + check(initialized.compareAndSet(expectedValue = false, newValue = true)) { + "SSEClientTransport already started! If using Client class, note that connect() calls start() automatically." } - session = urlString?.let { - client.sseSession( - urlString = it, + try { + session = urlString?.let { + client.sseSession( + urlString = it, + reconnectionTime = reconnectionTime, + block = requestBuilder, + ) + } ?: client.sseSession( reconnectionTime = reconnectionTime, block = requestBuilder, ) - } ?: client.sseSession( - reconnectionTime = reconnectionTime, - block = requestBuilder, - ) - - job = scope.launch(CoroutineName("SseMcpClientTransport.collect#${hashCode()}")) { - session.incoming.collect { event -> - when (event.event) { - "error" -> { - val e = IllegalStateException("SSE error: ${event.data}") - _onError(e) - throw e - } - - "open" -> { - // The connection is open, but we need to wait for the endpoint to be received. - } - - "endpoint" -> { - try { - val eventData = event.data ?: "" - - // check url correctness - val maybeEndpoint = Url("$baseUrl/${if (eventData.startsWith("/")) eventData.substring(1) else eventData}") - endpoint.complete(maybeEndpoint.toString()) - } catch (e: Exception) { - _onError(e) - close() - error(e) - } - } + scope = CoroutineScope(session.coroutineContext + SupervisorJob()) - else -> { - try { - val message = McpJson.decodeFromString(event.data ?: "") - _onMessage(message) - } catch (e: Exception) { - _onError(e) - } - } - } + job = scope.launch(CoroutineName("SseMcpClientTransport.connect#${hashCode()}")) { + collectMessages() } - } - endpoint.await() + endpoint.await() + } catch (e: Exception) { + closeResources() + initialized.store(false) + throw e + } } @OptIn(ExperimentalCoroutinesApi::class) override suspend fun send(message: JSONRPCMessage) { - if (!endpoint.isCompleted) { - error("Not connected") - } + check(initialized.load()) { "SseClientTransport is not initialized!" } + check(job?.isActive == true) { "SseClientTransport is closed!" } + check(endpoint.isCompleted) { "Not connected!" } try { val response = client.post(endpoint.getCompleted()) { @@ -147,19 +111,80 @@ public class SseClientTransport( val text = response.bodyAsText() error("Error POSTing to endpoint (HTTP ${response.status}): $text") } - } catch (e: Exception) { + } catch (e: Throwable) { _onError(e) throw e } } override suspend fun close() { - if (!initialized.load()) { - error("SSEClientTransport is not initialized!") + check(initialized.load()) { "SseClientTransport is not initialized!" } + closeResources() + } + + private suspend fun CoroutineScope.collectMessages() { + try { + session.incoming.collect { event -> + ensureActive() + + when (event.event) { + "error" -> { + val error = IllegalStateException("SSE error: ${event.data}") + _onError(error) + throw error + } + + "open" -> { + // The connection is open, but we need to wait for the endpoint to be received. + } + + "endpoint" -> handleEndpoint(event.data.orEmpty()) + else -> handleMessage(event.data.orEmpty()) + } + } + } catch (e: CancellationException) { + throw e + } catch (e: Throwable) { + _onError(e) + throw e + } finally { + closeResources() } + } + + private fun handleEndpoint(eventData: String) { + try { + val path = if (eventData.startsWith("/")) eventData.substring(1) else eventData + val endpointUrl = Url("$baseUrl/$path") + endpoint.complete(endpointUrl.toString()) + } catch (e: Throwable) { + _onError(e) + endpoint.completeExceptionally(e) + throw e + } + } + + private suspend fun handleMessage(data: String) { + try { + val message = McpJson.decodeFromString(data) + _onMessage(message) + } catch (e: SerializationException) { + _onError(e) + } + } + + private suspend fun closeResources() { + if (!initialized.compareAndSet(expectedValue = true, newValue = false)) return - session.cancel() - _onClose() job?.cancelAndJoin() + try { + if (::session.isInitialized) session.cancel() + if (::scope.isInitialized) scope.cancel() + endpoint.cancel() + } catch (e: Throwable) { + _onError(e) + } + + _onClose() } } diff --git a/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/client/SseTransportTest.kt b/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/client/SseTransportTest.kt index de61a079..1c63ff65 100644 --- a/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/client/SseTransportTest.kt +++ b/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/client/SseTransportTest.kt @@ -4,6 +4,7 @@ import io.ktor.client.HttpClient import io.ktor.client.plugins.sse.SSE import io.ktor.server.application.install import io.ktor.server.cio.CIO +import io.ktor.server.engine.EmbeddedServer import io.ktor.server.engine.embeddedServer import io.ktor.server.routing.post import io.ktor.server.routing.route @@ -17,10 +18,12 @@ import kotlinx.coroutines.test.runTest import kotlin.test.Test class SseTransportTest : BaseTransportTest() { + + private suspend fun EmbeddedServer<*, *>.actualPort() = engine.resolvedConnectors().first().port + @Test fun `should start then close cleanly`() = runTest { - val port = 8080 - val server = embeddedServer(CIO, port = port) { + val server = embeddedServer(CIO, port = 0) { install(io.ktor.server.sse.SSE) val transports = ConcurrentMap() routing { @@ -34,24 +37,27 @@ class SseTransportTest : BaseTransportTest() { } }.startSuspend(wait = false) + val actualPort = server.actualPort() + val client = HttpClient { install(SSE) }.mcpSseTransport { url { host = "localhost" - this.port = port + this.port = actualPort } } - testClientOpenClose(client) - - server.stopSuspend() + try { + testClientOpenClose(client) + } finally { + server.stopSuspend() + } } @Test fun `should read messages`() = runTest { - val port = 3003 - val server = embeddedServer(CIO, port = port) { + val server = embeddedServer(CIO, port = 0) { install(io.ktor.server.sse.SSE) val transports = ConcurrentMap() routing { @@ -71,23 +77,27 @@ class SseTransportTest : BaseTransportTest() { } }.startSuspend(wait = false) + val actualPort = server.actualPort() + val client = HttpClient { install(SSE) }.mcpSseTransport { url { host = "localhost" - this.port = port + this.port = actualPort } } - testClientRead(client) - server.stopSuspend() + try { + testClientRead(client) + } finally { + server.stopSuspend() + } } @Test fun `test sse path not root path`() = runTest { - val port = 3007 - val server = embeddedServer(CIO, port = port) { + val server = embeddedServer(CIO, port = 0) { install(io.ktor.server.sse.SSE) val transports = ConcurrentMap() routing { @@ -109,17 +119,22 @@ class SseTransportTest : BaseTransportTest() { } }.startSuspend(wait = false) + val actualPort = server.actualPort() + val client = HttpClient { install(SSE) }.mcpSseTransport { url { host = "localhost" - this.port = port + this.port = actualPort pathSegments = listOf("sse") } } - testClientRead(client) - server.stopSuspend() + try { + testClientRead(client) + } finally { + server.stopSuspend() + } } } diff --git a/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/SseIntegrationTest.kt b/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/SseIntegrationTest.kt index 7b97b156..19d84589 100644 --- a/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/SseIntegrationTest.kt +++ b/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/SseIntegrationTest.kt @@ -26,11 +26,15 @@ class SseIntegrationTest { @Test fun `client should be able to connect to sse server`() = runTest { val serverEngine = initServer() + var client: Client? = null try { withContext(Dispatchers.Default) { - assertDoesNotThrow { initClient() } + assertDoesNotThrow { client = initClient() } } + } catch (e: Exception) { + fail("Failed to connect client: $e") } finally { + client?.close() // Make sure to stop the server serverEngine.stopSuspend(1000, 2000) } @@ -54,11 +58,11 @@ class SseIntegrationTest { ServerOptions(capabilities = ServerCapabilities()), ) - return embeddedServer(ServerCIO, host = URL, port = PORT) { + return embeddedServer(ServerCIO, host = URL, port = PORT) { install(io.ktor.server.sse.SSE) - routing { - mcp { server } - } + routing { + mcp { server } + } }.startSuspend(wait = false) }