Skip to content
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
110 changes: 101 additions & 9 deletions src/tools/mongodb/create/createIndex.ts
Original file line number Diff line number Diff line change
@@ -1,33 +1,125 @@
import { z } from "zod";
import type { CallToolResult } from "@modelcontextprotocol/sdk/types.js";
import { DbOperationArgs, MongoDBToolBase } from "../mongodbTool.js";
import type { ToolArgs, OperationType } from "../../tool.js";
import { type ToolArgs, type OperationType, FeatureFlags } from "../../tool.js";
import type { IndexDirection } from "mongodb";

const vectorSearchIndexDefinition = z.object({
type: z.literal("vectorSearch"),
fields: z
.array(
z.discriminatedUnion("type", [
z
.object({
type: z.literal("filter"),
path: z
.string()
.describe(
"Name of the field to index. For nested fields, use dot notation to specify path to embedded fields"
),
})
.strict()
.describe("Definition for a field that will be used for pre-filtering results."),
z
.object({
type: z.literal("vector"),
path: z
.string()
.describe(
"Name of the field to index. For nested fields, use dot notation to specify path to embedded fields"
),
numDimensions: z
.number()
.min(1)
.max(8192)
.describe(
"Number of vector dimensions that MongoDB Vector Search enforces at index-time and query-time"
),
similarity: z
.enum(["cosine", "euclidean", "dotProduct"])
.default("cosine")
.describe(
"Vector similarity function to use to search for top K-nearest neighbors. You can set this field only for vector-type fields."
),
quantization: z
.enum(["none", "scalar", "binary"])
.optional()
.default("none")
.describe(
"Type of automatic vector quantization for your vectors. Use this setting only if your embeddings are float or double vectors."
),
})
.strict()
.describe("Definition for a field that contains vector embeddings."),
])
)
.nonempty()
.refine((fields) => fields.some((f) => f.type === "vector"), {
message: "At least one vector field must be defined",
})
.describe(
"Definitions for the vector and filter fields to index, one definition per document. You must specify `vector` for fields that contain vector embeddings and `filter` for additional fields to filter on. At least one vector-type field definition is required."
),
});

export class CreateIndexTool extends MongoDBToolBase {
public name = "create-index";
protected description = "Create an index for a collection";
protected argsShape = {
...DbOperationArgs,
keys: z.object({}).catchall(z.custom<IndexDirection>()).describe("The index definition"),
name: z.string().optional().describe("The name of the index"),
definition: z
.array(
z.discriminatedUnion("type", [
z.object({
type: z.literal("classic"),
keys: z.object({}).catchall(z.custom<IndexDirection>()).describe("The index definition"),
}),
...(this.isFeatureFlagEnabled(FeatureFlags.VectorSearch) ? [vectorSearchIndexDefinition] : []),
])
)
.describe(
"The index definition. Use 'classic' for standard indexes and 'vectorSearch' for vector search indexes"
),
};

public operationType: OperationType = "create";

protected async execute({
database,
collection,
keys,
name,
definition: definitions,
}: ToolArgs<typeof this.argsShape>): Promise<CallToolResult> {
const provider = await this.ensureConnected();
const indexes = await provider.createIndexes(database, collection, [
{
key: keys,
name,
},
]);
let indexes: string[] = [];
const definition = definitions[0];
if (!definition) {
throw new Error("Index definition not provided. Expected one of the following: `classic`, `vectorSearch`");
}

switch (definition.type) {
case "classic":
indexes = await provider.createIndexes(database, collection, [
{
key: definition.keys,
name,
},
]);
break;
case "vectorSearch":
indexes = await provider.createSearchIndexes(database, collection, [
{
name,
definition: {
fields: definition.fields,
},
type: "vectorSearch",
},
]);

break;
}

return {
content: [
Expand Down
14 changes: 14 additions & 0 deletions src/tools/tool.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,10 @@ export type ToolCallbackArgs<Args extends ZodRawShape> = Parameters<ToolCallback

export type ToolExecutionContext<Args extends ZodRawShape = ZodRawShape> = Parameters<ToolCallback<Args>>[1];

export const enum FeatureFlags {
VectorSearch = "vectorSearch",
}

/**
* The type of operation the tool performs. This is used when evaluating if a tool is allowed to run based on
* the config's `disabledTools` and `readOnly` settings.
Expand Down Expand Up @@ -314,6 +318,16 @@ export abstract class ToolBase {

this.telemetry.emitEvents([event]);
}

// TODO: Move this to a separate file
protected isFeatureFlagEnabled(flag: FeatureFlags): boolean {
switch (flag) {
case FeatureFlags.VectorSearch:
return this.config.voyageApiKey !== "";
default:
return false;
}
}
}

/**
Expand Down
114 changes: 108 additions & 6 deletions tests/accuracy/createIndex.test.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import { describeAccuracyTests } from "./sdk/describeAccuracyTests.js";
import { Matcher } from "./sdk/matcher.js";

process.env.MDB_VOYAGE_API_KEY = "valid-key";

describeAccuracyTests([
{
prompt: "Create an index that covers the following query on 'mflix.movies' namespace - { \"release_year\": 1992 }",
Expand All @@ -11,9 +13,14 @@ describeAccuracyTests([
database: "mflix",
collection: "movies",
name: Matcher.anyOf(Matcher.undefined, Matcher.string()),
keys: {
release_year: 1,
},
definition: [
{
type: "classic",
keys: {
release_year: 1,
},
},
],
},
},
],
Expand All @@ -27,9 +34,104 @@ describeAccuracyTests([
database: "mflix",
collection: "movies",
name: Matcher.anyOf(Matcher.undefined, Matcher.string()),
keys: {
title: "text",
},
definition: [
{
type: "classic",
keys: {
title: "text",
},
},
],
},
},
],
},
{
prompt: "Create a vector search index on 'mydb.movies' namespace on the 'plotSummary' field. The index should use 1024 dimensions.",
expectedToolCalls: [
{
toolName: "create-index",
parameters: {
database: "mydb",
collection: "movies",
name: Matcher.anyOf(Matcher.undefined, Matcher.string()),
definition: [
{
type: "vectorSearch",
fields: [
{
type: "vector",
path: "plotSummary",
numDimensions: 1024,
},
],
},
],
},
},
],
},
{
prompt: "Create a vector search index on 'mydb.movies' namespace with on the 'plotSummary' field and 'genre' field, both of which contain vector embeddings. Pick a sensible number of dimensions for a voyage 3.5 model.",
expectedToolCalls: [
{
toolName: "create-index",
parameters: {
database: "mydb",
collection: "movies",
name: Matcher.anyOf(Matcher.undefined, Matcher.string()),
definition: [
{
type: "vectorSearch",
fields: [
{
type: "vector",
path: "plotSummary",
numDimensions: Matcher.number(
(value) => value % 2 === 0 && value >= 256 && value <= 8192
),
similarity: Matcher.anyOf(Matcher.undefined, Matcher.string()),
},
{
type: "vector",
path: "genre",
numDimensions: Matcher.number(
(value) => value % 2 === 0 && value >= 256 && value <= 8192
),
similarity: Matcher.anyOf(Matcher.undefined, Matcher.string()),
},
],
},
],
},
},
],
},
{
prompt: "Create a vector search index on 'mydb.movies' namespace where the 'plotSummary' field is indexed as a 1024-dimensional vector and the 'releaseDate' field is indexed as a regular field.",
expectedToolCalls: [
{
toolName: "create-index",
parameters: {
database: "mydb",
collection: "movies",
name: Matcher.anyOf(Matcher.undefined, Matcher.string()),
definition: [
{
type: "vectorSearch",
fields: [
{
type: "vector",
path: "plotSummary",
numDimensions: 1024,
},
{
type: "filter",
path: "releaseDate",
},
],
},
],
},
},
],
Expand Down
11 changes: 8 additions & 3 deletions tests/accuracy/dropIndex.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,14 @@ describeAccuracyTests([
database: "mflix",
collection: "movies",
name: Matcher.anyOf(Matcher.undefined, Matcher.string()),
keys: {
title: "text",
},
definition: [
{
keys: {
title: "text",
},
type: "classic",
},
],
},
},
{
Expand Down
4 changes: 3 additions & 1 deletion tests/accuracy/sdk/accuracyTestingClient.ts
Original file line number Diff line number Diff line change
Expand Up @@ -82,14 +82,16 @@ export class AccuracyTestingClient {
static async initializeClient(
mdbConnectionString: string,
atlasApiClientId?: string,
atlasApiClientSecret?: string
atlasApiClientSecret?: string,
voyageApiKey?: string
): Promise<AccuracyTestingClient> {
const args = [
MCP_SERVER_CLI_SCRIPT,
"--connectionString",
mdbConnectionString,
...(atlasApiClientId ? ["--apiClientId", atlasApiClientId] : []),
...(atlasApiClientSecret ? ["--apiClientSecret", atlasApiClientSecret] : []),
...(voyageApiKey ? ["--voyageApiKey", voyageApiKey] : []),
];

const clientTransport = new StdioClientTransport({
Expand Down
4 changes: 3 additions & 1 deletion tests/accuracy/sdk/describeAccuracyTests.ts
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ export function describeAccuracyTests(accuracyTestConfigs: AccuracyTestConfig[])

const atlasApiClientId = process.env.MDB_MCP_API_CLIENT_ID;
const atlasApiClientSecret = process.env.MDB_MCP_API_CLIENT_SECRET;
const voyageApiKey = process.env.MDB_VOYAGE_API_KEY;

let commitSHA: string;
let accuracyResultStorage: AccuracyResultStorage;
Expand All @@ -85,7 +86,8 @@ export function describeAccuracyTests(accuracyTestConfigs: AccuracyTestConfig[])
testMCPClient = await AccuracyTestingClient.initializeClient(
mdbIntegration.connectionString(),
atlasApiClientId,
atlasApiClientSecret
atlasApiClientSecret,
voyageApiKey
);
agent = getVercelToolCallingAgent();
});
Expand Down
Loading
Loading