diff --git a/src/common/config.ts b/src/common/config.ts index aebd6e73a..414e91589 100644 --- a/src/common/config.ts +++ b/src/common/config.ts @@ -116,6 +116,7 @@ export interface UserConfig extends CliOptions { transport: "stdio" | "http"; httpPort: number; httpHost: string; + httpHeaders: Record; loggers: Array<"stderr" | "disk" | "mcp">; idleTimeoutMs: number; notificationTimeoutMs: number; @@ -137,6 +138,7 @@ export const defaultUserConfig: UserConfig = { loggers: ["disk", "mcp"], idleTimeoutMs: 600000, // 10 minutes notificationTimeoutMs: 540000, // 9 minutes + httpHeaders: {}, }; export const config = setupUserConfig({ diff --git a/src/common/logger.ts b/src/common/logger.ts index b172ec54c..1cdd0c4a3 100644 --- a/src/common/logger.ts +++ b/src/common/logger.ts @@ -64,7 +64,7 @@ export const LogId = { oidcFlow: mongoLogId(1_008_001), } as const; -interface LogPayload { +export interface LogPayload { id: MongoLogId; context: string; message: string; @@ -152,6 +152,26 @@ export abstract class LoggerBase = DefaultEventMap> extend public emergency(payload: LogPayload): void { this.log("emergency", payload); } + + protected mapToMongoDBLogLevel(level: LogLevel): "info" | "warn" | "error" | "debug" | "fatal" { + switch (level) { + case "info": + return "info"; + case "warning": + return "warn"; + case "error": + return "error"; + case "notice": + case "debug": + return "debug"; + case "critical": + case "alert": + case "emergency": + return "fatal"; + default: + return "info"; + } + } } export class ConsoleLogger extends LoggerBase { @@ -225,26 +245,6 @@ export class DiskLogger extends LoggerBase<{ initialized: [] }> { this.logWriter[mongoDBLevel]("MONGODB-MCP", id, context, message, payload.attributes); } - - private mapToMongoDBLogLevel(level: LogLevel): "info" | "warn" | "error" | "debug" | "fatal" { - switch (level) { - case "info": - return "info"; - case "warning": - return "warn"; - case "error": - return "error"; - case "notice": - case "debug": - return "debug"; - case "critical": - case "alert": - case "emergency": - return "fatal"; - default: - return "info"; - } - } } export class McpLogger extends LoggerBase { @@ -286,7 +286,11 @@ export class CompositeLogger extends LoggerBase { public log(level: LogLevel, payload: LogPayload): void { // Override the public method to avoid the base logger redacting the message payload for (const logger of this.loggers) { - logger.log(level, { ...payload, attributes: { ...this.attributes, ...payload.attributes } }); + const attributes = + Object.keys(this.attributes).length > 0 || payload.attributes + ? { ...this.attributes, ...payload.attributes } + : undefined; + logger.log(level, { ...payload, attributes }); } } diff --git a/src/helpers/deviceId.ts b/src/helpers/deviceId.ts index f4173ff87..1282e1b79 100644 --- a/src/helpers/deviceId.ts +++ b/src/helpers/deviceId.ts @@ -1,54 +1,51 @@ import { getDeviceId } from "@mongodb-js/device-id"; -import nodeMachineId from "node-machine-id"; +import * as nodeMachineId from "node-machine-id"; import type { LoggerBase } from "../common/logger.js"; import { LogId } from "../common/logger.js"; export const DEVICE_ID_TIMEOUT = 3000; export class DeviceId { - private deviceId: string | undefined = undefined; - private deviceIdPromise: Promise | undefined = undefined; - private abortController: AbortController | undefined = undefined; + private static readonly UnknownDeviceId = Promise.resolve("unknown"); + + private deviceIdPromise: Promise; + private abortController: AbortController; private logger: LoggerBase; private readonly getMachineId: () => Promise; private timeout: number; - private static instance: DeviceId | undefined = undefined; private constructor(logger: LoggerBase, timeout: number = DEVICE_ID_TIMEOUT) { this.logger = logger; this.timeout = timeout; this.getMachineId = (): Promise => nodeMachineId.machineId(true); + this.abortController = new AbortController(); + + this.deviceIdPromise = DeviceId.UnknownDeviceId; } - public static create(logger: LoggerBase, timeout?: number): DeviceId { - if (this.instance) { - throw new Error("DeviceId instance already exists, use get() to retrieve the device ID"); - } + private initialize(): void { + this.deviceIdPromise = getDeviceId({ + getMachineId: this.getMachineId, + onError: (reason, error) => { + this.handleDeviceIdError(reason, String(error)); + }, + timeout: this.timeout, + abortSignal: this.abortController.signal, + }); + } + public static create(logger: LoggerBase, timeout?: number): DeviceId { const instance = new DeviceId(logger, timeout ?? DEVICE_ID_TIMEOUT); - instance.setup(); - - this.instance = instance; + instance.initialize(); return instance; } - private setup(): void { - this.deviceIdPromise = this.calculateDeviceId(); - } - /** * Closes the device ID calculation promise and abort controller. */ public close(): void { - if (this.abortController) { - this.abortController.abort(); - this.abortController = undefined; - } - - this.deviceId = undefined; - this.deviceIdPromise = undefined; - DeviceId.instance = undefined; + this.abortController.abort(); } /** @@ -56,39 +53,11 @@ export class DeviceId { * @returns Promise that resolves to the device ID string */ public get(): Promise { - if (this.deviceId) { - return Promise.resolve(this.deviceId); - } - - if (this.deviceIdPromise) { - return this.deviceIdPromise; - } - - return this.calculateDeviceId(); - } - - /** - * Internal method that performs the actual device ID calculation. - */ - private async calculateDeviceId(): Promise { - if (!this.abortController) { - this.abortController = new AbortController(); - } - - this.deviceIdPromise = getDeviceId({ - getMachineId: this.getMachineId, - onError: (reason, error) => { - this.handleDeviceIdError(reason, String(error)); - }, - timeout: this.timeout, - abortSignal: this.abortController.signal, - }); - return this.deviceIdPromise; } private handleDeviceIdError(reason: string, error: string): void { - this.deviceIdPromise = Promise.resolve("unknown"); + this.deviceIdPromise = DeviceId.UnknownDeviceId; switch (reason) { case "resolutionError": diff --git a/src/lib.ts b/src/lib.ts index 7843a9cda..9fd921e4c 100644 --- a/src/lib.ts +++ b/src/lib.ts @@ -1,4 +1,7 @@ export { Server, type ServerOptions } from "./server.js"; export { Telemetry } from "./telemetry/telemetry.js"; export { Session, type SessionOptions } from "./common/session.js"; -export type { UserConfig } from "./common/config.js"; +export { type UserConfig, defaultUserConfig } from "./common/config.js"; +export { StreamableHttpRunner } from "./transports/streamableHttp.js"; +export { LoggerBase } from "./common/logger.js"; +export type { LogPayload, LoggerType, LogLevel } from "./common/logger.js"; diff --git a/src/transports/base.ts b/src/transports/base.ts index 17a0ff5e7..4cbcc293e 100644 --- a/src/transports/base.ts +++ b/src/transports/base.ts @@ -16,9 +16,10 @@ export abstract class TransportRunnerBase { protected constructor( protected readonly userConfig: UserConfig, - private readonly driverOptions: DriverOptions + private readonly driverOptions: DriverOptions, + additionalLoggers: LoggerBase[] ) { - const loggers: LoggerBase[] = []; + const loggers: LoggerBase[] = [...additionalLoggers]; if (this.userConfig.loggers.includes("stderr")) { loggers.push(new ConsoleLogger()); } diff --git a/src/transports/stdio.ts b/src/transports/stdio.ts index d0619da6c..0751cac7b 100644 --- a/src/transports/stdio.ts +++ b/src/transports/stdio.ts @@ -1,3 +1,4 @@ +import type { LoggerBase } from "../common/logger.js"; import { LogId } from "../common/logger.js"; import type { Server } from "../server.js"; import { TransportRunnerBase } from "./base.js"; @@ -54,8 +55,8 @@ export function createStdioTransport(): StdioServerTransport { export class StdioRunner extends TransportRunnerBase { private server: Server | undefined; - constructor(userConfig: UserConfig, driverOptions: DriverOptions) { - super(userConfig, driverOptions); + constructor(userConfig: UserConfig, driverOptions: DriverOptions, additionalLoggers: LoggerBase[] = []) { + super(userConfig, driverOptions, additionalLoggers); } async start(): Promise { diff --git a/src/transports/streamableHttp.ts b/src/transports/streamableHttp.ts index f50019ef4..74ad3062c 100644 --- a/src/transports/streamableHttp.ts +++ b/src/transports/streamableHttp.ts @@ -4,6 +4,7 @@ import { StreamableHTTPServerTransport } from "@modelcontextprotocol/sdk/server/ import { isInitializeRequest } from "@modelcontextprotocol/sdk/types.js"; import { TransportRunnerBase } from "./base.js"; import type { DriverOptions, UserConfig } from "../common/config.js"; +import type { LoggerBase } from "../common/logger.js"; import { LogId } from "../common/logger.js"; import { randomUUID } from "crypto"; import { SessionStore } from "../common/sessionStore.js"; @@ -18,8 +19,20 @@ export class StreamableHttpRunner extends TransportRunnerBase { private httpServer: http.Server | undefined; private sessionStore!: SessionStore; - constructor(userConfig: UserConfig, driverOptions: DriverOptions) { - super(userConfig, driverOptions); + public get serverAddress(): string { + const result = this.httpServer?.address(); + if (typeof result === "string") { + return result; + } + if (typeof result === "object" && result) { + return `http://${result.address}:${result.port}`; + } + + throw new Error("Server is not started yet"); + } + + constructor(userConfig: UserConfig, driverOptions: DriverOptions, additionalLoggers: LoggerBase[] = []) { + super(userConfig, driverOptions, additionalLoggers); } async start(): Promise { @@ -32,6 +45,17 @@ export class StreamableHttpRunner extends TransportRunnerBase { app.enable("trust proxy"); // needed for reverse proxy support app.use(express.json()); + app.use((req, res, next) => { + for (const [key, value] of Object.entries(this.userConfig.httpHeaders)) { + const header = req.headers[key.toLowerCase()]; + if (!header || header !== value) { + res.status(403).send({ error: `Invalid value for header "${key}"` }); + return; + } + } + + next(); + }); const handleSessionRequest = async (req: express.Request, res: express.Response): Promise => { const sessionId = req.headers["mcp-session-id"]; @@ -142,7 +166,7 @@ export class StreamableHttpRunner extends TransportRunnerBase { this.logger.info({ id: LogId.streamableHttpTransportStarted, context: "streamableHttpTransport", - message: `Server started on http://${this.userConfig.httpHost}:${this.userConfig.httpPort}`, + message: `Server started on ${this.serverAddress}`, noRedaction: true, }); } diff --git a/tests/integration/build.test.ts b/tests/integration/build.test.ts index 7282efe45..f5b26827e 100644 --- a/tests/integration/build.test.ts +++ b/tests/integration/build.test.ts @@ -41,6 +41,13 @@ describe("Build Test", () => { const esmKeys = Object.keys(esmModule).sort(); expect(cjsKeys).toEqual(esmKeys); - expect(cjsKeys).toEqual(["Server", "Session", "Telemetry"]); + expect(cjsKeys).toIncludeSameMembers([ + "Server", + "Session", + "Telemetry", + "StreamableHttpRunner", + "defaultUserConfig", + "LoggerBase", + ]); }); }); diff --git a/tests/integration/transports/streamableHttp.test.ts b/tests/integration/transports/streamableHttp.test.ts index 396879b1b..f45ce3cd3 100644 --- a/tests/integration/transports/streamableHttp.test.ts +++ b/tests/integration/transports/streamableHttp.test.ts @@ -1,56 +1,156 @@ import { StreamableHttpRunner } from "../../../src/transports/streamableHttp.js"; import { Client } from "@modelcontextprotocol/sdk/client/index.js"; import { StreamableHTTPClientTransport } from "@modelcontextprotocol/sdk/client/streamableHttp.js"; -import { describe, expect, it, beforeAll, afterAll } from "vitest"; +import { describe, expect, it, beforeAll, afterAll, beforeEach } from "vitest"; import { config, driverOptions } from "../../../src/common/config.js"; +import type { LoggerType, LogLevel, LogPayload } from "../../../src/common/logger.js"; +import { LoggerBase, LogId } from "../../../src/common/logger.js"; describe("StreamableHttpRunner", () => { let runner: StreamableHttpRunner; let oldTelemetry: "enabled" | "disabled"; let oldLoggers: ("stderr" | "disk" | "mcp")[]; - beforeAll(async () => { + beforeAll(() => { oldTelemetry = config.telemetry; oldLoggers = config.loggers; config.telemetry = "disabled"; config.loggers = ["stderr"]; - runner = new StreamableHttpRunner(config, driverOptions); - await runner.start(); + config.httpPort = 0; // Use a random port for testing }); - afterAll(async () => { - await runner.close(); - config.telemetry = oldTelemetry; - config.loggers = oldLoggers; - }); + const headerTestCases: { headers: Record; description: string }[] = [ + { headers: {}, description: "without headers" }, + { headers: { "x-custom-header": "test-value" }, description: "with headers" }, + ]; - describe("client connects successfully", () => { - let client: Client; - let transport: StreamableHTTPClientTransport; - beforeAll(async () => { - transport = new StreamableHTTPClientTransport(new URL("http://127.0.0.1:3000/mcp")); + for (const { headers, description } of headerTestCases) { + describe(description, () => { + beforeAll(async () => { + config.httpHeaders = headers; + runner = new StreamableHttpRunner(config, driverOptions); + await runner.start(); + }); - client = new Client({ - name: "test", - version: "0.0.0", + afterAll(async () => { + await runner.close(); + config.telemetry = oldTelemetry; + config.loggers = oldLoggers; + config.httpHeaders = {}; }); - await client.connect(transport); + + const clientHeaderTestCases = [ + { + headers: {}, + description: "without client headers", + expectSuccess: Object.keys(headers).length === 0, + }, + { headers, description: "with matching client headers", expectSuccess: true }, + { headers: { ...headers, foo: "bar" }, description: "with extra client headers", expectSuccess: true }, + { + headers: { foo: "bar" }, + description: "with non-matching client headers", + expectSuccess: Object.keys(headers).length === 0, + }, + ]; + + for (const { + headers: clientHeaders, + description: clientDescription, + expectSuccess, + } of clientHeaderTestCases) { + describe(clientDescription, () => { + let client: Client; + let transport: StreamableHTTPClientTransport; + beforeAll(() => { + client = new Client({ + name: "test", + version: "0.0.0", + }); + transport = new StreamableHTTPClientTransport(new URL(`${runner.serverAddress}/mcp`), { + requestInit: { + headers: clientHeaders, + }, + }); + }); + + afterAll(async () => { + await client.close(); + await transport.close(); + }); + + it(`should ${expectSuccess ? "succeed" : "fail"}`, async () => { + try { + await client.connect(transport); + const response = await client.listTools(); + expect(response).toBeDefined(); + expect(response.tools).toBeDefined(); + expect(response.tools.length).toBeGreaterThan(0); + + const sortedTools = response.tools.sort((a, b) => a.name.localeCompare(b.name)); + expect(sortedTools[0]?.name).toBe("aggregate"); + expect(sortedTools[0]?.description).toBe("Run an aggregation against a MongoDB collection"); + } catch (err) { + if (expectSuccess) { + throw err; + } else { + expect(err).toBeDefined(); + expect(err?.toString()).toContain("HTTP 403"); + } + } + }); + }); + } }); + } + + it("can create multiple runners", async () => { + const runners: StreamableHttpRunner[] = []; + try { + for (let i = 0; i < 3; i++) { + config.httpPort = 0; // Use a random port for each runner + const runner = new StreamableHttpRunner(config, driverOptions); + await runner.start(); + runners.push(runner); + } + + const addresses = new Set(runners.map((r) => r.serverAddress)); + expect(addresses.size).toBe(runners.length); + } finally { + for (const runner of runners) { + await runner.close(); + } + } + }); - afterAll(async () => { - await client.close(); - await transport.close(); + describe("with custom logger", () => { + beforeEach(() => { + config.loggers = []; }); - it("handles requests and sends responses", async () => { - const response = await client.listTools(); - expect(response).toBeDefined(); - expect(response.tools).toBeDefined(); - expect(response.tools.length).toBeGreaterThan(0); + class CustomLogger extends LoggerBase { + protected type?: LoggerType = "console"; + public messages: { level: LogLevel; payload: LogPayload }[] = []; + protected logCore(level: LogLevel, payload: LogPayload): void { + this.messages.push({ level, payload }); + } + } + + it("can provide custom logger", async () => { + const logger = new CustomLogger(); + const runner = new StreamableHttpRunner(config, driverOptions, [logger]); + await runner.start(); + + const messages = logger.messages; + expect(messages.length).toBeGreaterThan(0); - const sortedTools = response.tools.sort((a, b) => a.name.localeCompare(b.name)); - expect(sortedTools[0]?.name).toBe("aggregate"); - expect(sortedTools[0]?.description).toBe("Run an aggregation against a MongoDB collection"); + const serverStartedMessage = messages.filter( + (m) => m.payload.id === LogId.streamableHttpTransportStarted + )[0]; + expect(serverStartedMessage).toBeDefined(); + expect(serverStartedMessage?.payload.message).toContain("Server started on"); + expect(serverStartedMessage?.payload.context).toBe("streamableHttpTransport"); + expect(serverStartedMessage?.level).toBe("info"); }); }); }); diff --git a/tests/unit/helpers/deviceId.test.ts b/tests/unit/helpers/deviceId.test.ts index 68fd54e08..3b6112f72 100644 --- a/tests/unit/helpers/deviceId.test.ts +++ b/tests/unit/helpers/deviceId.test.ts @@ -22,11 +22,16 @@ describe("deviceId", () => { deviceId.close(); }); - it("should fail to create separate instances", () => { + it("should return different instance from create", async () => { deviceId = DeviceId.create(testLogger); - - // try to create a new device id and see it raises an error - expect(() => DeviceId.create(testLogger)).toThrow("DeviceId instance already exists"); + let second: DeviceId | undefined; + try { + second = DeviceId.create(testLogger); + expect(second === deviceId).toBe(false); + expect(await second.get()).toBe(await deviceId.get()); + } finally { + second?.close(); + } }); it("should successfully retrieve device ID", async () => {