diff --git a/api/kotlin-sdk.api b/api/kotlin-sdk.api index cf30f17d..ed82d3b6 100644 --- a/api/kotlin-sdk.api +++ b/api/kotlin-sdk.api @@ -2872,6 +2872,9 @@ public class io/modelcontextprotocol/kotlin/sdk/server/Server : io/modelcontextp public static synthetic fun createMessage$default (Lio/modelcontextprotocol/kotlin/sdk/server/Server;Lio/modelcontextprotocol/kotlin/sdk/CreateMessageRequest;Lio/modelcontextprotocol/kotlin/sdk/shared/RequestOptions;Lkotlin/coroutines/Continuation;ILjava/lang/Object;)Ljava/lang/Object; public final fun getClientCapabilities ()Lio/modelcontextprotocol/kotlin/sdk/ClientCapabilities; public final fun getClientVersion ()Lio/modelcontextprotocol/kotlin/sdk/Implementation; + public final fun getPrompts ()Ljava/util/Map; + public final fun getResources ()Ljava/util/Map; + public final fun getTools ()Ljava/util/Map; public final fun listRoots (Lkotlinx/serialization/json/JsonObject;Lio/modelcontextprotocol/kotlin/sdk/shared/RequestOptions;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; public static synthetic fun listRoots$default (Lio/modelcontextprotocol/kotlin/sdk/server/Server;Lkotlinx/serialization/json/JsonObject;Lio/modelcontextprotocol/kotlin/sdk/shared/RequestOptions;Lkotlin/coroutines/Continuation;ILjava/lang/Object;)Ljava/lang/Object; public fun onClose ()V diff --git a/build.gradle.kts b/build.gradle.kts index 6a586511..1f826bd3 100644 --- a/build.gradle.kts +++ b/build.gradle.kts @@ -14,6 +14,7 @@ import org.jreleaser.model.Active plugins { alias(libs.plugins.kotlin.multiplatform) alias(libs.plugins.kotlin.serialization) + alias(libs.plugins.kotlin.atomicfu) alias(libs.plugins.dokka) alias(libs.plugins.jreleaser) `maven-publish` @@ -246,6 +247,7 @@ kotlin { kotlin.srcDir(generateLibVersionTask.map { it.sourcesDir }) dependencies { api(libs.kotlinx.serialization.json) + api(libs.kotlinx.collections.immutable) api(libs.ktor.client.cio) api(libs.ktor.server.cio) api(libs.ktor.server.sse) diff --git a/gradle/libs.versions.toml b/gradle/libs.versions.toml index 5261c60a..40dc628d 100644 --- a/gradle/libs.versions.toml +++ b/gradle/libs.versions.toml @@ -2,9 +2,11 @@ # plugins version kotlin = "2.2.0" dokka = "2.0.0" +atomicfu = "0.29.0" # libraries version serialization = "1.9.0" +collections-immutable = "0.4.0" coroutines = "1.10.2" ktor = "3.2.1" mockk = "1.14.4" @@ -17,6 +19,7 @@ kotest = "5.9.1" [libraries] # Kotlinx libraries kotlinx-serialization-json = { group = "org.jetbrains.kotlinx", name = "kotlinx-serialization-json", version.ref = "serialization" } +kotlinx-collections-immutable = { group = "org.jetbrains.kotlinx", name = "kotlinx-collections-immutable", version.ref = "collections-immutable" } kotlin-logging = { group = "io.github.oshai", name = "kotlin-logging", version.ref = "logging" } # Ktor @@ -36,6 +39,7 @@ kotest-assertions-json = { group = "io.kotest", name = "kotest-assertions-json", [plugins] kotlin-multiplatform = { id = "org.jetbrains.kotlin.multiplatform", version.ref = "kotlin" } kotlin-serialization = { id = "org.jetbrains.kotlin.plugin.serialization", version.ref = "kotlin" } +kotlin-atomicfu = { id = "org.jetbrains.kotlinx.atomicfu", version.ref = "atomicfu" } dokka = { id = "org.jetbrains.dokka", version.ref = "dokka" } jreleaser = { id = "org.jreleaser", version.ref = "jreleaser"} kotlinx-binary-compatibility-validator = { id = "org.jetbrains.kotlinx.binary-compatibility-validator", version.ref = "binaryCompatibilityValidatorPlugin" } diff --git a/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/Client.kt b/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/Client.kt index 71b61ba7..880c8847 100644 --- a/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/Client.kt +++ b/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/Client.kt @@ -22,6 +22,7 @@ import io.modelcontextprotocol.kotlin.sdk.ListResourceTemplatesRequest import io.modelcontextprotocol.kotlin.sdk.ListResourceTemplatesResult import io.modelcontextprotocol.kotlin.sdk.ListResourcesRequest import io.modelcontextprotocol.kotlin.sdk.ListResourcesResult +import io.modelcontextprotocol.kotlin.sdk.ListRootsRequest import io.modelcontextprotocol.kotlin.sdk.ListRootsResult import io.modelcontextprotocol.kotlin.sdk.ListToolsRequest import io.modelcontextprotocol.kotlin.sdk.ListToolsResult @@ -41,6 +42,12 @@ import io.modelcontextprotocol.kotlin.sdk.shared.Protocol import io.modelcontextprotocol.kotlin.sdk.shared.ProtocolOptions import io.modelcontextprotocol.kotlin.sdk.shared.RequestOptions import io.modelcontextprotocol.kotlin.sdk.shared.Transport +import kotlinx.atomicfu.atomic +import kotlinx.atomicfu.getAndUpdate +import kotlinx.atomicfu.update +import kotlinx.collections.immutable.minus +import kotlinx.collections.immutable.persistentMapOf +import kotlinx.collections.immutable.toPersistentSet import kotlinx.serialization.json.JsonElement import kotlinx.serialization.json.JsonNull import kotlinx.serialization.json.JsonObject @@ -94,14 +101,14 @@ public open class Client( private val capabilities: ClientCapabilities = options.capabilities - private val roots = mutableMapOf() + private val roots = atomic(persistentMapOf()) init { logger.debug { "Initializing MCP client with capabilities: $capabilities" } // Internal handlers for roots if (capabilities.roots != null) { - setRequestHandler(Method.Defined.RootsList) { _, _ -> + setRequestHandler(Method.Defined.RootsList) { _, _ -> handleListRoots() } } @@ -483,7 +490,7 @@ public open class Client( throw IllegalStateException("Client does not support roots capability.") } logger.info { "Adding root: $name ($uri)" } - roots[uri] = Root(uri, name) + roots.update { current -> current.put(uri, Root(uri, name)) } } /** @@ -498,10 +505,7 @@ public open class Client( throw IllegalStateException("Client does not support roots capability.") } logger.info { "Adding ${rootsToAdd.size} roots" } - for (r in rootsToAdd) { - logger.info { "Adding root: ${r.name} (${r.uri})" } - roots[r.uri] = r - } + roots.update { current -> current.putAll(rootsToAdd.associateBy { it.uri }) } } /** @@ -517,7 +521,8 @@ public open class Client( throw IllegalStateException("Client does not support roots capability.") } logger.info { "Removing root: $uri" } - val removed = roots.remove(uri) != null + val oldMap = roots.getAndUpdate { current -> current.remove(uri) } + val removed = uri in oldMap logger.debug { if (removed) { "Root removed: $uri" @@ -541,13 +546,11 @@ public open class Client( throw IllegalStateException("Client does not support roots capability.") } logger.info { "Removing ${uris.size} roots" } - var removedCount = 0 - for (uri in uris) { - logger.debug { "Removing root: $uri" } - if (roots.remove(uri) != null) { - removedCount++ - } - } + + val oldMap = roots.getAndUpdate { current -> current - uris.toPersistentSet() } + + val removedCount = uris.count { it in oldMap } + logger.info { if (removedCount > 0) { "Removed $removedCount roots" @@ -571,7 +574,7 @@ public open class Client( // --- Internal Handlers --- private suspend fun handleListRoots(): ListRootsResult { - val rootList = roots.values.toList() + val rootList = roots.value.values.toList() return ListRootsResult(rootList) } } diff --git a/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/Server.kt b/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/Server.kt index 67791e10..e98c4ced 100644 --- a/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/Server.kt +++ b/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/Server.kt @@ -44,6 +44,12 @@ import io.modelcontextprotocol.kotlin.sdk.ToolListChangedNotification import io.modelcontextprotocol.kotlin.sdk.shared.Protocol import io.modelcontextprotocol.kotlin.sdk.shared.ProtocolOptions import io.modelcontextprotocol.kotlin.sdk.shared.RequestOptions +import kotlinx.atomicfu.atomic +import kotlinx.atomicfu.getAndUpdate +import kotlinx.atomicfu.update +import kotlinx.collections.immutable.minus +import kotlinx.collections.immutable.persistentMapOf +import kotlinx.collections.immutable.toPersistentSet import kotlinx.coroutines.CompletableDeferred import kotlinx.serialization.json.JsonObject @@ -91,9 +97,15 @@ public open class Server( private val capabilities: ServerCapabilities = options.capabilities - private val tools = mutableMapOf() - private val prompts = mutableMapOf() - private val resources = mutableMapOf() + private val _tools = atomic(persistentMapOf()) + private val _prompts = atomic(persistentMapOf()) + private val _resources = atomic(persistentMapOf()) + public val tools: Map + get() = _tools.value + public val prompts: Map + get() = _prompts.value + public val resources: Map + get() = _resources.value init { logger.debug { "Initializing MCP server with capabilities: $capabilities" } @@ -192,7 +204,9 @@ public open class Server( throw IllegalStateException("Server does not support tools capability. Enable it in ServerOptions.") } logger.info { "Registering tool: $name" } - tools[name] = RegisteredTool(Tool(name, description, inputSchema, toolAnnotations), handler) + _tools.update { current -> + current.put(name, RegisteredTool(Tool(name, description, inputSchema, toolAnnotations), handler)) + } } /** @@ -207,10 +221,7 @@ public open class Server( throw IllegalStateException("Server does not support tools capability.") } logger.info { "Registering ${toolsToAdd.size} tools" } - for (rt in toolsToAdd) { - logger.debug { "Registering tool: ${rt.tool.name}" } - tools[rt.tool.name] = rt - } + _tools.update { current -> current.putAll(toolsToAdd.associateBy { it.tool.name }) } } /** @@ -226,7 +237,10 @@ public open class Server( throw IllegalStateException("Server does not support tools capability.") } logger.info { "Removing tool: $name" } - val removed = tools.remove(name) != null + + val oldMap = _tools.getAndUpdate { current -> current.remove(name) } + + val removed = name in oldMap logger.debug { if (removed) { "Tool removed: $name" @@ -250,18 +264,15 @@ public open class Server( throw IllegalStateException("Server does not support tools capability.") } logger.info { "Removing ${toolNames.size} tools" } - var removedCount = 0 - for (name in toolNames) { - logger.debug { "Removing tool: $name" } - if (tools.remove(name) != null) { - removedCount++ - } - } + + val oldMap = _tools.getAndUpdate { current -> current - toolNames.toPersistentSet() } + + val removedCount = toolNames.count { it in oldMap } logger.info { if (removedCount > 0) { - "Removed $removedCount tools" + "Removed $removedCount tools" } else { - "No tools were removed" + "No tools were removed" } } return removedCount @@ -280,7 +291,7 @@ public open class Server( throw IllegalStateException("Server does not support prompts capability.") } logger.info { "Registering prompt: ${prompt.name}" } - prompts[prompt.name] = RegisteredPrompt(prompt, promptProvider) + _prompts.update { current -> current.put(prompt.name, RegisteredPrompt(prompt, promptProvider)) } } /** @@ -314,10 +325,7 @@ public open class Server( throw IllegalStateException("Server does not support prompts capability.") } logger.info { "Registering ${promptsToAdd.size} prompts" } - for (rp in promptsToAdd) { - logger.debug { "Registering prompt: ${rp.prompt.name}" } - prompts[rp.prompt.name] = rp - } + _prompts.update { current -> current.putAll(promptsToAdd.associateBy { it.prompt.name }) } } /** @@ -333,7 +341,10 @@ public open class Server( throw IllegalStateException("Server does not support prompts capability.") } logger.info { "Removing prompt: $name" } - val removed = prompts.remove(name) != null + + val oldMap = _prompts.getAndUpdate { current -> current.remove(name) } + + val removed = name in oldMap logger.debug { if (removed) { "Prompt removed: $name" @@ -357,13 +368,11 @@ public open class Server( throw IllegalStateException("Server does not support prompts capability.") } logger.info { "Removing ${promptNames.size} prompts" } - var removedCount = 0 - for (name in promptNames) { - logger.debug { "Removing prompt: $name" } - if (prompts.remove(name) != null) { - removedCount++ - } - } + + val oldMap = _prompts.getAndUpdate { current -> current - promptNames.toPersistentSet() } + + val removedCount = promptNames.count { it in oldMap } + logger.info { if (removedCount > 0) { "Removed $removedCount prompts" @@ -396,7 +405,12 @@ public open class Server( throw IllegalStateException("Server does not support resources capability.") } logger.info { "Registering resource: $name ($uri)" } - resources[uri] = RegisteredResource(Resource(uri, name, description, mimeType), readHandler) + _resources.update { current -> + current.put( + uri, + RegisteredResource(Resource(uri, name, description, mimeType), readHandler) + ) + } } /** @@ -411,10 +425,7 @@ public open class Server( throw IllegalStateException("Server does not support resources capability.") } logger.info { "Registering ${resourcesToAdd.size} resources" } - for (r in resourcesToAdd) { - logger.debug { "Registering resource: ${r.resource.name} (${r.resource.uri})" } - resources[r.resource.uri] = r - } + _resources.update { current -> current.putAll(resourcesToAdd.associateBy { it.resource.uri }) } } /** @@ -430,7 +441,10 @@ public open class Server( throw IllegalStateException("Server does not support resources capability.") } logger.info { "Removing resource: $uri" } - val removed = resources.remove(uri) != null + + val oldMap = _resources.getAndUpdate { current -> current.remove(uri) } + + val removed = uri in oldMap logger.debug { if (removed) { "Resource removed: $uri" @@ -454,13 +468,11 @@ public open class Server( throw IllegalStateException("Server does not support resources capability.") } logger.info { "Removing ${uris.size} resources" } - var removedCount = 0 - for (uri in uris) { - logger.debug { "Removing resource: $uri" } - if (resources.remove(uri) != null) { - removedCount++ - } - } + + val oldMap = _resources.getAndUpdate { current -> current - uris.toPersistentSet() } + + val removedCount = uris.count { it in oldMap } + logger.info { if (removedCount > 0) { "Removed $removedCount resources" @@ -586,7 +598,7 @@ public open class Server( private suspend fun handleCallTool(request: CallToolRequest): CallToolResult { logger.debug { "Handling tool call request for tool: ${request.name}" } - val tool = tools[request.name] + val tool = _tools.value[request.name] ?: run { logger.error { "Tool not found: ${request.name}" } throw IllegalArgumentException("Tool not found: ${request.name}") diff --git a/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/shared/Protocol.kt b/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/shared/Protocol.kt index 4cad7561..8ad733b1 100644 --- a/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/shared/Protocol.kt +++ b/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/shared/Protocol.kt @@ -19,12 +19,17 @@ import io.modelcontextprotocol.kotlin.sdk.RequestId import io.modelcontextprotocol.kotlin.sdk.RequestResult import io.modelcontextprotocol.kotlin.sdk.fromJSON import io.modelcontextprotocol.kotlin.sdk.toJSON +import kotlinx.atomicfu.AtomicRef +import kotlinx.atomicfu.atomic +import kotlinx.atomicfu.getAndUpdate +import kotlinx.atomicfu.update +import kotlinx.collections.immutable.PersistentMap +import kotlinx.collections.immutable.persistentMapOf import kotlinx.coroutines.CompletableDeferred import kotlinx.coroutines.Deferred import kotlinx.coroutines.TimeoutCancellationException import kotlinx.coroutines.withTimeout import kotlinx.serialization.ExperimentalSerializationApi -import kotlinx.serialization.encodeToString import kotlinx.serialization.json.ClassDiscriminatorMode import kotlinx.serialization.json.Json import kotlinx.serialization.json.JsonObject @@ -117,18 +122,25 @@ public abstract class Protocol( public var transport: Transport? = null private set - @PublishedApi - internal val requestHandlers: MutableMap RequestResult?> = - mutableMapOf() - public val notificationHandlers: MutableMap Unit> = - mutableMapOf() + private val _requestHandlers: AtomicRef RequestResult?>> = + atomic(persistentMapOf()) + public val requestHandlers: Map RequestResult?> + get() = _requestHandlers.value - @PublishedApi - internal val responseHandlers: MutableMap Unit> = - mutableMapOf() + private val _notificationHandlers = + atomic(persistentMapOf Unit>()) + public val notificationHandlers: Map Unit> + get() = _notificationHandlers.value - @PublishedApi - internal val progressHandlers: MutableMap = mutableMapOf() + private val _responseHandlers: AtomicRef Unit>> = + atomic(persistentMapOf()) + public val responseHandlers: Map Unit> + get() = _responseHandlers.value + + private val _progressHandlers: AtomicRef> = + atomic(persistentMapOf()) + public val progressHandlers: Map + get() = _progressHandlers.value /** * Callback for when the connection is closed for any reason. @@ -162,7 +174,7 @@ public abstract class Protocol( COMPLETED } - setRequestHandler(Method.Defined.Ping) { request, _ -> + setRequestHandler(Method.Defined.Ping) { _, _ -> EmptyRequestResult() } } @@ -195,22 +207,22 @@ public abstract class Protocol( } private fun doClose() { - responseHandlers.clear() - progressHandlers.clear() + val handlersToNotify = _responseHandlers.value.values.toList() + _responseHandlers.getAndSet(persistentMapOf()) + _progressHandlers.getAndSet(persistentMapOf()) transport = null onClose() val error = McpError(ErrorCode.Defined.ConnectionClosed.code, "Connection closed") - for (handler in responseHandlers.values) { + for (handler in handlersToNotify) { handler(null, error) } } private suspend fun onNotification(notification: JSONRPCNotification) { LOGGER.trace { "Received notification: ${notification.method}" } - val function = notificationHandlers[notification.method] - val property = fallbackNotificationHandler - val handler = function ?: property + + val handler = notificationHandlers[notification.method] ?: fallbackNotificationHandler if (handler == null) { LOGGER.trace { "No handler found for notification: ${notification.method}" } @@ -226,6 +238,7 @@ public abstract class Protocol( private suspend fun onRequest(request: JSONRPCRequest) { LOGGER.trace { "Received request: ${request.method} (id: ${request.id})" } + val handler = requestHandlers[request.method] ?: fallbackRequestHandler if (handler === null) { @@ -285,7 +298,7 @@ public abstract class Protocol( val message = notification.message val progressToken = notification.progressToken - val handler = progressHandlers[progressToken] + val handler = _progressHandlers.value[progressToken] if (handler == null) { val error = Error( "Received a progress notification for an unknown token: ${McpJson.encodeToString(notification)}", @@ -300,14 +313,24 @@ public abstract class Protocol( private fun onResponse(response: JSONRPCResponse?, error: JSONRPCError?) { val messageId = response?.id - val handler = responseHandlers[messageId] - if (handler == null) { + + val oldResponseHandlers = _responseHandlers.getAndUpdate { current -> + if (messageId != null && messageId in current) { + current.remove(messageId) + } else { + current + } + } + + val handler = oldResponseHandlers[messageId] + + if (handler != null) { + messageId?.let { msg -> _progressHandlers.update { it.remove(msg) } } + } else { onError(Error("Received a response for an unknown message ID: ${McpJson.encodeToString(response)}")) return } - responseHandlers.remove(messageId) - progressHandlers.remove(messageId) if (response != null) { handler(response, null) } else { @@ -317,7 +340,6 @@ public abstract class Protocol( error.message, error.data, ) - handler(null, error) } } @@ -372,31 +394,35 @@ public abstract class Protocol( if (options?.onProgress != null) { LOGGER.trace { "Registering progress handler for request id: $messageId" } - progressHandlers[messageId] = options.onProgress - } - - responseHandlers[messageId] = set@{ response, error -> - if (error != null) { - result.completeExceptionally(error) - return@set + _progressHandlers.update { current -> + current.put(messageId, options.onProgress) } + } - if (response?.error != null) { - result.completeExceptionally(IllegalStateException(response.error.toString())) - return@set - } - - try { - @Suppress("UNCHECKED_CAST") - result.complete(response!!.result as T) - } catch (error: Throwable) { - result.completeExceptionally(error) + _responseHandlers.update { current -> + current.put(messageId) { response, error -> + if (error != null) { + result.completeExceptionally(error) + return@put + } + + if (response?.error != null) { + result.completeExceptionally(IllegalStateException(response.error.toString())) + return@put + } + + try { + @Suppress("UNCHECKED_CAST") + result.complete(response!!.result as T) + } catch (error: Throwable) { + result.completeExceptionally(error) + } } } val cancel: suspend (Throwable) -> Unit = { reason: Throwable -> - responseHandlers.remove(messageId) - progressHandlers.remove(messageId) + _responseHandlers.update { current -> current.remove(messageId) } + _progressHandlers.update { current -> current.remove(messageId) } val notification = CancelledNotification(requestId = messageId, reason = reason.message ?: "Unknown") @@ -468,16 +494,17 @@ public abstract class Protocol( val serializer = McpJson.serializersModule.serializer(requestType) - requestHandlers[method.value] = { request, extraHandler -> - val result = McpJson.decodeFromJsonElement(serializer, request.params) - val response = if (result != null) { - @Suppress("UNCHECKED_CAST") - block(result as T, extraHandler) - } else { - EmptyRequestResult() + _requestHandlers.update { current -> + current.put(method.value) { request, extraHandler -> + val result = McpJson.decodeFromJsonElement(serializer, request.params) + val response = if (result != null) { + @Suppress("UNCHECKED_CAST") + block(result as T, extraHandler) + } else { + EmptyRequestResult() + } + response } - - response } } @@ -485,7 +512,7 @@ public abstract class Protocol( * Removes the request handler for the given method. */ public fun removeRequestHandler(method: Method) { - requestHandlers.remove(method.value) + _requestHandlers.update { current -> current.remove(method.value) } } /** @@ -494,9 +521,11 @@ public abstract class Protocol( * Note that this will replace any previous notification handler for the same method. */ public fun setNotificationHandler(method: Method, handler: (notification: T) -> Deferred) { - notificationHandlers[method.value] = { - @Suppress("UNCHECKED_CAST") - handler(it.fromJSON() as T) + _notificationHandlers.update { current -> + current.put(method.value) { + @Suppress("UNCHECKED_CAST") + handler(it.fromJSON() as T) + } } } @@ -504,6 +533,6 @@ public abstract class Protocol( * Removes the notification handler for the given method. */ public fun removeNotificationHandler(method: Method) { - notificationHandlers.remove(method.value) + _notificationHandlers.update { current -> current.remove(method.value) } } }