diff --git a/client/src/components/OAuthCallback.tsx b/client/src/components/OAuthCallback.tsx index 80ed62bb7..9c19e6957 100644 --- a/client/src/components/OAuthCallback.tsx +++ b/client/src/components/OAuthCallback.tsx @@ -36,6 +36,9 @@ const OAuthCallback = ({ onConnect }: OAuthCallbackProps) => { }); const params = parseOAuthCallbackParams(window.location.search); + // Extract state from query params + const urlParams = new URLSearchParams(window.location.search); + const returnedState = urlParams.get("state"); if (!params.successful) { return notifyError(generateOAuthErrorDescription(params)); } @@ -45,6 +48,20 @@ const OAuthCallback = ({ onConnect }: OAuthCallbackProps) => { return notifyError("Missing Server URL"); } + // Validate state parameter + const serverAuthProvider = new InspectorOAuthClientProvider( + serverUrl, + undefined, + undefined, + ); + const expectedState = serverAuthProvider.getState(); + serverAuthProvider.clearState(); // Always clear after checking + if (!returnedState || !expectedState || returnedState !== expectedState) { + return notifyError( + "Invalid or missing OAuth state parameter. Please try logging in again.", + ); + } + const clientInformation = await getClientInformationFromSessionStorage(serverUrl); diff --git a/client/src/components/__tests__/Sidebar.test.tsx b/client/src/components/__tests__/Sidebar.test.tsx index 980dc0120..ad7c2125b 100644 --- a/client/src/components/__tests__/Sidebar.test.tsx +++ b/client/src/components/__tests__/Sidebar.test.tsx @@ -27,6 +27,10 @@ describe("Sidebar Environment Variables", () => { setEnv: jest.fn(), bearerToken: "", setBearerToken: jest.fn(), + oauthClientId: "", + setOauthClientId: jest.fn(), + oauthParams: "", + setOauthParams: jest.fn(), onConnect: jest.fn(), onDisconnect: jest.fn(), stdErrNotifications: [], diff --git a/client/src/lib/auth.authorize-url.test.ts b/client/src/lib/auth.authorize-url.test.ts new file mode 100644 index 000000000..53c6404db --- /dev/null +++ b/client/src/lib/auth.authorize-url.test.ts @@ -0,0 +1,45 @@ +import { InspectorOAuthClientProvider } from "./auth"; + +describe("OAuth /authorize URL includes state parameter", () => { + const serverUrl = "https://example.com"; + let provider: InspectorOAuthClientProvider; + + // Suppress console.log for this test suite + beforeAll(() => { + jest.spyOn(console, "log").mockImplementation(() => {}); + }); + + beforeEach(() => { + provider = new InspectorOAuthClientProvider(serverUrl); + sessionStorage.clear(); + }); + + it("includes state parameter in the authorization URL", () => { + // Mock window.location.href using Object.defineProperty + const originalLocation = window.location; + + Object.defineProperty(window, "location", { + value: { href: "" }, + writable: true, + }); + + const url = new URL("https://authserver.com/authorize"); + provider.redirectToAuthorization(url); + + // Check that the URL contains the state parameter + expect(window.location.href).toContain("state="); + const stateInUrl = new URL(window.location.href).searchParams.get("state"); + expect(stateInUrl).toBeDefined(); + expect(stateInUrl!.length).toBeGreaterThan(0); + + // Restore window.location + Object.defineProperty(window, "location", { + value: originalLocation, + writable: true, + }); + }); + + afterAll(() => { + (console.log as jest.Mock).mockRestore(); + }); +}); diff --git a/client/src/lib/auth.state.test.ts b/client/src/lib/auth.state.test.ts new file mode 100644 index 000000000..6aa3fd3ae --- /dev/null +++ b/client/src/lib/auth.state.test.ts @@ -0,0 +1,31 @@ +import { InspectorOAuthClientProvider } from "./auth"; + +describe("InspectorOAuthClientProvider state parameter", () => { + const serverUrl = "https://example.com"; + let provider: InspectorOAuthClientProvider; + + beforeEach(() => { + provider = new InspectorOAuthClientProvider(serverUrl); + sessionStorage.clear(); + }); + + it("generates, stores, and retrieves state", () => { + const state = provider.generateAndStoreState(); + expect(state).toBeDefined(); + expect(state).toEqual(provider.getState()); + expect(state).toHaveLength(32); + }); + + it("clears state from sessionStorage", () => { + provider.generateAndStoreState(); + provider.clearState(); + expect(provider.getState()).toBeNull(); + }); + + it("generates a new state each time", () => { + const state1 = provider.generateAndStoreState(); + provider.clearState(); + const state2 = provider.generateAndStoreState(); + expect(state1).not.toEqual(state2); + }); +}); diff --git a/client/src/lib/auth.ts b/client/src/lib/auth.ts index 8d034296a..a41340c28 100644 --- a/client/src/lib/auth.ts +++ b/client/src/lib/auth.ts @@ -6,6 +6,7 @@ import { OAuthTokensSchema, } from "@modelcontextprotocol/sdk/shared/auth.js"; import { SESSION_KEYS, getServerSpecificKey } from "./constants"; +import { generateRandomState } from "@/utils/oauthUtils"; export const getClientInformationFromSessionStorage = async ( serverUrl: string, @@ -98,7 +99,36 @@ export class InspectorOAuthClientProvider implements OAuthClientProvider { sessionStorage.setItem(key, JSON.stringify(tokens)); } + /** + * Generate, store, and return a new state parameter for OAuth. + */ + generateAndStoreState(): string { + const state = generateRandomState(32); + const key = getServerSpecificKey(SESSION_KEYS.OAUTH_STATE, this.serverUrl); + sessionStorage.setItem(key, state); + return state; + } + + /** + * Retrieve the stored state parameter for this serverUrl. + */ + getState(): string | null { + const key = getServerSpecificKey(SESSION_KEYS.OAUTH_STATE, this.serverUrl); + return sessionStorage.getItem(key); + } + + /** + * Remove the stored state parameter for this serverUrl. + */ + clearState() { + const key = getServerSpecificKey(SESSION_KEYS.OAUTH_STATE, this.serverUrl); + sessionStorage.removeItem(key); + } + redirectToAuthorization(authorizationUrl: URL) { + // Generate and store a new state parameter + const state = this.generateAndStoreState(); + authorizationUrl.searchParams.set("state", state); const authParams = this.authParams(); console.log("authParams", authParams); if (authParams) { diff --git a/client/src/lib/constants.ts b/client/src/lib/constants.ts index d1d3e0787..b574250ae 100644 --- a/client/src/lib/constants.ts +++ b/client/src/lib/constants.ts @@ -7,6 +7,7 @@ export const SESSION_KEYS = { TOKENS: "mcp_tokens", CLIENT_INFORMATION: "mcp_client_information", OAUTH_PARAMS: "mcp_oauth_params", + OAUTH_STATE: "oauth_state", } as const; // Generate server-specific session storage keys diff --git a/client/src/utils/__tests__/oauthUtils.ts b/client/src/utils/__tests__/oauthUtils.ts index cc9674cb2..22e3757fe 100644 --- a/client/src/utils/__tests__/oauthUtils.ts +++ b/client/src/utils/__tests__/oauthUtils.ts @@ -1,7 +1,8 @@ import { generateOAuthErrorDescription, parseOAuthCallbackParams, -} from "@/utils/oauthUtils.ts"; + generateRandomState, +} from "@/utils/oauthUtils"; describe("parseOAuthCallbackParams", () => { it("Returns successful: true and code when present", () => { @@ -76,3 +77,27 @@ describe("generateOAuthErrorDescription", () => { ); }); }); + +describe("generateRandomState", () => { + it("generates a string of the correct length", () => { + const state = generateRandomState(32); + expect(state).toHaveLength(32); + const state16 = generateRandomState(16); + expect(state16).toHaveLength(16); + }); + + it("generates a string with only allowed characters", () => { + const charset = + "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789"; + const state = generateRandomState(64); + for (const char of state) { + expect(charset.includes(char)).toBe(true); + } + }); + + it("generates different values on subsequent calls (randomness)", () => { + const state1 = generateRandomState(32); + const state2 = generateRandomState(32); + expect(state1).not.toEqual(state2); + }); +}); diff --git a/client/src/utils/oauthUtils.ts b/client/src/utils/oauthUtils.ts index c971271e3..db9ba22ff 100644 --- a/client/src/utils/oauthUtils.ts +++ b/client/src/utils/oauthUtils.ts @@ -63,3 +63,15 @@ export const generateOAuthErrorDescription = ( .filter(Boolean) .join("\n"); }; + +/** + * Generates a cryptographically secure random string for use as OAuth state or PKCE code_verifier. + * @param length Number of characters in the generated string (default: 32) + */ +export function generateRandomState(length = 32): string { + const charset = + "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789"; + const array = new Uint8Array(length); + window.crypto.getRandomValues(array); + return Array.from(array, (byte) => charset[byte % charset.length]).join(""); +}