Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions api/kotlin-sdk.api
Original file line number Diff line number Diff line change
Expand Up @@ -2872,6 +2872,9 @@ public class io/modelcontextprotocol/kotlin/sdk/server/Server : io/modelcontextp
public static synthetic fun createMessage$default (Lio/modelcontextprotocol/kotlin/sdk/server/Server;Lio/modelcontextprotocol/kotlin/sdk/CreateMessageRequest;Lio/modelcontextprotocol/kotlin/sdk/shared/RequestOptions;Lkotlin/coroutines/Continuation;ILjava/lang/Object;)Ljava/lang/Object;
public final fun getClientCapabilities ()Lio/modelcontextprotocol/kotlin/sdk/ClientCapabilities;
public final fun getClientVersion ()Lio/modelcontextprotocol/kotlin/sdk/Implementation;
public final fun getPrompts ()Ljava/util/Map;
public final fun getResources ()Ljava/util/Map;
public final fun getTools ()Ljava/util/Map;
public final fun listRoots (Lkotlinx/serialization/json/JsonObject;Lio/modelcontextprotocol/kotlin/sdk/shared/RequestOptions;Lkotlin/coroutines/Continuation;)Ljava/lang/Object;
public static synthetic fun listRoots$default (Lio/modelcontextprotocol/kotlin/sdk/server/Server;Lkotlinx/serialization/json/JsonObject;Lio/modelcontextprotocol/kotlin/sdk/shared/RequestOptions;Lkotlin/coroutines/Continuation;ILjava/lang/Object;)Ljava/lang/Object;
public fun onClose ()V
Expand Down
2 changes: 2 additions & 0 deletions build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import org.jreleaser.model.Active
plugins {
alias(libs.plugins.kotlin.multiplatform)
alias(libs.plugins.kotlin.serialization)
alias(libs.plugins.kotlin.atomicfu)
alias(libs.plugins.dokka)
alias(libs.plugins.jreleaser)
`maven-publish`
Expand Down Expand Up @@ -246,6 +247,7 @@ kotlin {
kotlin.srcDir(generateLibVersionTask.map { it.sourcesDir })
dependencies {
api(libs.kotlinx.serialization.json)
api(libs.kotlinx.collections.immutable)
api(libs.ktor.client.cio)
api(libs.ktor.server.cio)
api(libs.ktor.server.sse)
Expand Down
4 changes: 4 additions & 0 deletions gradle/libs.versions.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,11 @@
# plugins version
kotlin = "2.2.0"
dokka = "2.0.0"
atomicfu = "0.29.0"

# libraries version
serialization = "1.9.0"
collections-immutable = "0.4.0"
coroutines = "1.10.2"
ktor = "3.2.1"
mockk = "1.14.4"
Expand All @@ -17,6 +19,7 @@ kotest = "5.9.1"
[libraries]
# Kotlinx libraries
kotlinx-serialization-json = { group = "org.jetbrains.kotlinx", name = "kotlinx-serialization-json", version.ref = "serialization" }
kotlinx-collections-immutable = { group = "org.jetbrains.kotlinx", name = "kotlinx-collections-immutable", version.ref = "collections-immutable" }
kotlin-logging = { group = "io.github.oshai", name = "kotlin-logging", version.ref = "logging" }

# Ktor
Expand All @@ -36,6 +39,7 @@ kotest-assertions-json = { group = "io.kotest", name = "kotest-assertions-json",
[plugins]
kotlin-multiplatform = { id = "org.jetbrains.kotlin.multiplatform", version.ref = "kotlin" }
kotlin-serialization = { id = "org.jetbrains.kotlin.plugin.serialization", version.ref = "kotlin" }
kotlin-atomicfu = { id = "org.jetbrains.kotlinx.atomicfu", version.ref = "atomicfu" }
dokka = { id = "org.jetbrains.dokka", version.ref = "dokka" }
jreleaser = { id = "org.jreleaser", version.ref = "jreleaser"}
kotlinx-binary-compatibility-validator = { id = "org.jetbrains.kotlinx.binary-compatibility-validator", version.ref = "binaryCompatibilityValidatorPlugin" }
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import io.modelcontextprotocol.kotlin.sdk.ListResourceTemplatesRequest
import io.modelcontextprotocol.kotlin.sdk.ListResourceTemplatesResult
import io.modelcontextprotocol.kotlin.sdk.ListResourcesRequest
import io.modelcontextprotocol.kotlin.sdk.ListResourcesResult
import io.modelcontextprotocol.kotlin.sdk.ListRootsRequest
import io.modelcontextprotocol.kotlin.sdk.ListRootsResult
import io.modelcontextprotocol.kotlin.sdk.ListToolsRequest
import io.modelcontextprotocol.kotlin.sdk.ListToolsResult
Expand All @@ -41,6 +42,12 @@ import io.modelcontextprotocol.kotlin.sdk.shared.Protocol
import io.modelcontextprotocol.kotlin.sdk.shared.ProtocolOptions
import io.modelcontextprotocol.kotlin.sdk.shared.RequestOptions
import io.modelcontextprotocol.kotlin.sdk.shared.Transport
import kotlinx.atomicfu.atomic
import kotlinx.atomicfu.getAndUpdate
import kotlinx.atomicfu.update
import kotlinx.collections.immutable.minus
import kotlinx.collections.immutable.persistentMapOf
import kotlinx.collections.immutable.toPersistentSet
import kotlinx.serialization.json.JsonElement
import kotlinx.serialization.json.JsonNull
import kotlinx.serialization.json.JsonObject
Expand Down Expand Up @@ -94,14 +101,14 @@ public open class Client(

private val capabilities: ClientCapabilities = options.capabilities

private val roots = mutableMapOf<String, Root>()
private val roots = atomic(persistentMapOf<String, Root>())

init {
logger.debug { "Initializing MCP client with capabilities: $capabilities" }

// Internal handlers for roots
if (capabilities.roots != null) {
setRequestHandler<ListToolsRequest>(Method.Defined.RootsList) { _, _ ->
setRequestHandler<ListRootsRequest>(Method.Defined.RootsList) { _, _ ->
handleListRoots()
}
}
Expand Down Expand Up @@ -483,7 +490,7 @@ public open class Client(
throw IllegalStateException("Client does not support roots capability.")
}
logger.info { "Adding root: $name ($uri)" }
roots[uri] = Root(uri, name)
roots.update { current -> current.put(uri, Root(uri, name)) }
}

/**
Expand All @@ -498,10 +505,7 @@ public open class Client(
throw IllegalStateException("Client does not support roots capability.")
}
logger.info { "Adding ${rootsToAdd.size} roots" }
for (r in rootsToAdd) {
logger.info { "Adding root: ${r.name} (${r.uri})" }
roots[r.uri] = r
}
roots.update { current -> current.putAll(rootsToAdd.associateBy { it.uri }) }
}

/**
Expand All @@ -517,7 +521,8 @@ public open class Client(
throw IllegalStateException("Client does not support roots capability.")
}
logger.info { "Removing root: $uri" }
val removed = roots.remove(uri) != null
val oldMap = roots.getAndUpdate { current -> current.remove(uri) }
val removed = uri in oldMap
logger.debug {
if (removed) {
"Root removed: $uri"
Expand All @@ -541,13 +546,11 @@ public open class Client(
throw IllegalStateException("Client does not support roots capability.")
}
logger.info { "Removing ${uris.size} roots" }
var removedCount = 0
for (uri in uris) {
logger.debug { "Removing root: $uri" }
if (roots.remove(uri) != null) {
removedCount++
}
}

val oldMap = roots.getAndUpdate { current -> current - uris.toPersistentSet() }

val removedCount = uris.count { it in oldMap }

logger.info {
if (removedCount > 0) {
"Removed $removedCount roots"
Expand All @@ -571,7 +574,7 @@ public open class Client(
// --- Internal Handlers ---

private suspend fun handleListRoots(): ListRootsResult {
val rootList = roots.values.toList()
val rootList = roots.value.values.toList()
return ListRootsResult(rootList)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,12 @@ import io.modelcontextprotocol.kotlin.sdk.ToolListChangedNotification
import io.modelcontextprotocol.kotlin.sdk.shared.Protocol
import io.modelcontextprotocol.kotlin.sdk.shared.ProtocolOptions
import io.modelcontextprotocol.kotlin.sdk.shared.RequestOptions
import kotlinx.atomicfu.atomic
import kotlinx.atomicfu.getAndUpdate
import kotlinx.atomicfu.update
import kotlinx.collections.immutable.minus
import kotlinx.collections.immutable.persistentMapOf
import kotlinx.collections.immutable.toPersistentSet
import kotlinx.coroutines.CompletableDeferred
import kotlinx.serialization.json.JsonObject

Expand Down Expand Up @@ -91,9 +97,15 @@ public open class Server(

private val capabilities: ServerCapabilities = options.capabilities

private val tools = mutableMapOf<String, RegisteredTool>()
private val prompts = mutableMapOf<String, RegisteredPrompt>()
private val resources = mutableMapOf<String, RegisteredResource>()
private val _tools = atomic(persistentMapOf<String, RegisteredTool>())
private val _prompts = atomic(persistentMapOf<String, RegisteredPrompt>())
private val _resources = atomic(persistentMapOf<String, RegisteredResource>())
public val tools: Map<String, RegisteredTool>
get() = _tools.value
public val prompts: Map<String, RegisteredPrompt>
get() = _prompts.value
public val resources: Map<String, RegisteredResource>
get() = _resources.value

init {
logger.debug { "Initializing MCP server with capabilities: $capabilities" }
Expand Down Expand Up @@ -192,7 +204,9 @@ public open class Server(
throw IllegalStateException("Server does not support tools capability. Enable it in ServerOptions.")
}
logger.info { "Registering tool: $name" }
tools[name] = RegisteredTool(Tool(name, description, inputSchema, toolAnnotations), handler)
_tools.update { current ->
current.put(name, RegisteredTool(Tool(name, description, inputSchema, toolAnnotations), handler))
}
}

/**
Expand All @@ -207,10 +221,7 @@ public open class Server(
throw IllegalStateException("Server does not support tools capability.")
}
logger.info { "Registering ${toolsToAdd.size} tools" }
for (rt in toolsToAdd) {
logger.debug { "Registering tool: ${rt.tool.name}" }
tools[rt.tool.name] = rt
}
_tools.update { current -> current.putAll(toolsToAdd.associateBy { it.tool.name }) }
}

/**
Expand All @@ -226,7 +237,10 @@ public open class Server(
throw IllegalStateException("Server does not support tools capability.")
}
logger.info { "Removing tool: $name" }
val removed = tools.remove(name) != null

val oldMap = _tools.getAndUpdate { current -> current.remove(name) }

val removed = name in oldMap
logger.debug {
if (removed) {
"Tool removed: $name"
Expand All @@ -250,18 +264,15 @@ public open class Server(
throw IllegalStateException("Server does not support tools capability.")
}
logger.info { "Removing ${toolNames.size} tools" }
var removedCount = 0
for (name in toolNames) {
logger.debug { "Removing tool: $name" }
if (tools.remove(name) != null) {
removedCount++
}
}

val oldMap = _tools.getAndUpdate { current -> current - toolNames.toPersistentSet() }

val removedCount = toolNames.count { it in oldMap }
logger.info {
if (removedCount > 0) {
"Removed $removedCount tools"
"Removed $removedCount tools"
} else {
"No tools were removed"
"No tools were removed"
}
}
return removedCount
Expand All @@ -280,7 +291,7 @@ public open class Server(
throw IllegalStateException("Server does not support prompts capability.")
}
logger.info { "Registering prompt: ${prompt.name}" }
prompts[prompt.name] = RegisteredPrompt(prompt, promptProvider)
_prompts.update { current -> current.put(prompt.name, RegisteredPrompt(prompt, promptProvider)) }
}

/**
Expand Down Expand Up @@ -314,10 +325,7 @@ public open class Server(
throw IllegalStateException("Server does not support prompts capability.")
}
logger.info { "Registering ${promptsToAdd.size} prompts" }
for (rp in promptsToAdd) {
logger.debug { "Registering prompt: ${rp.prompt.name}" }
prompts[rp.prompt.name] = rp
}
_prompts.update { current -> current.putAll(promptsToAdd.associateBy { it.prompt.name }) }
}

/**
Expand All @@ -333,7 +341,10 @@ public open class Server(
throw IllegalStateException("Server does not support prompts capability.")
}
logger.info { "Removing prompt: $name" }
val removed = prompts.remove(name) != null

val oldMap = _prompts.getAndUpdate { current -> current.remove(name) }

val removed = name in oldMap
logger.debug {
if (removed) {
"Prompt removed: $name"
Expand All @@ -357,13 +368,11 @@ public open class Server(
throw IllegalStateException("Server does not support prompts capability.")
}
logger.info { "Removing ${promptNames.size} prompts" }
var removedCount = 0
for (name in promptNames) {
logger.debug { "Removing prompt: $name" }
if (prompts.remove(name) != null) {
removedCount++
}
}

val oldMap = _prompts.getAndUpdate { current -> current - promptNames.toPersistentSet() }

val removedCount = promptNames.count { it in oldMap }

logger.info {
if (removedCount > 0) {
"Removed $removedCount prompts"
Expand Down Expand Up @@ -396,7 +405,12 @@ public open class Server(
throw IllegalStateException("Server does not support resources capability.")
}
logger.info { "Registering resource: $name ($uri)" }
resources[uri] = RegisteredResource(Resource(uri, name, description, mimeType), readHandler)
_resources.update { current ->
current.put(
uri,
RegisteredResource(Resource(uri, name, description, mimeType), readHandler)
)
}
}

/**
Expand All @@ -411,10 +425,7 @@ public open class Server(
throw IllegalStateException("Server does not support resources capability.")
}
logger.info { "Registering ${resourcesToAdd.size} resources" }
for (r in resourcesToAdd) {
logger.debug { "Registering resource: ${r.resource.name} (${r.resource.uri})" }
resources[r.resource.uri] = r
}
_resources.update { current -> current.putAll(resourcesToAdd.associateBy { it.resource.uri }) }
}

/**
Expand All @@ -430,7 +441,10 @@ public open class Server(
throw IllegalStateException("Server does not support resources capability.")
}
logger.info { "Removing resource: $uri" }
val removed = resources.remove(uri) != null

val oldMap = _resources.getAndUpdate { current -> current.remove(uri) }

val removed = uri in oldMap
logger.debug {
if (removed) {
"Resource removed: $uri"
Expand All @@ -454,13 +468,11 @@ public open class Server(
throw IllegalStateException("Server does not support resources capability.")
}
logger.info { "Removing ${uris.size} resources" }
var removedCount = 0
for (uri in uris) {
logger.debug { "Removing resource: $uri" }
if (resources.remove(uri) != null) {
removedCount++
}
}

val oldMap = _resources.getAndUpdate { current -> current - uris.toPersistentSet() }

val removedCount = uris.count { it in oldMap }

logger.info {
if (removedCount > 0) {
"Removed $removedCount resources"
Expand Down Expand Up @@ -586,7 +598,7 @@ public open class Server(

private suspend fun handleCallTool(request: CallToolRequest): CallToolResult {
logger.debug { "Handling tool call request for tool: ${request.name}" }
val tool = tools[request.name]
val tool = _tools.value[request.name]
?: run {
logger.error { "Tool not found: ${request.name}" }
throw IllegalArgumentException("Tool not found: ${request.name}")
Expand Down
Loading
Loading