Skip to content
17 changes: 17 additions & 0 deletions client/src/components/OAuthCallback.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,9 @@ const OAuthCallback = ({ onConnect }: OAuthCallbackProps) => {
});

const params = parseOAuthCallbackParams(window.location.search);
// Extract state from query params
const urlParams = new URLSearchParams(window.location.search);
const returnedState = urlParams.get("state");
if (!params.successful) {
return notifyError(generateOAuthErrorDescription(params));
}
Expand All @@ -45,6 +48,20 @@ const OAuthCallback = ({ onConnect }: OAuthCallbackProps) => {
return notifyError("Missing Server URL");
}

// Validate state parameter
const serverAuthProvider = new InspectorOAuthClientProvider(
serverUrl,
undefined,
undefined,
);
const expectedState = serverAuthProvider.getState();
serverAuthProvider.clearState(); // Always clear after checking
if (!returnedState || !expectedState || returnedState !== expectedState) {
return notifyError(
"Invalid or missing OAuth state parameter. Please try logging in again.",
);
}

const clientInformation =
await getClientInformationFromSessionStorage(serverUrl);

Expand Down
4 changes: 4 additions & 0 deletions client/src/components/__tests__/Sidebar.test.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,10 @@ describe("Sidebar Environment Variables", () => {
setEnv: jest.fn(),
bearerToken: "",
setBearerToken: jest.fn(),
oauthClientId: "",
setOauthClientId: jest.fn(),
oauthParams: "",
setOauthParams: jest.fn(),
onConnect: jest.fn(),
onDisconnect: jest.fn(),
stdErrNotifications: [],
Expand Down
45 changes: 45 additions & 0 deletions client/src/lib/auth.authorize-url.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
import { InspectorOAuthClientProvider } from "./auth";

describe("OAuth /authorize URL includes state parameter", () => {
const serverUrl = "https://example.com";
let provider: InspectorOAuthClientProvider;

// Suppress console.log for this test suite
beforeAll(() => {
jest.spyOn(console, "log").mockImplementation(() => {});
});

beforeEach(() => {
provider = new InspectorOAuthClientProvider(serverUrl);
sessionStorage.clear();
});

it("includes state parameter in the authorization URL", () => {
// Mock window.location.href using Object.defineProperty
const originalLocation = window.location;

Object.defineProperty(window, "location", {
value: { href: "" },
writable: true,
});

const url = new URL("https://authserver.com/authorize");
provider.redirectToAuthorization(url);

// Check that the URL contains the state parameter
expect(window.location.href).toContain("state=");
const stateInUrl = new URL(window.location.href).searchParams.get("state");
expect(stateInUrl).toBeDefined();
expect(stateInUrl!.length).toBeGreaterThan(0);

// Restore window.location
Object.defineProperty(window, "location", {
value: originalLocation,
writable: true,
});
});

afterAll(() => {
(console.log as jest.Mock).mockRestore();
});
});
31 changes: 31 additions & 0 deletions client/src/lib/auth.state.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
import { InspectorOAuthClientProvider } from "./auth";

describe("InspectorOAuthClientProvider state parameter", () => {
const serverUrl = "https://example.com";
let provider: InspectorOAuthClientProvider;

beforeEach(() => {
provider = new InspectorOAuthClientProvider(serverUrl);
sessionStorage.clear();
});

it("generates, stores, and retrieves state", () => {
const state = provider.generateAndStoreState();
expect(state).toBeDefined();
expect(state).toEqual(provider.getState());
expect(state).toHaveLength(32);
});

it("clears state from sessionStorage", () => {
provider.generateAndStoreState();
provider.clearState();
expect(provider.getState()).toBeNull();
});

it("generates a new state each time", () => {
const state1 = provider.generateAndStoreState();
provider.clearState();
const state2 = provider.generateAndStoreState();
expect(state1).not.toEqual(state2);
});
});
30 changes: 30 additions & 0 deletions client/src/lib/auth.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import {
OAuthTokensSchema,
} from "@modelcontextprotocol/sdk/shared/auth.js";
import { SESSION_KEYS, getServerSpecificKey } from "./constants";
import { generateRandomState } from "@/utils/oauthUtils";

export const getClientInformationFromSessionStorage = async (
serverUrl: string,
Expand Down Expand Up @@ -98,7 +99,36 @@ export class InspectorOAuthClientProvider implements OAuthClientProvider {
sessionStorage.setItem(key, JSON.stringify(tokens));
}

/**
* Generate, store, and return a new state parameter for OAuth.
*/
generateAndStoreState(): string {
const state = generateRandomState(32);
const key = getServerSpecificKey(SESSION_KEYS.OAUTH_STATE, this.serverUrl);
sessionStorage.setItem(key, state);
return state;
}

/**
* Retrieve the stored state parameter for this serverUrl.
*/
getState(): string | null {
const key = getServerSpecificKey(SESSION_KEYS.OAUTH_STATE, this.serverUrl);
return sessionStorage.getItem(key);
}

/**
* Remove the stored state parameter for this serverUrl.
*/
clearState() {
const key = getServerSpecificKey(SESSION_KEYS.OAUTH_STATE, this.serverUrl);
sessionStorage.removeItem(key);
}

redirectToAuthorization(authorizationUrl: URL) {
// Generate and store a new state parameter
const state = this.generateAndStoreState();
authorizationUrl.searchParams.set("state", state);
const authParams = this.authParams();
console.log("authParams", authParams);
if (authParams) {
Expand Down
1 change: 1 addition & 0 deletions client/src/lib/constants.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ export const SESSION_KEYS = {
TOKENS: "mcp_tokens",
CLIENT_INFORMATION: "mcp_client_information",
OAUTH_PARAMS: "mcp_oauth_params",
OAUTH_STATE: "oauth_state",
} as const;

// Generate server-specific session storage keys
Expand Down
27 changes: 26 additions & 1 deletion client/src/utils/__tests__/oauthUtils.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import {
generateOAuthErrorDescription,
parseOAuthCallbackParams,
} from "@/utils/oauthUtils.ts";
generateRandomState,
} from "@/utils/oauthUtils";

describe("parseOAuthCallbackParams", () => {
it("Returns successful: true and code when present", () => {
Expand Down Expand Up @@ -76,3 +77,27 @@ describe("generateOAuthErrorDescription", () => {
);
});
});

describe("generateRandomState", () => {
it("generates a string of the correct length", () => {
const state = generateRandomState(32);
expect(state).toHaveLength(32);
const state16 = generateRandomState(16);
expect(state16).toHaveLength(16);
});

it("generates a string with only allowed characters", () => {
const charset =
"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789";
const state = generateRandomState(64);
for (const char of state) {
expect(charset.includes(char)).toBe(true);
}
});

it("generates different values on subsequent calls (randomness)", () => {
const state1 = generateRandomState(32);
const state2 = generateRandomState(32);
expect(state1).not.toEqual(state2);
});
});
12 changes: 12 additions & 0 deletions client/src/utils/oauthUtils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -63,3 +63,15 @@ export const generateOAuthErrorDescription = (
.filter(Boolean)
.join("\n");
};

/**
* Generates a cryptographically secure random string for use as OAuth state or PKCE code_verifier.
* @param length Number of characters in the generated string (default: 32)
*/
export function generateRandomState(length = 32): string {
const charset =
"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789";
const array = new Uint8Array(length);
window.crypto.getRandomValues(array);
return Array.from(array, (byte) => charset[byte % charset.length]).join("");
}