|
1 | 1 | import { z } from "zod"; |
2 | 2 | import type { CallToolResult } from "@modelcontextprotocol/sdk/types.js"; |
3 | 3 | import { DbOperationArgs, MongoDBToolBase } from "../mongodbTool.js"; |
4 | | -import type { ToolArgs, OperationType } from "../../tool.js"; |
| 4 | +import type { ToolCategory } from "../../tool.js"; |
| 5 | +import { type ToolArgs, type OperationType, FeatureFlags } from "../../tool.js"; |
5 | 6 | import type { IndexDirection } from "mongodb"; |
6 | 7 |
|
7 | 8 | export class CreateIndexTool extends MongoDBToolBase { |
| 9 | + private vectorSearchIndexDefinition = z.object({ |
| 10 | + type: z.literal("vectorSearch"), |
| 11 | + fields: z |
| 12 | + .array( |
| 13 | + z.discriminatedUnion("type", [ |
| 14 | + z |
| 15 | + .object({ |
| 16 | + type: z.literal("filter"), |
| 17 | + path: z |
| 18 | + .string() |
| 19 | + .describe( |
| 20 | + "Name of the field to index. For nested fields, use dot notation to specify path to embedded fields" |
| 21 | + ), |
| 22 | + }) |
| 23 | + .strict() |
| 24 | + .describe("Definition for a field that will be used for pre-filtering results."), |
| 25 | + z |
| 26 | + .object({ |
| 27 | + type: z.literal("vector"), |
| 28 | + path: z |
| 29 | + .string() |
| 30 | + .describe( |
| 31 | + "Name of the field to index. For nested fields, use dot notation to specify path to embedded fields" |
| 32 | + ), |
| 33 | + numDimensions: z |
| 34 | + .number() |
| 35 | + .min(1) |
| 36 | + .max(8192) |
| 37 | + .default(this.config.vectorSearchDimensions) |
| 38 | + .describe( |
| 39 | + "Number of vector dimensions that MongoDB Vector Search enforces at index-time and query-time" |
| 40 | + ), |
| 41 | + similarity: z |
| 42 | + .enum(["cosine", "euclidean", "dotProduct"]) |
| 43 | + .default(this.config.vectorSearchSimilarityFunction) |
| 44 | + .describe( |
| 45 | + "Vector similarity function to use to search for top K-nearest neighbors. You can set this field only for vector-type fields." |
| 46 | + ), |
| 47 | + quantization: z |
| 48 | + .enum(["none", "scalar", "binary"]) |
| 49 | + .optional() |
| 50 | + .default("none") |
| 51 | + .describe( |
| 52 | + "Type of automatic vector quantization for your vectors. Use this setting only if your embeddings are float or double vectors." |
| 53 | + ), |
| 54 | + }) |
| 55 | + .strict() |
| 56 | + .describe("Definition for a field that contains vector embeddings."), |
| 57 | + ]) |
| 58 | + ) |
| 59 | + .nonempty() |
| 60 | + .refine((fields) => fields.some((f) => f.type === "vector"), { |
| 61 | + message: "At least one vector field must be defined", |
| 62 | + }) |
| 63 | + .describe( |
| 64 | + "Definitions for the vector and filter fields to index, one definition per document. You must specify `vector` for fields that contain vector embeddings and `filter` for additional fields to filter on. At least one vector-type field definition is required." |
| 65 | + ), |
| 66 | + }); |
| 67 | + |
8 | 68 | public name = "create-index"; |
9 | 69 | protected description = "Create an index for a collection"; |
10 | 70 | protected argsShape = { |
11 | 71 | ...DbOperationArgs, |
12 | | - keys: z.object({}).catchall(z.custom<IndexDirection>()).describe("The index definition"), |
13 | 72 | name: z.string().optional().describe("The name of the index"), |
| 73 | + definition: z |
| 74 | + .array( |
| 75 | + z.discriminatedUnion("type", [ |
| 76 | + z.object({ |
| 77 | + type: z.literal("classic"), |
| 78 | + keys: z.object({}).catchall(z.custom<IndexDirection>()).describe("The index definition"), |
| 79 | + }), |
| 80 | + ...(this.isFeatureFlagEnabled(FeatureFlags.VectorSearch) ? [this.vectorSearchIndexDefinition] : []), |
| 81 | + ]) |
| 82 | + ) |
| 83 | + .describe( |
| 84 | + "The index definition. Use 'classic' for standard indexes and 'vectorSearch' for vector search indexes" |
| 85 | + ), |
14 | 86 | }; |
15 | 87 |
|
16 | 88 | public operationType: OperationType = "create"; |
17 | 89 |
|
18 | 90 | protected async execute({ |
19 | 91 | database, |
20 | 92 | collection, |
21 | | - keys, |
22 | 93 | name, |
| 94 | + definition: definitions, |
23 | 95 | }: ToolArgs<typeof this.argsShape>): Promise<CallToolResult> { |
24 | 96 | const provider = await this.ensureConnected(); |
25 | | - const indexes = await provider.createIndexes(database, collection, [ |
26 | | - { |
27 | | - key: keys, |
28 | | - name, |
29 | | - }, |
30 | | - ]); |
| 97 | + let indexes: string[] = []; |
| 98 | + const definition = definitions[0]; |
| 99 | + if (!definition) { |
| 100 | + throw new Error("Index definition not provided. Expected one of the following: `classic`, `vectorSearch`"); |
| 101 | + } |
| 102 | + |
| 103 | + let responseClarification = ""; |
| 104 | + |
| 105 | + switch (definition.type) { |
| 106 | + case "classic": |
| 107 | + indexes = await provider.createIndexes(database, collection, [ |
| 108 | + { |
| 109 | + key: definition.keys, |
| 110 | + name, |
| 111 | + }, |
| 112 | + ]); |
| 113 | + break; |
| 114 | + case "vectorSearch": |
| 115 | + { |
| 116 | + const isVectorSearchSupported = await this.session.isSearchSupported(); |
| 117 | + if (!isVectorSearchSupported) { |
| 118 | + // TODO: remove hacky casts once we merge the local dev tools |
| 119 | + const isLocalAtlasAvailable = |
| 120 | + (this.server?.tools.filter((t) => t.category === ("atlas-local" as unknown as ToolCategory)) |
| 121 | + .length ?? 0) > 0; |
| 122 | + |
| 123 | + const CTA = isLocalAtlasAvailable ? "`atlas-local` tools" : "Atlas CLI"; |
| 124 | + return { |
| 125 | + content: [ |
| 126 | + { |
| 127 | + text: `The connected MongoDB deployment does not support vector search indexes. Either connect to a MongoDB Atlas cluster or use the ${CTA} to create and manage a local Atlas deployment.`, |
| 128 | + type: "text", |
| 129 | + }, |
| 130 | + ], |
| 131 | + isError: true, |
| 132 | + }; |
| 133 | + } |
| 134 | + |
| 135 | + indexes = await provider.createSearchIndexes(database, collection, [ |
| 136 | + { |
| 137 | + name, |
| 138 | + definition: { |
| 139 | + fields: definition.fields, |
| 140 | + }, |
| 141 | + type: "vectorSearch", |
| 142 | + }, |
| 143 | + ]); |
| 144 | + |
| 145 | + responseClarification = |
| 146 | + " Since this is a vector search index, it may take a while for the index to build. Use the `list-indexes` tool to check the index status."; |
| 147 | + } |
| 148 | + |
| 149 | + break; |
| 150 | + } |
31 | 151 |
|
32 | 152 | return { |
33 | 153 | content: [ |
34 | 154 | { |
35 | | - text: `Created the index "${indexes[0]}" on collection "${collection}" in database "${database}"`, |
| 155 | + text: `Created the index "${indexes[0]}" on collection "${collection}" in database "${database}".${responseClarification}`, |
36 | 156 | type: "text", |
37 | 157 | }, |
38 | 158 | ], |
|
0 commit comments