diff --git a/src/tools/args.ts b/src/tools/args.ts index 165f3da0d..653f72da2 100644 --- a/src/tools/args.ts +++ b/src/tools/args.ts @@ -1,4 +1,5 @@ import { z, type ZodString } from "zod"; +import { EJSON } from "bson"; const NO_UNICODE_REGEX = /^[\x20-\x7E]*$/; export const NO_UNICODE_ERROR = "String cannot contain special characters or Unicode symbols"; @@ -68,3 +69,15 @@ export const AtlasArgs = { password: (): z.ZodString => z.string().min(1, "Password is required").max(100, "Password must be 100 characters or less"), }; + +function toEJSON(value: T): T { + if (!value) { + return value; + } + + return EJSON.deserialize(value, { relaxed: false }) as T; +} + +export function zEJSON(): z.AnyZodObject { + return z.object({}).passthrough().transform(toEJSON) as unknown as z.AnyZodObject; +} diff --git a/src/tools/mongodb/create/insertMany.ts b/src/tools/mongodb/create/insertMany.ts index 3e5f9b8a1..46619568d 100644 --- a/src/tools/mongodb/create/insertMany.ts +++ b/src/tools/mongodb/create/insertMany.ts @@ -2,6 +2,7 @@ import { z } from "zod"; import type { CallToolResult } from "@modelcontextprotocol/sdk/types.js"; import { DbOperationArgs, MongoDBToolBase } from "../mongodbTool.js"; import type { ToolArgs, OperationType } from "../../tool.js"; +import { zEJSON } from "../../args.js"; export class InsertManyTool extends MongoDBToolBase { public name = "insert-many"; @@ -9,7 +10,7 @@ export class InsertManyTool extends MongoDBToolBase { protected argsShape = { ...DbOperationArgs, documents: z - .array(z.object({}).passthrough().describe("An individual MongoDB document")) + .array(zEJSON().describe("An individual MongoDB document")) .describe( "The array of documents to insert, matching the syntax of the document argument of db.collection.insertMany()" ), diff --git a/src/tools/mongodb/delete/deleteMany.ts b/src/tools/mongodb/delete/deleteMany.ts index 754b0381a..835cbb4ab 100644 --- a/src/tools/mongodb/delete/deleteMany.ts +++ b/src/tools/mongodb/delete/deleteMany.ts @@ -1,18 +1,16 @@ -import { z } from "zod"; import type { CallToolResult } from "@modelcontextprotocol/sdk/types.js"; import { DbOperationArgs, MongoDBToolBase } from "../mongodbTool.js"; import type { ToolArgs, OperationType } from "../../tool.js"; import { checkIndexUsage } from "../../../helpers/indexCheck.js"; import { EJSON } from "bson"; +import { zEJSON } from "../../args.js"; export class DeleteManyTool extends MongoDBToolBase { public name = "delete-many"; protected description = "Removes all documents that match the filter from a MongoDB collection"; protected argsShape = { ...DbOperationArgs, - filter: z - .object({}) - .passthrough() + filter: zEJSON() .optional() .describe( "The query filter, specifying the deletion criteria. Matches the syntax of the filter argument of db.collection.deleteMany()" diff --git a/src/tools/mongodb/read/aggregate.ts b/src/tools/mongodb/read/aggregate.ts index 45df45471..29aa5fc1e 100644 --- a/src/tools/mongodb/read/aggregate.ts +++ b/src/tools/mongodb/read/aggregate.ts @@ -6,9 +6,10 @@ import { formatUntrustedData } from "../../tool.js"; import { checkIndexUsage } from "../../../helpers/indexCheck.js"; import { EJSON } from "bson"; import { ErrorCodes, MongoDBError } from "../../../common/errors.js"; +import { zEJSON } from "../../args.js"; export const AggregateArgs = { - pipeline: z.array(z.object({}).passthrough()).describe("An array of aggregation stages to execute"), + pipeline: z.array(zEJSON()).describe("An array of aggregation stages to execute"), }; export class AggregateTool extends MongoDBToolBase { diff --git a/src/tools/mongodb/read/count.ts b/src/tools/mongodb/read/count.ts index 9a746990c..435c2c772 100644 --- a/src/tools/mongodb/read/count.ts +++ b/src/tools/mongodb/read/count.ts @@ -1,13 +1,11 @@ import type { CallToolResult } from "@modelcontextprotocol/sdk/types.js"; import { DbOperationArgs, MongoDBToolBase } from "../mongodbTool.js"; import type { ToolArgs, OperationType } from "../../tool.js"; -import { z } from "zod"; import { checkIndexUsage } from "../../../helpers/indexCheck.js"; +import { zEJSON } from "../../args.js"; export const CountArgs = { - query: z - .object({}) - .passthrough() + query: zEJSON() .optional() .describe( "A filter/query parameter. Allows users to filter the documents to count. Matches the syntax of the filter argument of db.collection.count()." diff --git a/src/tools/mongodb/read/find.ts b/src/tools/mongodb/read/find.ts index 38f3f5059..0373cef44 100644 --- a/src/tools/mongodb/read/find.ts +++ b/src/tools/mongodb/read/find.ts @@ -6,11 +6,10 @@ import { formatUntrustedData } from "../../tool.js"; import type { SortDirection } from "mongodb"; import { checkIndexUsage } from "../../../helpers/indexCheck.js"; import { EJSON } from "bson"; +import { zEJSON } from "../../args.js"; export const FindArgs = { - filter: z - .object({}) - .passthrough() + filter: zEJSON() .optional() .describe("The query filter, matching the syntax of the query argument of db.collection.find()"), projection: z diff --git a/src/tools/mongodb/update/updateMany.ts b/src/tools/mongodb/update/updateMany.ts index c48768aec..9d936757f 100644 --- a/src/tools/mongodb/update/updateMany.ts +++ b/src/tools/mongodb/update/updateMany.ts @@ -3,23 +3,21 @@ import type { CallToolResult } from "@modelcontextprotocol/sdk/types.js"; import { DbOperationArgs, MongoDBToolBase } from "../mongodbTool.js"; import type { ToolArgs, OperationType } from "../../tool.js"; import { checkIndexUsage } from "../../../helpers/indexCheck.js"; +import { zEJSON } from "../../args.js"; export class UpdateManyTool extends MongoDBToolBase { public name = "update-many"; protected description = "Updates all documents that match the specified filter for a collection"; protected argsShape = { ...DbOperationArgs, - filter: z - .object({}) - .passthrough() + filter: zEJSON() .optional() .describe( "The selection criteria for the update, matching the syntax of the filter argument of db.collection.updateOne()" ), - update: z - .object({}) - .passthrough() - .describe("An update document describing the modifications to apply using update operator expressions"), + update: zEJSON().describe( + "An update document describing the modifications to apply using update operator expressions" + ), upsert: z .boolean() .optional() diff --git a/src/transports/stdio.ts b/src/transports/stdio.ts index 09a7490b9..f3f316855 100644 --- a/src/transports/stdio.ts +++ b/src/transports/stdio.ts @@ -1,55 +1,8 @@ -import { EJSON } from "bson"; -import type { JSONRPCMessage } from "@modelcontextprotocol/sdk/types.js"; -import { JSONRPCMessageSchema } from "@modelcontextprotocol/sdk/types.js"; import { StdioServerTransport } from "@modelcontextprotocol/sdk/server/stdio.js"; import { LogId } from "../common/logger.js"; import type { Server } from "../server.js"; import { TransportRunnerBase, type TransportRunnerConfig } from "./base.js"; -// This is almost a copy of ReadBuffer from @modelcontextprotocol/sdk -// but it uses EJSON.parse instead of JSON.parse to handle BSON types -export class EJsonReadBuffer { - private _buffer?: Buffer; - - append(chunk: Buffer): void { - this._buffer = this._buffer ? Buffer.concat([this._buffer, chunk]) : chunk; - } - - readMessage(): JSONRPCMessage | null { - if (!this._buffer) { - return null; - } - - const index = this._buffer.indexOf("\n"); - if (index === -1) { - return null; - } - - const line = this._buffer.toString("utf8", 0, index).replace(/\r$/, ""); - this._buffer = this._buffer.subarray(index + 1); - - // This is using EJSON.parse instead of JSON.parse to handle BSON types - return JSONRPCMessageSchema.parse(EJSON.parse(line)); - } - - clear(): void { - this._buffer = undefined; - } -} - -// This is a hacky workaround for https://github.com/mongodb-js/mongodb-mcp-server/issues/211 -// The underlying issue is that StdioServerTransport uses JSON.parse to deserialize -// messages, but that doesn't handle bson types, such as ObjectId when serialized as EJSON. -// -// This function creates a StdioServerTransport and replaces the internal readBuffer with EJsonReadBuffer -// that uses EJson.parse instead. -export function createStdioTransport(): StdioServerTransport { - const server = new StdioServerTransport(); - server["_readBuffer"] = new EJsonReadBuffer(); - - return server; -} - export class StdioRunner extends TransportRunnerBase { private server: Server | undefined; @@ -60,8 +13,7 @@ export class StdioRunner extends TransportRunnerBase { async start(): Promise { try { this.server = await this.setupServer(); - - const transport = createStdioTransport(); + const transport = new StdioServerTransport(); await this.server.connect(transport); } catch (error: unknown) { diff --git a/tests/integration/indexCheck.test.ts b/tests/integration/indexCheck.test.ts index 49bb06b08..0cb59d0b3 100644 --- a/tests/integration/indexCheck.test.ts +++ b/tests/integration/indexCheck.test.ts @@ -80,7 +80,7 @@ describe("IndexCheck integration tests", () => { arguments: { database: integration.randomDbName(), collection: "find-test-collection", - filter: { _id: docs[0]?._id }, // Uses _id index (IDHACK) + filter: { _id: { $oid: docs[0]?._id } }, // Uses _id index (IDHACK) }, }); diff --git a/tests/integration/tools/mongodb/create/insertMany.test.ts b/tests/integration/tools/mongodb/create/insertMany.test.ts index 739c39964..844cbcaef 100644 --- a/tests/integration/tools/mongodb/create/insertMany.test.ts +++ b/tests/integration/tools/mongodb/create/insertMany.test.ts @@ -76,7 +76,7 @@ describeWithMongoDB("insertMany tool", (integration) => { arguments: { database: integration.randomDbName(), collection: "coll1", - documents: [{ prop1: "value1", _id: insertedIds[0] }], + documents: [{ prop1: "value1", _id: { $oid: insertedIds[0] } }], }, }); diff --git a/tests/integration/tools/mongodb/mongodbHelpers.ts b/tests/integration/tools/mongodb/mongodbHelpers.ts index 327d5cdf9..60961df32 100644 --- a/tests/integration/tools/mongodb/mongodbHelpers.ts +++ b/tests/integration/tools/mongodb/mongodbHelpers.ts @@ -15,6 +15,7 @@ import { } from "../../helpers.js"; import type { UserConfig, DriverOptions } from "../../../../src/common/config.js"; import { afterAll, afterEach, beforeAll, beforeEach, describe, expect, it } from "vitest"; +import { EJSON } from "bson"; const __dirname = path.dirname(fileURLToPath(import.meta.url)); @@ -267,10 +268,9 @@ export function prepareTestData(integration: MongoDBIntegrationTest): { }; } -export function getDocsFromUntrustedContent(content: string): unknown[] { +export function getDocsFromUntrustedContent(content: string): T[] { const data = getDataFromUntrustedContent(content); - - return JSON.parse(data) as unknown[]; + return EJSON.parse(data, { relaxed: true }) as T[]; } export async function isCommunityServer(integration: MongoDBIntegrationTestCase): Promise { diff --git a/tests/integration/tools/mongodb/read/find.test.ts b/tests/integration/tools/mongodb/read/find.test.ts index fc192d8ba..ec94961b9 100644 --- a/tests/integration/tools/mongodb/read/find.test.ts +++ b/tests/integration/tools/mongodb/read/find.test.ts @@ -190,7 +190,7 @@ describeWithMongoDB("find tool", (integration) => { arguments: { database: integration.randomDbName(), collection: "foo", - filter: { _id: fooObject._id }, + filter: { _id: { $oid: fooObject._id } }, }, }); @@ -202,6 +202,36 @@ describeWithMongoDB("find tool", (integration) => { expect((docs[0] as { value: number }).value).toEqual(fooObject.value); }); + + it("can find objects by date", async () => { + await integration.connectMcpClient(); + + await integration + .mongoClient() + .db(integration.randomDbName()) + .collection("foo_with_dates") + .insertMany([ + { date: new Date("2025-05-10"), idx: 0 }, + { date: new Date("2025-05-11"), idx: 1 }, + ]); + + const response = await integration.mcpClient().callTool({ + name: "find", + arguments: { + database: integration.randomDbName(), + collection: "foo_with_dates", + filter: { date: { $gt: { $date: "2025-05-10" } } }, // only 2025-05-11 will match + }, + }); + + const content = getResponseContent(response); + expect(content).toContain('Found 1 documents in the collection "foo_with_dates".'); + + const docs = getDocsFromUntrustedContent<{ date: Date }>(content); + expect(docs.length).toEqual(1); + + expect(docs[0]?.date.toISOString()).toContain("2025-05-11"); + }); }); validateAutoConnectBehavior(integration, "find", () => { diff --git a/tests/unit/transports/stdio.test.ts b/tests/unit/transports/stdio.test.ts deleted file mode 100644 index bfc64c290..000000000 --- a/tests/unit/transports/stdio.test.ts +++ /dev/null @@ -1,71 +0,0 @@ -import { Decimal128, MaxKey, MinKey, ObjectId, Timestamp, UUID } from "bson"; -import { createStdioTransport, EJsonReadBuffer } from "../../../src/transports/stdio.js"; -import type { JSONRPCMessage } from "@modelcontextprotocol/sdk/types.js"; -import type { AuthInfo } from "@modelcontextprotocol/sdk/server/auth/types.js"; -import { StdioServerTransport } from "@modelcontextprotocol/sdk/server/stdio.js"; -import type { Readable } from "stream"; -import { ReadBuffer } from "@modelcontextprotocol/sdk/shared/stdio.js"; -import { describe, expect, it, beforeEach, afterEach } from "vitest"; -describe("stdioTransport", () => { - let transport: StdioServerTransport; - beforeEach(async () => { - transport = createStdioTransport(); - await transport.start(); - }); - - afterEach(async () => { - await transport.close(); - }); - - it("ejson deserializes messages", () => { - const messages: { message: JSONRPCMessage; extra?: { authInfo?: AuthInfo } }[] = []; - transport.onmessage = ( - message, - extra?: { - authInfo?: AuthInfo; - } - ): void => { - messages.push({ message, extra }); - }; - - (transport["_stdin"] as Readable).emit( - "data", - Buffer.from( - '{"jsonrpc":"2.0","id":1,"method":"testMethod","params":{"oid":{"$oid":"681b741f13aa74a0687b5110"},"uuid":{"$uuid":"f81d4fae-7dec-11d0-a765-00a0c91e6bf6"},"date":{"$date":"2025-05-07T14:54:23.973Z"},"decimal":{"$numberDecimal":"1234567890987654321"},"int32":123,"maxKey":{"$maxKey":1},"minKey":{"$minKey":1},"timestamp":{"$timestamp":{"t":123,"i":456}}}}\n', - "utf-8" - ) - ); - - expect(messages.length).toBe(1); - const message = messages[0]?.message; - - expect(message).toEqual({ - jsonrpc: "2.0", - id: 1, - method: "testMethod", - params: { - oid: new ObjectId("681b741f13aa74a0687b5110"), - uuid: new UUID("f81d4fae-7dec-11d0-a765-00a0c91e6bf6"), - date: new Date(Date.parse("2025-05-07T14:54:23.973Z")), - decimal: new Decimal128("1234567890987654321"), - int32: 123, - maxKey: new MaxKey(), - minKey: new MinKey(), - timestamp: new Timestamp({ t: 123, i: 456 }), - }, - }); - }); - - it("has _readBuffer field of type EJsonReadBuffer", () => { - expect(transport["_readBuffer"]).toBeDefined(); - expect(transport["_readBuffer"]).toBeInstanceOf(EJsonReadBuffer); - }); - - describe("standard StdioServerTransport", () => { - it("has a _readBuffer field", () => { - const standardTransport = new StdioServerTransport(); - expect(standardTransport["_readBuffer"]).toBeDefined(); - expect(standardTransport["_readBuffer"]).toBeInstanceOf(ReadBuffer); - }); - }); -});