Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ import io.ktor.http.protocolWithAuthority
import io.modelcontextprotocol.kotlin.sdk.JSONRPCMessage
import io.modelcontextprotocol.kotlin.sdk.shared.AbstractTransport
import io.modelcontextprotocol.kotlin.sdk.shared.McpJson
import kotlinx.coroutines.CancellationException
import kotlinx.coroutines.CompletableDeferred
import kotlinx.coroutines.CoroutineName
import kotlinx.coroutines.CoroutineScope
Expand All @@ -24,10 +25,11 @@ import kotlinx.coroutines.Job
import kotlinx.coroutines.SupervisorJob
import kotlinx.coroutines.cancel
import kotlinx.coroutines.cancelAndJoin
import kotlinx.coroutines.ensureActive
import kotlinx.coroutines.launch
import kotlinx.serialization.SerializationException
import kotlin.concurrent.atomics.AtomicBoolean
import kotlin.concurrent.atomics.ExperimentalAtomicApi
import kotlin.properties.Delegates
import kotlin.time.Duration

@Deprecated("Use SseClientTransport instead", ReplaceWith("SseClientTransport"), DeprecationLevel.WARNING)
Expand All @@ -44,97 +46,59 @@ public class SseClientTransport(
private val reconnectionTime: Duration? = null,
private val requestBuilder: HttpRequestBuilder.() -> Unit = {},
) : AbstractTransport() {
private val scope by lazy {
CoroutineScope(session.coroutineContext + SupervisorJob())
}

private val initialized: AtomicBoolean = AtomicBoolean(false)
private var session: ClientSSESession by Delegates.notNull()
private val endpoint = CompletableDeferred<String>()

private lateinit var session: ClientSSESession
private lateinit var scope: CoroutineScope
private var job: Job? = null

private val baseUrl by lazy {
val requestUrl = session.call.request.url.toString()
val url = Url(requestUrl)
var path = url.encodedPath
if (path.isEmpty()) {
url.protocolWithAuthority
} else if (path.endsWith("/")) {
url.protocolWithAuthority + path.removeSuffix("/")
} else {
// the last item is not a directory, so will not be taken into account
path = path.substring(0, path.lastIndexOf("/"))
url.protocolWithAuthority + path
private val baseUrl: String by lazy {
session.call.request.url.let { url ->
val path = url.encodedPath
when {
path.isEmpty() -> url.protocolWithAuthority
path.endsWith("/") -> url.protocolWithAuthority + path.removeSuffix("/")
else -> url.protocolWithAuthority + path.take(path.lastIndexOf("/"))
}
}
}

override suspend fun start() {
if (!initialized.compareAndSet(expectedValue = false, newValue = true)) {
error(
"SSEClientTransport already started! " +
"If using Client class, note that connect() calls start() automatically.",
)
check(initialized.compareAndSet(expectedValue = false, newValue = true)) {
"SSEClientTransport already started! If using Client class, note that connect() calls start() automatically."
}

session = urlString?.let {
client.sseSession(
urlString = it,
try {
session = urlString?.let {
client.sseSession(
urlString = it,
reconnectionTime = reconnectionTime,
block = requestBuilder,
)
} ?: client.sseSession(
reconnectionTime = reconnectionTime,
block = requestBuilder,
)
} ?: client.sseSession(
reconnectionTime = reconnectionTime,
block = requestBuilder,
)

job = scope.launch(CoroutineName("SseMcpClientTransport.collect#${hashCode()}")) {
session.incoming.collect { event ->
when (event.event) {
"error" -> {
val e = IllegalStateException("SSE error: ${event.data}")
_onError(e)
throw e
}

"open" -> {
// The connection is open, but we need to wait for the endpoint to be received.
}

"endpoint" -> {
try {
val eventData = event.data ?: ""

// check url correctness
val maybeEndpoint = Url("$baseUrl/${if (eventData.startsWith("/")) eventData.substring(1) else eventData}")
endpoint.complete(maybeEndpoint.toString())
} catch (e: Exception) {
_onError(e)
close()
error(e)
}
}
scope = CoroutineScope(session.coroutineContext + SupervisorJob())

else -> {
try {
val message = McpJson.decodeFromString<JSONRPCMessage>(event.data ?: "")
_onMessage(message)
} catch (e: Exception) {
_onError(e)
}
}
}
job = scope.launch(CoroutineName("SseMcpClientTransport.connect#${hashCode()}")) {
collectMessages()
}
}

endpoint.await()
endpoint.await()
} catch (e: Exception) {
closeResources()
initialized.store(false)
throw e
}
}

@OptIn(ExperimentalCoroutinesApi::class)
override suspend fun send(message: JSONRPCMessage) {
if (!endpoint.isCompleted) {
error("Not connected")
}
check(initialized.load()) { "SseClientTransport is not initialized!" }
check(job?.isActive == true) { "SseClientTransport is closed!" }
check(endpoint.isCompleted) { "Not connected!" }

try {
val response = client.post(endpoint.getCompleted()) {
Expand All @@ -147,19 +111,80 @@ public class SseClientTransport(
val text = response.bodyAsText()
error("Error POSTing to endpoint (HTTP ${response.status}): $text")
}
} catch (e: Exception) {
} catch (e: Throwable) {
_onError(e)
throw e
}
}

override suspend fun close() {
if (!initialized.load()) {
error("SSEClientTransport is not initialized!")
check(initialized.load()) { "SseClientTransport is not initialized!" }
closeResources()
}

private suspend fun CoroutineScope.collectMessages() {
try {
session.incoming.collect { event ->
ensureActive()

when (event.event) {
"error" -> {
val error = IllegalStateException("SSE error: ${event.data}")
_onError(error)
throw error
}

"open" -> {
// The connection is open, but we need to wait for the endpoint to be received.
}

"endpoint" -> handleEndpoint(event.data.orEmpty())
else -> handleMessage(event.data.orEmpty())
}
}
} catch (e: CancellationException) {
throw e
} catch (e: Throwable) {
_onError(e)
throw e
} finally {
closeResources()
}
}

private fun handleEndpoint(eventData: String) {
try {
val path = if (eventData.startsWith("/")) eventData.substring(1) else eventData
val endpointUrl = Url("$baseUrl/$path")
endpoint.complete(endpointUrl.toString())
} catch (e: Throwable) {
_onError(e)
endpoint.completeExceptionally(e)
throw e
}
}

private suspend fun handleMessage(data: String) {
try {
val message = McpJson.decodeFromString<JSONRPCMessage>(data)
_onMessage(message)
} catch (e: SerializationException) {
_onError(e)
}
}

private suspend fun closeResources() {
if (!initialized.compareAndSet(expectedValue = true, newValue = false)) return

session.cancel()
_onClose()
job?.cancelAndJoin()
try {
if (::session.isInitialized) session.cancel()
if (::scope.isInitialized) scope.cancel()
endpoint.cancel()
} catch (e: Throwable) {
_onError(e)
}

_onClose()
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import io.ktor.client.HttpClient
import io.ktor.client.plugins.sse.SSE
import io.ktor.server.application.install
import io.ktor.server.cio.CIO
import io.ktor.server.engine.EmbeddedServer
import io.ktor.server.engine.embeddedServer
import io.ktor.server.routing.post
import io.ktor.server.routing.route
Expand All @@ -17,10 +18,12 @@ import kotlinx.coroutines.test.runTest
import kotlin.test.Test

class SseTransportTest : BaseTransportTest() {

private suspend fun EmbeddedServer<*, *>.actualPort() = engine.resolvedConnectors().first().port

@Test
fun `should start then close cleanly`() = runTest {
val port = 8080
val server = embeddedServer(CIO, port = port) {
val server = embeddedServer(CIO, port = 0) {
install(io.ktor.server.sse.SSE)
val transports = ConcurrentMap<String, SseServerTransport>()
routing {
Expand All @@ -34,24 +37,27 @@ class SseTransportTest : BaseTransportTest() {
}
}.startSuspend(wait = false)

val actualPort = server.actualPort()

val client = HttpClient {
install(SSE)
}.mcpSseTransport {
url {
host = "localhost"
this.port = port
this.port = actualPort
}
}

testClientOpenClose(client)

server.stopSuspend()
try {
testClientOpenClose(client)
} finally {
server.stopSuspend()
}
}

@Test
fun `should read messages`() = runTest {
val port = 3003
val server = embeddedServer(CIO, port = port) {
val server = embeddedServer(CIO, port = 0) {
install(io.ktor.server.sse.SSE)
val transports = ConcurrentMap<String, SseServerTransport>()
routing {
Expand All @@ -71,23 +77,27 @@ class SseTransportTest : BaseTransportTest() {
}
}.startSuspend(wait = false)

val actualPort = server.actualPort()

val client = HttpClient {
install(SSE)
}.mcpSseTransport {
url {
host = "localhost"
this.port = port
this.port = actualPort
}
}

testClientRead(client)
server.stopSuspend()
try {
testClientRead(client)
} finally {
server.stopSuspend()
}
}

@Test
fun `test sse path not root path`() = runTest {
val port = 3007
val server = embeddedServer(CIO, port = port) {
val server = embeddedServer(CIO, port = 0) {
install(io.ktor.server.sse.SSE)
val transports = ConcurrentMap<String, SseServerTransport>()
routing {
Expand All @@ -109,17 +119,22 @@ class SseTransportTest : BaseTransportTest() {
}
}.startSuspend(wait = false)

val actualPort = server.actualPort()

val client = HttpClient {
install(SSE)
}.mcpSseTransport {
url {
host = "localhost"
this.port = port
this.port = actualPort
pathSegments = listOf("sse")
}
}

testClientRead(client)
server.stopSuspend()
try {
testClientRead(client)
} finally {
server.stopSuspend()
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,15 @@ class SseIntegrationTest {
@Test
fun `client should be able to connect to sse server`() = runTest {
val serverEngine = initServer()
var client: Client? = null
try {
withContext(Dispatchers.Default) {
assertDoesNotThrow { initClient() }
assertDoesNotThrow { client = initClient() }
}
} catch (e: Exception) {
fail("Failed to connect client: $e")
} finally {
client?.close()
// Make sure to stop the server
serverEngine.stopSuspend(1000, 2000)
}
Expand All @@ -54,11 +58,11 @@ class SseIntegrationTest {
ServerOptions(capabilities = ServerCapabilities()),
)

return embeddedServer(ServerCIO, host = URL, port = PORT) {
return embeddedServer(ServerCIO, host = URL, port = PORT) {
install(io.ktor.server.sse.SSE)
routing {
mcp { server }
}
routing {
mcp { server }
}
}.startSuspend(wait = false)
}

Expand Down
Loading