feat: short timeout has now been added for clients like cursor

This commit is contained in:
stevehuynh 2025-05-14 20:40:43 +10:00
parent bd75a1cdf0
commit c1fe647a48
2 changed files with 56 additions and 7 deletions

View file

@ -3,10 +3,14 @@ import { Client } from '@modelcontextprotocol/sdk/client/index.js'
import { SSEClientTransport } from '@modelcontextprotocol/sdk/client/sse.js' import { SSEClientTransport } from '@modelcontextprotocol/sdk/client/sse.js'
import { StreamableHTTPClientTransport } from '@modelcontextprotocol/sdk/client/streamableHttp.js' import { StreamableHTTPClientTransport } from '@modelcontextprotocol/sdk/client/streamableHttp.js'
import { Transport } from '@modelcontextprotocol/sdk/shared/transport.js' import { Transport } from '@modelcontextprotocol/sdk/shared/transport.js'
import fs from 'fs'
import path from 'path'
import os from 'os'
// Connection constants // Connection constants
export const REASON_AUTH_NEEDED = 'authentication-needed' export const REASON_AUTH_NEEDED = 'authentication-needed'
export const REASON_TRANSPORT_FALLBACK = 'falling-back-to-alternate-transport' export const REASON_TRANSPORT_FALLBACK = 'falling-back-to-alternate-transport'
export const SHORT_TIMEOUT_DURATION = 50000
// Transport strategy types // Transport strategy types
export type TransportStrategy = 'sse-only' | 'http-only' | 'sse-first' | 'http-first' export type TransportStrategy = 'sse-only' | 'http-only' | 'sse-first' | 'http-first'
@ -24,6 +28,27 @@ export function log(str: string, ...rest: unknown[]) {
console.error(`[${pid}] ${str}`, ...rest) console.error(`[${pid}] ${str}`, ...rest)
} }
/**
* Clears MCP auth files after short timeout
* Used for clients like Cursor or Claude Desktop with short timeouts
*/
function clearMcpAuthFiles() {
try {
const mcpDir = path.join(os.homedir(), '.mcp-auth')
log('Short timeout reached, clearing MCP auth files')
// Check if the directory exists
if (fs.existsSync(mcpDir)) {
// Delete the entire directory and its contents
fs.rmSync(mcpDir, { recursive: true, force: true })
log('MCP auth directory cleared successfully')
} else {
log('No MCP directory found, nothing to clear')
}
} catch (error) {
log('Error clearing MCP auth files:', error)
}
}
/** /**
* Creates a bidirectional proxy between two transports * Creates a bidirectional proxy between two transports
* @param params The transport connections to proxy between * @param params The transport connections to proxy between
@ -97,6 +122,7 @@ export type AuthInitializer = () => Promise<{
* @param authInitializer Function to initialize authentication when needed * @param authInitializer Function to initialize authentication when needed
* @param transportStrategy Strategy for selecting transport type ('sse-only', 'http-only', 'sse-first', 'http-first') * @param transportStrategy Strategy for selecting transport type ('sse-only', 'http-only', 'sse-first', 'http-first')
* @param recursionReasons Set of reasons for recursive calls (internal use) * @param recursionReasons Set of reasons for recursive calls (internal use)
* @param shortTimeout Whether to use a short timeout (for clients like Cursor or Claude Desktop)
* @returns The connected transport * @returns The connected transport
*/ */
export async function connectToRemoteServer( export async function connectToRemoteServer(
@ -106,7 +132,8 @@ export async function connectToRemoteServer(
headers: Record<string, string>, headers: Record<string, string>,
authInitializer: AuthInitializer, authInitializer: AuthInitializer,
transportStrategy: TransportStrategy = 'http-first', transportStrategy: TransportStrategy = 'http-first',
recursionReasons: Set<string> = new Set(), shortTimeout: boolean = false,
recursionReasons: Set<string> = new Set()
): Promise<Transport> { ): Promise<Transport> {
log(`[${pid}] Connecting to remote server: ${serverUrl}`) log(`[${pid}] Connecting to remote server: ${serverUrl}`)
const url = new URL(serverUrl) const url = new URL(serverUrl)
@ -196,7 +223,8 @@ export async function connectToRemoteServer(
headers, headers,
authInitializer, authInitializer,
sseTransport ? 'http-only' : 'sse-only', sseTransport ? 'http-only' : 'sse-only',
recursionReasons, shortTimeout,
recursionReasons
) )
} else if (error instanceof UnauthorizedError || (error instanceof Error && error.message.includes('Unauthorized'))) { } else if (error instanceof UnauthorizedError || (error instanceof Error && error.message.includes('Unauthorized'))) {
log('Authentication required. Initializing auth...') log('Authentication required. Initializing auth...')
@ -204,6 +232,15 @@ export async function connectToRemoteServer(
// Initialize authentication on-demand // Initialize authentication on-demand
const { waitForAuthCode, skipBrowserAuth } = await authInitializer() const { waitForAuthCode, skipBrowserAuth } = await authInitializer()
// Set up short timeout if enabled
let shortTimeoutTimer: NodeJS.Timeout | null = null
if (shortTimeout) {
log(`Short timeout enabled, will clear auth files after ${SHORT_TIMEOUT_DURATION / 1000} seconds`)
shortTimeoutTimer = setTimeout(() => {
clearMcpAuthFiles()
}, SHORT_TIMEOUT_DURATION)
}
if (skipBrowserAuth) { if (skipBrowserAuth) {
log('Authentication required but skipping browser auth - using shared auth') log('Authentication required but skipping browser auth - using shared auth')
} else { } else {
@ -214,6 +251,11 @@ export async function connectToRemoteServer(
const code = await waitForAuthCode() const code = await waitForAuthCode()
try { try {
// Clear the timeout if auth completes successfully
if (shortTimeoutTimer) {
clearTimeout(shortTimeoutTimer)
}
log('Completing authorization...') log('Completing authorization...')
await transport.finishAuth(code) await transport.finishAuth(code)
@ -228,8 +270,13 @@ export async function connectToRemoteServer(
log(`Recursively reconnecting for reason: ${REASON_AUTH_NEEDED}`) log(`Recursively reconnecting for reason: ${REASON_AUTH_NEEDED}`)
// Recursively call connectToRemoteServer with the updated recursion tracking // Recursively call connectToRemoteServer with the updated recursion tracking
return connectToRemoteServer(client, serverUrl, authProvider, headers, authInitializer, transportStrategy, recursionReasons) return connectToRemoteServer(client, serverUrl, authProvider, headers, authInitializer, transportStrategy, shortTimeout)
} catch (authError) { } catch (authError) {
// Clear the timeout if auth fails
if (shortTimeoutTimer) {
clearTimeout(shortTimeoutTimer)
}
log('Authorization error:', authError) log('Authorization error:', authError)
throw authError throw authError
} }
@ -412,6 +459,7 @@ export async function parseCommandLineArgs(args: string[], defaultPort: number,
const serverUrl = args[0] const serverUrl = args[0]
const specifiedPort = args[1] ? parseInt(args[1]) : undefined const specifiedPort = args[1] ? parseInt(args[1]) : undefined
const allowHttp = args.includes('--allow-http') const allowHttp = args.includes('--allow-http')
const shortTimeout = args.includes('--short-timeout')
// Parse transport strategy // Parse transport strategy
let transportStrategy: TransportStrategy = 'http-first' // Default let transportStrategy: TransportStrategy = 'http-first' // Default
@ -468,7 +516,7 @@ export async function parseCommandLineArgs(args: string[], defaultPort: number,
}) })
} }
return { serverUrl, callbackPort, headers, transportStrategy } return { serverUrl, callbackPort, headers, transportStrategy, shortTimeout }
} }
/** /**

View file

@ -32,6 +32,7 @@ async function runProxy(
callbackPort: number, callbackPort: number,
headers: Record<string, string>, headers: Record<string, string>,
transportStrategy: TransportStrategy = 'http-first', transportStrategy: TransportStrategy = 'http-first',
shortTimeout: boolean = false,
) { ) {
// Set up event emitter for auth flow // Set up event emitter for auth flow
const events = new EventEmitter() const events = new EventEmitter()
@ -78,7 +79,7 @@ async function runProxy(
try { try {
// Connect to remote server with lazy authentication // Connect to remote server with lazy authentication
const remoteTransport = await connectToRemoteServer(null, serverUrl, authProvider, headers, authInitializer, transportStrategy) const remoteTransport = await connectToRemoteServer(null, serverUrl, authProvider, headers, authInitializer, transportStrategy, shortTimeout)
// Set up bidirectional proxy between local and remote transports // Set up bidirectional proxy between local and remote transports
mcpProxy({ mcpProxy({
@ -136,8 +137,8 @@ to the CA certificate file. If using claude_desktop_config.json, this might look
// Parse command-line arguments and run the proxy // Parse command-line arguments and run the proxy
parseCommandLineArgs(process.argv.slice(2), 3334, 'Usage: npx tsx proxy.ts <https://server-url> [callback-port]') parseCommandLineArgs(process.argv.slice(2), 3334, 'Usage: npx tsx proxy.ts <https://server-url> [callback-port]')
.then(({ serverUrl, callbackPort, headers, transportStrategy }) => { .then(({ serverUrl, callbackPort, headers, transportStrategy, shortTimeout }) => {
return runProxy(serverUrl, callbackPort, headers, transportStrategy) return runProxy(serverUrl, callbackPort, headers, transportStrategy, shortTimeout)
}) })
.catch((error) => { .catch((error) => {
log('Fatal error:', error) log('Fatal error:', error)