diff --git a/src/lib/utils.ts b/src/lib/utils.ts index 572550c..b11c9b3 100644 --- a/src/lib/utils.ts +++ b/src/lib/utils.ts @@ -3,10 +3,14 @@ import { Client } from '@modelcontextprotocol/sdk/client/index.js' import { SSEClientTransport } from '@modelcontextprotocol/sdk/client/sse.js' import { StreamableHTTPClientTransport } from '@modelcontextprotocol/sdk/client/streamableHttp.js' import { Transport } from '@modelcontextprotocol/sdk/shared/transport.js' +import fs from 'fs' +import path from 'path' +import os from 'os' // Connection constants export const REASON_AUTH_NEEDED = 'authentication-needed' export const REASON_TRANSPORT_FALLBACK = 'falling-back-to-alternate-transport' +export const SHORT_TIMEOUT_DURATION = 50000 // Transport strategy types export type TransportStrategy = 'sse-only' | 'http-only' | 'sse-first' | 'http-first' @@ -24,6 +28,27 @@ export function log(str: string, ...rest: unknown[]) { console.error(`[${pid}] ${str}`, ...rest) } +/** + * Clears MCP auth files after short timeout + * Used for clients like Cursor or Claude Desktop with short timeouts + */ +function clearMcpAuthFiles() { + try { + const mcpDir = path.join(os.homedir(), '.mcp-auth') + log('Short timeout reached, clearing MCP auth files') + // Check if the directory exists + if (fs.existsSync(mcpDir)) { + // Delete the entire directory and its contents + fs.rmSync(mcpDir, { recursive: true, force: true }) + log('MCP auth directory cleared successfully') + } else { + log('No MCP directory found, nothing to clear') + } + } catch (error) { + log('Error clearing MCP auth files:', error) + } +} + /** * Creates a bidirectional proxy between two transports * @param params The transport connections to proxy between @@ -97,6 +122,7 @@ export type AuthInitializer = () => Promise<{ * @param authInitializer Function to initialize authentication when needed * @param transportStrategy Strategy for selecting transport type ('sse-only', 'http-only', 'sse-first', 'http-first') * @param recursionReasons Set of reasons for recursive calls (internal use) + * @param shortTimeout Whether to use a short timeout (for clients like Cursor or Claude Desktop) * @returns The connected transport */ export async function connectToRemoteServer( @@ -106,7 +132,8 @@ export async function connectToRemoteServer( headers: Record, authInitializer: AuthInitializer, transportStrategy: TransportStrategy = 'http-first', - recursionReasons: Set = new Set(), + shortTimeout: boolean = false, + recursionReasons: Set = new Set() ): Promise { log(`[${pid}] Connecting to remote server: ${serverUrl}`) const url = new URL(serverUrl) @@ -196,7 +223,8 @@ export async function connectToRemoteServer( headers, authInitializer, sseTransport ? 'http-only' : 'sse-only', - recursionReasons, + shortTimeout, + recursionReasons ) } else if (error instanceof UnauthorizedError || (error instanceof Error && error.message.includes('Unauthorized'))) { log('Authentication required. Initializing auth...') @@ -204,6 +232,15 @@ export async function connectToRemoteServer( // Initialize authentication on-demand const { waitForAuthCode, skipBrowserAuth } = await authInitializer() + // Set up short timeout if enabled + let shortTimeoutTimer: NodeJS.Timeout | null = null + if (shortTimeout) { + log(`Short timeout enabled, will clear auth files after ${SHORT_TIMEOUT_DURATION / 1000} seconds`) + shortTimeoutTimer = setTimeout(() => { + clearMcpAuthFiles() + }, SHORT_TIMEOUT_DURATION) + } + if (skipBrowserAuth) { log('Authentication required but skipping browser auth - using shared auth') } else { @@ -214,6 +251,11 @@ export async function connectToRemoteServer( const code = await waitForAuthCode() try { + // Clear the timeout if auth completes successfully + if (shortTimeoutTimer) { + clearTimeout(shortTimeoutTimer) + } + log('Completing authorization...') await transport.finishAuth(code) @@ -228,8 +270,13 @@ export async function connectToRemoteServer( log(`Recursively reconnecting for reason: ${REASON_AUTH_NEEDED}`) // Recursively call connectToRemoteServer with the updated recursion tracking - return connectToRemoteServer(client, serverUrl, authProvider, headers, authInitializer, transportStrategy, recursionReasons) + return connectToRemoteServer(client, serverUrl, authProvider, headers, authInitializer, transportStrategy, shortTimeout) } catch (authError) { + // Clear the timeout if auth fails + if (shortTimeoutTimer) { + clearTimeout(shortTimeoutTimer) + } + log('Authorization error:', authError) throw authError } @@ -412,6 +459,7 @@ export async function parseCommandLineArgs(args: string[], defaultPort: number, const serverUrl = args[0] const specifiedPort = args[1] ? parseInt(args[1]) : undefined const allowHttp = args.includes('--allow-http') + const shortTimeout = args.includes('--short-timeout') // Parse transport strategy let transportStrategy: TransportStrategy = 'http-first' // Default @@ -468,7 +516,7 @@ export async function parseCommandLineArgs(args: string[], defaultPort: number, }) } - return { serverUrl, callbackPort, headers, transportStrategy } + return { serverUrl, callbackPort, headers, transportStrategy, shortTimeout } } /** diff --git a/src/proxy.ts b/src/proxy.ts index 7263a95..58e29f3 100644 --- a/src/proxy.ts +++ b/src/proxy.ts @@ -32,6 +32,7 @@ async function runProxy( callbackPort: number, headers: Record, transportStrategy: TransportStrategy = 'http-first', + shortTimeout: boolean = false, ) { // Set up event emitter for auth flow const events = new EventEmitter() @@ -78,7 +79,7 @@ async function runProxy( try { // Connect to remote server with lazy authentication - const remoteTransport = await connectToRemoteServer(null, serverUrl, authProvider, headers, authInitializer, transportStrategy) + const remoteTransport = await connectToRemoteServer(null, serverUrl, authProvider, headers, authInitializer, transportStrategy, shortTimeout) // Set up bidirectional proxy between local and remote transports mcpProxy({ @@ -136,8 +137,8 @@ to the CA certificate file. If using claude_desktop_config.json, this might look // Parse command-line arguments and run the proxy parseCommandLineArgs(process.argv.slice(2), 3334, 'Usage: npx tsx proxy.ts [callback-port]') - .then(({ serverUrl, callbackPort, headers, transportStrategy }) => { - return runProxy(serverUrl, callbackPort, headers, transportStrategy) + .then(({ serverUrl, callbackPort, headers, transportStrategy, shortTimeout }) => { + return runProxy(serverUrl, callbackPort, headers, transportStrategy, shortTimeout) }) .catch((error) => { log('Fatal error:', error)