From 470c9697af99f0ba4789194c7c05d18764a9be79 Mon Sep 17 00:00:00 2001 From: Taylor Howellsmith <0x74@arcn.ms> Date: Sun, 9 Nov 2025 09:00:54 -0500 Subject: [PATCH] feat: add proactive token refresh manager with disk-backed metadata MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Objective: allow MCP Remote to renew OAuth access tokens automatically before expiry, ensuring users don’t get interrupted in long-running sessions. Tests mock all collaborators to validate scheduler decisions (expiring soon, lock contention, failure backoff) and the new helper functions; full CLI/proxy behavior was verified manually since Node tooling isn’t available in this environment. - add token_state.json, server.json, and refresh_lock.json plumbing in src/lib/mcp-auth-config.ts so we can track issued/expiry timestamps, remember which servers exist, and coordinate background refreshes without overloading the existing OAuth files (which must stay spec-compliant) - introduce TokenRefreshManager plus supporting CLI flags; it discovers every server hash with tokens, respects per-server locks/backoff, invokes the SDK’s refresh grant, and logs timing/expiry info so long-running proxies don’t require browser re-auth - extend NodeOAuthClientProvider.saveTokens() to persist derived timing metadata, invalidate it alongside tokens, and keep tokens.json untouched for compatibility - wire the manager into both CLI and proxy entrypoints (default on, opt-out via --disable-auto-refresh), persist server registrations during argument parsing, and document the new flags in the README so headless deployments behave predictably - add unit tests (src/lib/token-refresh-manager.test.ts) that mock disk IO and SDK calls to cover refresh-trigger logic, locking, and backoff, and mirror the repo’s logging/test-style conventions - enhanced debug logging now prints human-readable timestamps plus remaining durations whenever the refresh window is evaluated --- README.md | 21 ++ src/client.ts | 25 +- src/lib/mcp-auth-config.ts | 185 ++++++++++++++ src/lib/node-oauth-client-provider.ts | 16 +- src/lib/token-refresh-manager.test.ts | 226 +++++++++++++++++ src/lib/token-refresh-manager.ts | 351 ++++++++++++++++++++++++++ src/lib/utils.ts | 94 ++++++- src/proxy.ts | 24 +- 8 files changed, 936 insertions(+), 6 deletions(-) create mode 100644 src/lib/token-refresh-manager.test.ts create mode 100644 src/lib/token-refresh-manager.ts diff --git a/README.md b/README.md index d95aaaa..3afb213 100644 --- a/README.md +++ b/README.md @@ -209,6 +209,27 @@ You can specify multiple `--ignore-tool` flags to ignore different patterns. Exa ] ``` +* To automatically refresh access tokens before they expire, use the auto-refresh flags: + * `--enable-auto-refresh` / `--disable-auto-refresh` – turn the background refresher on or off (enabled by default for both the proxy and the CLI so they keep working in the background, opt out when you explicitly need to disable it). + * `--refresh-lead ` – how early to refresh before expiry (default `600`, i.e. 10 minutes). + * `--refresh-interval ` – how often to scan stored tokens (default `60`). + * `--refresh-backoff ` – how long to wait before retrying after a failure (default `300`). + + +```json + "args": [ + "mcp-remote", + "https://remote.mcp.server/sse", + "--enable-auto-refresh", + "--refresh-lead", + "300", + "--refresh-interval", + "30" + ] +``` + +The refresher scans all OAuth sessions stored under `~/.mcp-auth`, renews access tokens using their refresh tokens, and logs the outcome so long-running hosts keep working without forcing a browser re-auth. + ### Transport Strategies MCP Remote supports different transport strategies when connecting to an MCP server. This allows you to control whether it uses Server-Sent Events (SSE) or HTTP transport, and in what order it tries them. diff --git a/src/client.ts b/src/client.ts index 961cb22..2df4ebb 100644 --- a/src/client.ts +++ b/src/client.ts @@ -13,9 +13,18 @@ import { EventEmitter } from 'events' import { Client } from '@modelcontextprotocol/sdk/client/index.js' import { ListResourcesResultSchema, ListToolsResultSchema } from '@modelcontextprotocol/sdk/types.js' import { NodeOAuthClientProvider } from './lib/node-oauth-client-provider' -import { parseCommandLineArgs, setupSignalHandlers, log, MCP_REMOTE_VERSION, connectToRemoteServer, TransportStrategy } from './lib/utils' +import { + parseCommandLineArgs, + setupSignalHandlers, + log, + MCP_REMOTE_VERSION, + connectToRemoteServer, + TransportStrategy, + AutoRefreshOptions, +} from './lib/utils' import { StaticOAuthClientInformationFull, StaticOAuthClientMetadata } from './lib/types' import { createLazyAuthCoordinator } from './lib/coordination' +import { TokenRefreshManager } from './lib/token-refresh-manager' /** * Main function to run the client @@ -30,6 +39,7 @@ async function runClient( staticOAuthClientInfo: StaticOAuthClientInformationFull, authTimeoutMs: number, serverUrlHash: string, + autoRefresh: AutoRefreshOptions, ) { // Set up event emitter for auth flow const events = new EventEmitter() @@ -48,6 +58,14 @@ async function runClient( serverUrlHash, }) + const refreshManager = new TokenRefreshManager({ + enabled: autoRefresh.enabled, + intervalMs: autoRefresh.intervalMs, + leadTimeMs: autoRefresh.leadTimeMs, + failureBackoffMs: autoRefresh.backoffMs, + }) + refreshManager.start() + // Create the client const client = new Client( { @@ -103,6 +121,7 @@ async function runClient( // Set up cleanup handler const cleanup = async () => { + refreshManager.stop() log('\nClosing connection...') await client.close() // If auth was initialized and server was created, close it @@ -134,6 +153,7 @@ async function runClient( // log('Listening for messages. Press Ctrl+C to exit.') log('Exiting OK...') + refreshManager.stop() // Only close the server if it was initialized if (server) { server.close() @@ -141,6 +161,7 @@ async function runClient( process.exit(0) } catch (error) { log('Fatal error:', error) + refreshManager.stop() // Only close the server if it was initialized if (server) { server.close() @@ -162,6 +183,7 @@ parseCommandLineArgs(process.argv.slice(2), 'Usage: npx tsx client.ts { return runClient( serverUrl, @@ -173,6 +195,7 @@ parseCommandLineArgs(process.argv.slice(2), 'Usage: npx tsx client.ts { await deleteConfigFile(serverUrlHash, 'lock.json') } +/** + * Saves persistent information about a registered server + * @param serverUrlHash The hash identifying the server configuration + * @param registration The registration metadata to store + */ +export async function saveServerRegistration(serverUrlHash: string, registration: ServerRegistration): Promise { + await writeJsonFile(serverUrlHash, 'server.json', registration) +} + +/** + * Reads server registration data if available + * @param serverUrlHash The hash identifying the server configuration + * @returns The stored registration metadata, if present + */ +export async function readServerRegistration(serverUrlHash: string): Promise { + return await readJsonFile(serverUrlHash, 'server.json', serverRegistrationSchema) +} + +/** + * Lists server hashes that currently have token files on disk + * @returns An array of server hashes with stored tokens + */ +export async function listServerHashesWithTokens(): Promise { + try { + const configDir = getConfigDir() + const entries = await fs.readdir(configDir) + + return entries + .filter((filename) => filename.endsWith('_tokens.json')) + .map((filename) => filename.replace(/_tokens\.json$/, '')) + } catch (error) { + if ((error as NodeJS.ErrnoException).code === 'ENOENT') { + return [] + } + log('Error listing stored tokens:', error) + return [] + } +} + /** * Gets the configuration directory path * @returns The path to the configuration directory @@ -171,6 +301,31 @@ export async function writeJsonFile(serverUrlHash: string, filename: string, dat } } +/** + * Writes token metadata for a server (e.g., issued/expiry timestamps) + * @param serverUrlHash The hash identifying the server configuration + * @param state The token state fields to persist/merge + */ +export async function writeTokenState(serverUrlHash: string, state: Partial): Promise { + const current = await readTokenState(serverUrlHash) + const nextState: TokenState = { + issuedAt: state.issuedAt ?? current?.issuedAt ?? Date.now(), + expiresAt: state.expiresAt ?? current?.expiresAt, + lastRefreshAttempt: state.lastRefreshAttempt ?? current?.lastRefreshAttempt, + lastRefreshError: state.lastRefreshError ?? current?.lastRefreshError, + } + await writeJsonFile(serverUrlHash, 'token_state.json', nextState) +} + +/** + * Reads token metadata for a server + * @param serverUrlHash The hash identifying the server configuration + * @returns The stored token state, if available + */ +export async function readTokenState(serverUrlHash: string): Promise { + return await readJsonFile(serverUrlHash, 'token_state.json', tokenStateSchema) +} + /** * Reads a text file * @param serverUrlHash The hash of the server URL @@ -204,3 +359,33 @@ export async function writeTextFile(serverUrlHash: string, filename: string, tex throw error } } + +/** + * Attempts to acquire a refresh lock for a server. Returns true if acquired. + * @param serverUrlHash The hash identifying the server configuration + * @param ttlMs The duration in milliseconds before the lock expires automatically + */ +export async function tryAcquireRefreshLock(serverUrlHash: string, ttlMs: number): Promise { + const existingLock = await readJsonFile(serverUrlHash, 'refresh_lock.json', refreshLockSchema) + const now = Date.now() + + if (existingLock && existingLock.expiresAt > now && existingLock.pid !== process.pid) { + return false + } + + const newLock: RefreshLockData = { + pid: process.pid, + expiresAt: now + ttlMs, + } + + await writeJsonFile(serverUrlHash, 'refresh_lock.json', newLock) + return true +} + +/** + * Releases the refresh lock for a server + * @param serverUrlHash The hash identifying the server configuration + */ +export async function releaseRefreshLock(serverUrlHash: string): Promise { + await deleteConfigFile(serverUrlHash, 'refresh_lock.json') +} diff --git a/src/lib/node-oauth-client-provider.ts b/src/lib/node-oauth-client-provider.ts index 26a0c02..066b1c5 100644 --- a/src/lib/node-oauth-client-provider.ts +++ b/src/lib/node-oauth-client-provider.ts @@ -7,7 +7,7 @@ import { OAuthTokensSchema, } from '@modelcontextprotocol/sdk/shared/auth.js' import type { OAuthProviderOptions, StaticOAuthClientMetadata } from './types' -import { readJsonFile, writeJsonFile, readTextFile, writeTextFile, deleteConfigFile } from './mcp-auth-config' +import { readJsonFile, writeJsonFile, readTextFile, writeTextFile, deleteConfigFile, writeTokenState } from './mcp-auth-config' import { StaticOAuthClientInformationFull } from './types' import { log, debugLog, MCP_REMOTE_VERSION } from './utils' import { sanitizeUrl } from 'strict-url-sanitise' @@ -28,6 +28,7 @@ export class NodeOAuthClientProvider implements OAuthClientProvider { private staticOAuthClientInfo: StaticOAuthClientInformationFull private authorizeResource: string | undefined private _state: string + addClientAuthentication?: OAuthClientProvider['addClientAuthentication'] /** * Creates a new NodeOAuthClientProvider @@ -157,6 +158,13 @@ export class NodeOAuthClientProvider implements OAuthClientProvider { }) await writeJsonFile(this.serverUrlHash, 'tokens.json', tokens) + + const issuedAt = Date.now() + const expiresAt = typeof tokens.expires_in === 'number' ? issuedAt + tokens.expires_in * 1000 : undefined + await writeTokenState(this.serverUrlHash, { + issuedAt, + expiresAt, + }) } /** @@ -213,6 +221,7 @@ export class NodeOAuthClientProvider implements OAuthClientProvider { await Promise.all([ deleteConfigFile(this.serverUrlHash, 'client_info.json'), deleteConfigFile(this.serverUrlHash, 'tokens.json'), + deleteConfigFile(this.serverUrlHash, 'token_state.json'), deleteConfigFile(this.serverUrlHash, 'code_verifier.txt'), ]) debugLog('All credentials invalidated') @@ -224,7 +233,10 @@ export class NodeOAuthClientProvider implements OAuthClientProvider { break case 'tokens': - await deleteConfigFile(this.serverUrlHash, 'tokens.json') + await Promise.all([ + deleteConfigFile(this.serverUrlHash, 'tokens.json'), + deleteConfigFile(this.serverUrlHash, 'token_state.json'), + ]) debugLog('OAuth tokens invalidated') break diff --git a/src/lib/token-refresh-manager.test.ts b/src/lib/token-refresh-manager.test.ts new file mode 100644 index 0000000..c857ddb --- /dev/null +++ b/src/lib/token-refresh-manager.test.ts @@ -0,0 +1,226 @@ +import { afterAll, beforeEach, describe, expect, it, vi } from 'vitest' + +const mockTokenStore = new Map() +const mockTokenStates = new Map() +const mockRegistrations = new Map() +const mockLocks = new Map() +let currentTime = 1_700_000_000_000 + +vi.mock('./mcp-auth-config', () => { + return { + listServerHashesWithTokens: vi.fn(async () => Array.from(mockTokenStore.keys())), + readServerRegistration: vi.fn(async (hash: string) => mockRegistrations.get(hash)), + saveServerRegistration: vi.fn(async (hash: string, registration: any) => { + mockRegistrations.set(hash, registration) + }), + readTokenState: vi.fn(async (hash: string) => mockTokenStates.get(hash)), + writeTokenState: vi.fn(async (hash: string, state: any) => { + const existing = mockTokenStates.get(hash) ?? {} + mockTokenStates.set(hash, { ...existing, ...state }) + }), + tryAcquireRefreshLock: vi.fn(async (hash: string, ttlMs: number) => { + const now = Date.now() + const lockUntil = mockLocks.get(hash) + if (lockUntil && lockUntil > now) { + return false + } + mockLocks.set(hash, now + ttlMs) + return true + }), + releaseRefreshLock: vi.fn(async (hash: string) => { + mockLocks.delete(hash) + }), + } +}) + +vi.mock('./node-oauth-client-provider', () => { + class MockNodeOAuthClientProvider { + private readonly serverUrlHash: string + + constructor(options: any) { + this.serverUrlHash = options.serverUrlHash + } + + async clientInformation() { + return { + client_id: 'test-client', + redirect_uris: ['http://127.0.0.1:3335/oauth/callback'], + token_endpoint_auth_method: 'none', + } + } + + async tokens() { + return mockTokenStore.get(this.serverUrlHash) + } + + async saveTokens(tokens: any) { + mockTokenStore.set(this.serverUrlHash, tokens) + } + + get addClientAuthentication() { + return undefined + } + } + + return { NodeOAuthClientProvider: MockNodeOAuthClientProvider } +}) + +const clientMocks = vi.hoisted(() => { + const refreshAuthorization = vi.fn(async (_url: string | URL, { refreshToken }: { refreshToken: string }) => ({ + access_token: `new-token-for-${refreshToken}`, + refresh_token: refreshToken, + token_type: 'Bearer', + expires_in: 3600, + })) + const discoverAuthorizationServerMetadata = vi.fn(async () => ({ token_endpoint: 'https://auth.example/token' })) + const discoverOAuthProtectedResourceMetadata = vi.fn(async () => ({ + authorization_servers: ['https://auth.example'], + resource: 'https://resource.example', + })) + const selectResourceURL = vi.fn(async () => new URL('https://resource.example')) + + return { + refreshAuthorization, + discoverAuthorizationServerMetadata, + discoverOAuthProtectedResourceMetadata, + selectResourceURL, + } +}) + +vi.mock('@modelcontextprotocol/sdk/client/auth.js', () => clientMocks) +const { + refreshAuthorization, + discoverAuthorizationServerMetadata, + discoverOAuthProtectedResourceMetadata, + selectResourceURL, +} = clientMocks + +const errorMocks = vi.hoisted(() => { + class MockOAuthError extends Error { + constructor(message?: string, public errorCode?: string) { + super(message) + this.name = 'OAuthError' + } + } + + return { MockOAuthError } +}) + +vi.mock('@modelcontextprotocol/sdk/server/auth/errors.js', () => ({ + OAuthError: errorMocks.MockOAuthError, +})) + +vi.mock('./utils', async () => { + const originalModule = await vi.importActual('./utils') + return originalModule +}) + +import { TokenRefreshManager, isTokenExpiringSoon } from './token-refresh-manager' + +const dateNowSpy = vi.spyOn(Date, 'now').mockImplementation(() => currentTime) + +function seedServer(serverHash: string, { token, state }: { token: any; state: any }) { + mockRegistrations.set(serverHash, { + serverUrl: 'https://remote.example/sse', + host: 'localhost', + }) + mockTokenStore.set(serverHash, token) + mockTokenStates.set(serverHash, state) +} + +describe('Feature: Token Expiration Helper', () => { + it('Scenario: Returns false when no token state exists', () => { + expect(isTokenExpiringSoon(undefined, 600_000, currentTime)).toBe(false) + }) + + it('Scenario: Returns true when token already expired', () => { + const state = { issuedAt: currentTime - 10_000, expiresAt: currentTime - 1000 } + expect(isTokenExpiringSoon(state, 600_000, currentTime)).toBe(true) + }) + + it('Scenario: Returns true when token expires within lead window', () => { + const state = { issuedAt: currentTime - 1000, expiresAt: currentTime + 30_000 } + expect(isTokenExpiringSoon(state, 60_000, currentTime)).toBe(true) + }) + + it('Scenario: Returns false when token expires beyond lead window', () => { + const state = { issuedAt: currentTime - 1000, expiresAt: currentTime + 120_000 } + expect(isTokenExpiringSoon(state, 60_000, currentTime)).toBe(false) + }) +}) + +describe('Feature: Token Refresh Manager', () => { + beforeEach(() => { + mockTokenStore.clear() + mockTokenStates.clear() + mockRegistrations.clear() + mockLocks.clear() + refreshAuthorization.mockClear() + discoverAuthorizationServerMetadata.mockClear() + discoverOAuthProtectedResourceMetadata.mockClear() + selectResourceURL.mockClear() + currentTime = 1_700_000_000_000 + }) + + afterAll(() => { + dateNowSpy.mockRestore() + }) + + it('Scenario: Refreshes tokens when they are expiring soon', async () => { + const serverHash = 'hash-success' + seedServer(serverHash, { + token: { access_token: 'old', refresh_token: 'refresh-1', token_type: 'Bearer', expires_in: 30 }, + state: { issuedAt: currentTime - 1000, expiresAt: currentTime + 30_000 }, + }) + + const manager = new TokenRefreshManager({ enabled: true, leadTimeMs: 60_000 }) + await (manager as any).refreshIfNeeded(serverHash) + + expect(refreshAuthorization).toHaveBeenCalledTimes(1) + expect(mockTokenStore.get(serverHash)?.access_token).toBe('new-token-for-refresh-1') + expect(mockTokenStates.get(serverHash)?.lastRefreshAttempt).toBe(currentTime) + expect(mockTokenStates.get(serverHash)?.lastRefreshError).toBeUndefined() + expect(mockLocks.size).toBe(0) + }) + + it('Scenario: Skips refresh when another process holds the lock', async () => { + const serverHash = 'hash-lock' + seedServer(serverHash, { + token: { access_token: 'old', refresh_token: 'refresh-lock', token_type: 'Bearer', expires_in: 30 }, + state: { issuedAt: currentTime - 1000, expiresAt: currentTime + 10_000 }, + }) + mockLocks.set(serverHash, currentTime + 60_000) + + const manager = new TokenRefreshManager({ enabled: true, leadTimeMs: 60_000 }) + await (manager as any).refreshIfNeeded(serverHash) + + expect(refreshAuthorization).not.toHaveBeenCalled() + expect(mockLocks.get(serverHash)).toBe(currentTime + 60_000) + }) + + it('Scenario: Backs off and retries after a failed refresh attempt', async () => { + const serverHash = 'hash-backoff' + seedServer(serverHash, { + token: { access_token: 'old', refresh_token: 'refresh-backoff', token_type: 'Bearer', expires_in: 30 }, + state: { issuedAt: currentTime - 1000, expiresAt: currentTime + 10_000 }, + }) + + const manager = new TokenRefreshManager({ enabled: true, leadTimeMs: 60_000, failureBackoffMs: 60_000 }) + + refreshAuthorization.mockRejectedValueOnce(new Error('refresh failed')) + + await (manager as any).refreshIfNeeded(serverHash) + + expect(mockTokenStates.get(serverHash)?.lastRefreshError).toBe('refresh failed') + expect(refreshAuthorization).toHaveBeenCalledTimes(1) + + await (manager as any).refreshIfNeeded(serverHash) + expect(refreshAuthorization).toHaveBeenCalledTimes(1) + + currentTime += 120_000 + await (manager as any).refreshIfNeeded(serverHash) + expect(refreshAuthorization).toHaveBeenCalledTimes(2) + expect(mockTokenStore.get(serverHash)?.access_token).toBe('new-token-for-refresh-backoff') + expect(mockTokenStates.get(serverHash)?.lastRefreshError).toBeUndefined() + }) +}) diff --git a/src/lib/token-refresh-manager.ts b/src/lib/token-refresh-manager.ts new file mode 100644 index 0000000..6515d69 --- /dev/null +++ b/src/lib/token-refresh-manager.ts @@ -0,0 +1,351 @@ +import { OAuthError } from '@modelcontextprotocol/sdk/server/auth/errors.js' +import { + discoverAuthorizationServerMetadata, + discoverOAuthProtectedResourceMetadata, + refreshAuthorization, + selectResourceURL, +} from '@modelcontextprotocol/sdk/client/auth.js' +import { NodeOAuthClientProvider } from './node-oauth-client-provider' +import { + listServerHashesWithTokens, + readServerRegistration, + readTokenState, + releaseRefreshLock, + ServerRegistration, + TokenState, + tryAcquireRefreshLock, + writeTokenState, +} from './mcp-auth-config' +import { debugLog, log } from './utils' + +/** + * Periodically scans persisted OAuth sessions and proactively refreshes their access tokens + * using the stored refresh tokens. This keeps both the CLI and proxy processes authenticated + * without forcing users back through the interactive browser flow when tokens expire. + */ + +const DEFAULT_INTERVAL_MS = 60_000 +const DEFAULT_LEAD_MS = 10 * 60_000 +const DEFAULT_LOCK_TTL_MS = 2 * 60_000 +const DEFAULT_FAILURE_BACKOFF_MS = 5 * 60_000 + +export interface TokenRefreshManagerOptions { + enabled?: boolean + intervalMs?: number + leadTimeMs?: number + lockTtlMs?: number + failureBackoffMs?: number +} + +export class TokenRefreshManager { + private timer: NodeJS.Timeout | null = null + private stopped = true + private readonly intervalMs: number + private readonly leadTimeMs: number + private readonly lockTtlMs: number + private readonly failureBackoffMs: number + private readonly enabled: boolean + private failureBackoff = new Map() + + constructor(options: TokenRefreshManagerOptions = {}) { + this.intervalMs = options.intervalMs ?? DEFAULT_INTERVAL_MS + this.leadTimeMs = options.leadTimeMs ?? DEFAULT_LEAD_MS + this.lockTtlMs = options.lockTtlMs ?? DEFAULT_LOCK_TTL_MS + this.failureBackoffMs = options.failureBackoffMs ?? DEFAULT_FAILURE_BACKOFF_MS + this.enabled = options.enabled ?? false + } + + /** + * Begins running the background scan loop if auto-refresh is enabled. + */ + start() { + if (!this.enabled) { + debugLog('Token refresh manager disabled') + return + } + if (!this.stopped) { + return + } + this.stopped = false + this.scheduleNextScan(0) + } + + /** + * Stops the background scan loop and clears any pending timers. + */ + stop() { + if (this.timer) { + clearTimeout(this.timer) + this.timer = null + } + this.stopped = true + } + + /** + * Queues the next scan after the provided delay. No-op when the manager has been stopped. + */ + private scheduleNextScan(delayMs: number) { + if (this.stopped) { + return + } + this.timer = setTimeout(() => { + this.runScan() + .catch((error) => { + log('Token refresh manager error:', error) + }) + .finally(() => { + this.scheduleNextScan(this.intervalMs) + }) + }, delayMs) + } + + /** + * Iterates over every server hash that currently has saved tokens and attempts a refresh + * when the tokens are nearing expiration. + */ + private async runScan() { + const serverHashes = await listServerHashesWithTokens() + if (serverHashes.length === 0) { + debugLog('Token refresh manager: no servers with stored tokens') + return + } + + for (const serverUrlHash of serverHashes) { + try { + await this.refreshIfNeeded(serverUrlHash) + } catch (error) { + log(`Token refresh failed for server hash ${serverUrlHash}: ${this.formatError(error)}`) + debugLog('Token refresh failure details', { + serverUrlHash, + stack: (error as Error).stack, + }) + } + } + } + + /** + * Checks a single server entry and refreshes its tokens when it is close to expiration, + * respecting inter-process locks and failure backoff windows. + */ + private async refreshIfNeeded(serverUrlHash: string) { + const now = Date.now() + debugLog('Refresh scan evaluating server', { serverUrlHash, isoNow: new Date(now).toISOString() }) + const failureUntil = this.failureBackoff.get(serverUrlHash) + if (failureUntil && failureUntil > now) { + debugLog('Skipping refresh due to backoff', { + serverUrlHash, + isoNow: new Date(now).toISOString(), + nextEligibleAt: new Date(failureUntil).toISOString(), + millisUntilRetry: failureUntil - now, + }) + return + } + + const registration = await readServerRegistration(serverUrlHash) + if (!registration) { + debugLog('Skipping refresh - server registration missing', { serverUrlHash }) + return + } + + const provider = this.createProvider(serverUrlHash, registration) + const tokens = await provider.tokens() + + if (!tokens) { + debugLog('No tokens available for refresh', { serverUrlHash }) + return + } + + if (!tokens.refresh_token) { + debugLog('Stored tokens do not include a refresh_token', { serverUrlHash }) + return + } + + const state = await readTokenState(serverUrlHash) + if (!isTokenExpiringSoon(state, this.leadTimeMs)) { + if (state?.expiresAt) { + debugLog('Token not yet within refresh window', { + serverUrlHash, + ...formatTimingDebug(state.expiresAt, this.leadTimeMs, now), + }) + } else { + debugLog('Token state unavailable or missing expiry, skipping refresh window check', { + serverUrlHash, + isoNow: new Date(now).toISOString(), + }) + } + return + } + + if (state?.expiresAt) { + debugLog('Token requires refresh', { + serverUrlHash, + ...formatTimingDebug(state.expiresAt, this.leadTimeMs, now), + }) + } else { + debugLog('Token marked for refresh despite missing expiry metadata', { + serverUrlHash, + isoNow: new Date(now).toISOString(), + }) + } + + const acquired = await tryAcquireRefreshLock(serverUrlHash, this.lockTtlMs) + if (!acquired) { + debugLog('Skipped refresh because another process holds the lock', { serverUrlHash }) + return + } + + if (state?.expiresAt) { + log(`Refreshing OAuth tokens for ${registration.serverUrl}`, formatTimingDebug(state.expiresAt, this.leadTimeMs, now)) + } else { + log(`Refreshing OAuth tokens for ${registration.serverUrl}`, { isoNow: new Date(now).toISOString() }) + } + + try { + await this.performRefresh(serverUrlHash, registration, provider, tokens.refresh_token) + this.failureBackoff.delete(serverUrlHash) + log(`Refreshed OAuth tokens for ${registration.serverUrl}`) + } catch (error) { + this.failureBackoff.set(serverUrlHash, Date.now() + this.failureBackoffMs) + const message = this.formatError(error) + log(`Failed to refresh OAuth tokens for ${registration.serverUrl}: ${message}`) + await writeTokenState(serverUrlHash, { + lastRefreshAttempt: Date.now(), + lastRefreshError: message, + }) + debugLog('Token refresh attempt failed', { + serverUrlHash, + error: message, + stack: (error as Error).stack, + }) + return + } finally { + await releaseRefreshLock(serverUrlHash) + } + } + + /** + * Constructs a minimal OAuth client provider used solely for the refresh exchange. + */ + private createProvider(serverUrlHash: string, registration: ServerRegistration) { + return new NodeOAuthClientProvider({ + serverUrl: registration.serverUrl, + callbackPort: registration.callbackPort ?? 0, + host: registration.host ?? 'localhost', + authorizeResource: registration.authorizeResource, + staticOAuthClientMetadata: registration.staticOAuthClientMetadata, + staticOAuthClientInfo: registration.staticOAuthClientInfo, + serverUrlHash, + clientName: 'MCP CLI Auto Refresh', + }) + } + + /** + * Executes the OAuth refresh token grant and persists any returned credentials/metadata. + */ + private async performRefresh( + serverUrlHash: string, + registration: ServerRegistration, + provider: NodeOAuthClientProvider, + refreshToken: string, + ) { + const clientInformation = await provider.clientInformation() + if (!clientInformation) { + throw new Error('Missing OAuth client registration information') + } + + const { authorizationServerUrl, metadata, resource } = await this.resolveAuthorizationContext(registration.serverUrl, provider) + + debugLog('Attempting token refresh', { + serverUrlHash, + authorizationServerUrl: authorizationServerUrl.toString(), + resource: resource?.toString(), + }) + + const newTokens = await refreshAuthorization(authorizationServerUrl, { + metadata, + clientInformation, + refreshToken, + resource, + addClientAuthentication: provider.addClientAuthentication, + }) + + await provider.saveTokens(newTokens) + await writeTokenState(serverUrlHash, { + lastRefreshAttempt: Date.now(), + lastRefreshError: undefined, + }) + } + + /** + * Discovers the relevant authorization server metadata and resource indicators to reuse + * during refresh exchanges. + */ + private async resolveAuthorizationContext(serverUrl: string, provider: NodeOAuthClientProvider) { + let resourceMetadata: Awaited> | undefined + let authorizationServerUrl: string | URL | undefined + + try { + resourceMetadata = await discoverOAuthProtectedResourceMetadata(serverUrl) + if (resourceMetadata?.authorization_servers?.length) { + authorizationServerUrl = resourceMetadata.authorization_servers[0] + } + } catch (error) { + debugLog('Failed to load protected resource metadata', { + serverUrl, + error: (error as Error).message, + }) + } + + if (!authorizationServerUrl) { + authorizationServerUrl = serverUrl + } + + const metadata = await discoverAuthorizationServerMetadata(authorizationServerUrl, {}) + const resource = await selectResourceURL(serverUrl, provider, resourceMetadata) + + return { authorizationServerUrl, metadata, resource } + } + + /** + * Produces a concise error string for logging/backoff bookkeeping. + */ + private formatError(error: unknown): string { + if (error instanceof OAuthError) { + const code = (error as any).errorCode ? ` (${(error as any).errorCode})` : '' + return `${error.name}${code}: ${error.message}` + } + if (error instanceof Error) { + return error.message + } + return String(error) + } +} + +function formatTimingDebug(expiresAt: number, leadTimeMs: number, now: number) { + const refreshThreshold = expiresAt - leadTimeMs + const millisUntilExpiry = expiresAt - now + const millisUntilRefreshWindow = refreshThreshold - now + + return { + isoNow: new Date(now).toISOString(), + isoExpiresAt: new Date(expiresAt).toISOString(), + isoRefreshThreshold: new Date(refreshThreshold).toISOString(), + millisUntilExpiry, + millisUntilRefreshWindow, + secondsUntilExpiry: Math.max(0, Math.round(millisUntilExpiry / 1000)), + secondsUntilRefreshWindow: Math.max(0, Math.round(millisUntilRefreshWindow / 1000)), + } +} + +/** + * Helper that determines whether a token is expired or will expire within the provided lead time. + */ +export function isTokenExpiringSoon(state: TokenState | undefined, leadTimeMs: number, now: number = Date.now()): boolean { + if (!state || typeof state.expiresAt !== 'number') { + return false + } + if (state.expiresAt <= now) { + return true + } + return state.expiresAt - now <= leadTimeMs +} diff --git a/src/lib/utils.ts b/src/lib/utils.ts index e48fb0c..6c2a6ff 100644 --- a/src/lib/utils.ts +++ b/src/lib/utils.ts @@ -6,7 +6,7 @@ import { Transport } from '@modelcontextprotocol/sdk/shared/transport.js' import { OAuthError } from '@modelcontextprotocol/sdk/server/auth/errors.js' import { OAuthClientInformationFull, OAuthClientInformationFullSchema } from '@modelcontextprotocol/sdk/shared/auth.js' import { OAuthCallbackServerOptions, StaticOAuthClientInformationFull, StaticOAuthClientMetadata } from './types' -import { getConfigDir, getConfigFilePath, readJsonFile } from './mcp-auth-config' +import { getConfigDir, getConfigFilePath, readJsonFile, saveServerRegistration } from './mcp-auth-config' import express from 'express' import net from 'net' import crypto from 'crypto' @@ -595,13 +595,40 @@ export async function findAvailablePort(preferredPort?: number): Promise }) } +/** + * Normalized token refresh configuration returned from CLI parsing. We keep this separate from + * `parseCommandLineArgs`' internal state so the caller (client/proxy entrypoint) can pass the values + * directly into `TokenRefreshManager` without knowing the individual CLI flag names. + */ +export interface AutoRefreshOptions { + enabled: boolean + intervalMs: number + leadTimeMs: number + backoffMs: number +} + +/** + * Optional knobs that control how `parseCommandLineArgs` behaves depending on the binary that calls it. + * We expose this as a typed interface instead of hard-coding defaults inside the function so both + * `client.ts` and `proxy.ts` can choose different default behaviors (e.g., opt-out vs. opt-in refresh) + * while still sharing the same parser implementation. + */ +export interface ParseCommandLineArgsOptions { + + /** + * Whether background token refresh should be enabled by default. + * Defaults to true. + */ + defaultAutoRefreshEnabled?: boolean +} + /** * Parses command line arguments for MCP clients and proxies * @param args Command line arguments * @param usage Usage message to show on error * @returns A promise that resolves to an object with parsed serverUrl, callbackPort and headers */ -export async function parseCommandLineArgs(args: string[], usage: string) { +export async function parseCommandLineArgs(args: string[], usage: string, defaults: ParseCommandLineArgsOptions = {}) { // Process headers const headers: Record = {} let i = 0 @@ -625,6 +652,14 @@ export async function parseCommandLineArgs(args: string[], usage: string) { const specifiedPort = args[1] ? parseInt(args[1]) : undefined const allowHttp = args.includes('--allow-http') + let enableAutoRefresh = defaults.defaultAutoRefreshEnabled ?? true + if (args.includes('--enable-auto-refresh')) { + enableAutoRefresh = true + } + if (args.includes('--disable-auto-refresh')) { + enableAutoRefresh = false + } + // Check for debug flag const debug = args.includes('--debug') if (debug) { @@ -726,6 +761,42 @@ export async function parseCommandLineArgs(args: string[], usage: string) { } } + let refreshLeadMs = 10 * 60 * 1000 + const refreshLeadIndex = args.indexOf('--refresh-lead') + if (refreshLeadIndex !== -1 && refreshLeadIndex < args.length - 1) { + const leadSeconds = parseInt(args[refreshLeadIndex + 1], 10) + if (!isNaN(leadSeconds) && leadSeconds > 0) { + refreshLeadMs = leadSeconds * 1000 + log(`Using token refresh lead time: ${leadSeconds} seconds`) + } else { + log(`Warning: Ignoring invalid refresh lead value: ${args[refreshLeadIndex + 1]}. Must be a positive number.`) + } + } + + let refreshIntervalMs = 60 * 1000 + const refreshIntervalIndex = args.indexOf('--refresh-interval') + if (refreshIntervalIndex !== -1 && refreshIntervalIndex < args.length - 1) { + const intervalSeconds = parseInt(args[refreshIntervalIndex + 1], 10) + if (!isNaN(intervalSeconds) && intervalSeconds > 0) { + refreshIntervalMs = intervalSeconds * 1000 + log(`Using token refresh interval: ${intervalSeconds} seconds`) + } else { + log(`Warning: Ignoring invalid refresh interval value: ${args[refreshIntervalIndex + 1]}. Must be a positive number.`) + } + } + + let refreshBackoffMs = 5 * 60 * 1000 + const refreshBackoffIndex = args.indexOf('--refresh-backoff') + if (refreshBackoffIndex !== -1 && refreshBackoffIndex < args.length - 1) { + const backoffSeconds = parseInt(args[refreshBackoffIndex + 1], 10) + if (!isNaN(backoffSeconds) && backoffSeconds > 0) { + refreshBackoffMs = backoffSeconds * 1000 + log(`Using token refresh backoff: ${backoffSeconds} seconds`) + } else { + log(`Warning: Ignoring invalid refresh backoff value: ${args[refreshBackoffIndex + 1]}. Must be a positive number.`) + } + } + if (!serverUrl) { log(usage) process.exit(1) @@ -789,6 +860,19 @@ export async function parseCommandLineArgs(args: string[], usage: string) { }) } + try { + await saveServerRegistration(serverUrlHash, { + serverUrl, + host, + callbackPort, + authorizeResource: authorizeResource || undefined, + staticOAuthClientMetadata, + staticOAuthClientInfo, + }) + } catch (error) { + log('Warning: Unable to persist server registration data:', error) + } + return { serverUrl, callbackPort, @@ -802,6 +886,12 @@ export async function parseCommandLineArgs(args: string[], usage: string) { ignoredTools, authTimeoutMs, serverUrlHash, + autoRefresh: { + enabled: enableAutoRefresh, + intervalMs: refreshIntervalMs, + leadTimeMs: refreshLeadMs, + backoffMs: refreshBackoffMs, + } satisfies AutoRefreshOptions, } } diff --git a/src/proxy.ts b/src/proxy.ts index 6a6ad5b..3722638 100644 --- a/src/proxy.ts +++ b/src/proxy.ts @@ -11,10 +11,19 @@ import { EventEmitter } from 'events' import { StdioServerTransport } from '@modelcontextprotocol/sdk/server/stdio.js' -import { connectToRemoteServer, log, mcpProxy, parseCommandLineArgs, setupSignalHandlers, TransportStrategy } from './lib/utils' +import { + connectToRemoteServer, + log, + mcpProxy, + parseCommandLineArgs, + setupSignalHandlers, + TransportStrategy, + AutoRefreshOptions, +} from './lib/utils' import { StaticOAuthClientInformationFull, StaticOAuthClientMetadata } from './lib/types' import { NodeOAuthClientProvider } from './lib/node-oauth-client-provider' import { createLazyAuthCoordinator } from './lib/coordination' +import { TokenRefreshManager } from './lib/token-refresh-manager' /** * Main function to run the proxy @@ -31,6 +40,7 @@ async function runProxy( ignoredTools: string[], authTimeoutMs: number, serverUrlHash: string, + autoRefresh: AutoRefreshOptions, ) { // Set up event emitter for auth flow const events = new EventEmitter() @@ -50,6 +60,14 @@ async function runProxy( serverUrlHash, }) + const refreshManager = new TokenRefreshManager({ + enabled: autoRefresh.enabled, + intervalMs: autoRefresh.intervalMs, + leadTimeMs: autoRefresh.leadTimeMs, + failureBackoffMs: autoRefresh.backoffMs, + }) + refreshManager.start() + // Create the STDIO transport for local connections const localTransport = new StdioServerTransport() @@ -96,6 +114,7 @@ async function runProxy( // Setup cleanup handler const cleanup = async () => { + refreshManager.stop() await remoteTransport.close() await localTransport.close() // Only close the server if it was initialized @@ -106,6 +125,7 @@ async function runProxy( setupSignalHandlers(cleanup) } catch (error) { log('Fatal error:', error) + refreshManager.stop() if (error instanceof Error && error.message.includes('self-signed certificate in certificate chain')) { log(`You may be behind a VPN! @@ -152,6 +172,7 @@ parseCommandLineArgs(process.argv.slice(2), 'Usage: npx tsx proxy.ts { return runProxy( serverUrl, @@ -165,6 +186,7 @@ parseCommandLineArgs(process.argv.slice(2), 'Usage: npx tsx proxy.ts