From 14109a309f24594928ee0c8ef8437ae499764ec2 Mon Sep 17 00:00:00 2001 From: Glen Maddern Date: Thu, 17 Apr 2025 16:04:50 +1000 Subject: [PATCH] Able to catch exceptions from one transport and fall back to the other --- package.json | 5 ++- pnpm-lock.yaml | 6 +-- src/client.ts | 18 +++++--- src/lib/utils.ts | 104 +++++++++++++++++++++++++++++------------------ src/proxy.ts | 22 +++++++++- 5 files changed, 104 insertions(+), 51 deletions(-) diff --git a/package.json b/package.json index 17700f0..ebf114d 100644 --- a/package.json +++ b/package.json @@ -28,11 +28,11 @@ "check": "prettier --check . && tsc" }, "dependencies": { + "@modelcontextprotocol/sdk": "link:/Users/glen/src/hax/mcp/typescript-sdk", "express": "^4.21.2", "open": "^10.1.0" }, "devDependencies": { - "@modelcontextprotocol/sdk": "link:/Users/glen/src/hax/mcp/typescript-sdk", "@types/express": "^5.0.0", "@types/node": "^22.13.10", "@types/react": "^19.0.12", @@ -54,7 +54,8 @@ "clean": true, "outDir": "dist", "external": [ - "react" + "react", + "@modelcontextprotocol/sdk" ] } } diff --git a/pnpm-lock.yaml b/pnpm-lock.yaml index b1813cb..dbdb80d 100644 --- a/pnpm-lock.yaml +++ b/pnpm-lock.yaml @@ -8,6 +8,9 @@ importers: .: dependencies: + '@modelcontextprotocol/sdk': + specifier: link:/Users/glen/src/hax/mcp/typescript-sdk + version: link:../../hax/mcp/typescript-sdk express: specifier: ^4.21.2 version: 4.21.2 @@ -15,9 +18,6 @@ importers: specifier: ^10.1.0 version: 10.1.0 devDependencies: - '@modelcontextprotocol/sdk': - specifier: link:/Users/glen/src/hax/mcp/typescript-sdk - version: link:../../hax/mcp/typescript-sdk '@types/express': specifier: ^5.0.0 version: 5.0.0 diff --git a/src/client.ts b/src/client.ts index ac58fb8..669fd3b 100644 --- a/src/client.ts +++ b/src/client.ts @@ -70,7 +70,15 @@ async function runClient( try { // Connect to remote server with authentication - const transport = await connectToRemoteServer(serverUrl, authProvider, headers, waitForAuthCode, skipBrowserAuth, transportStrategy) + const transport = await connectToRemoteServer( + client, + serverUrl, + authProvider, + headers, + waitForAuthCode, + skipBrowserAuth, + transportStrategy, + ) // Set up message and error handlers transport.onmessage = (message) => { @@ -94,9 +102,6 @@ async function runClient( } setupSignalHandlers(cleanup) - // Connect the client - log('Connecting client...') - await client.connect(transport) log('Connected successfully!') try { @@ -117,7 +122,10 @@ async function runClient( log('Error requesting resources list:', e) } - log('Listening for messages. Press Ctrl+C to exit.') + // log('Listening for messages. Press Ctrl+C to exit.') + log('Exiting OK...') + server.close() + process.exit(0) } catch (error) { log('Fatal error:', error) server.close() diff --git a/src/lib/utils.ts b/src/lib/utils.ts index dde1b4c..43d2086 100644 --- a/src/lib/utils.ts +++ b/src/lib/utils.ts @@ -1,4 +1,5 @@ import { OAuthClientProvider, UnauthorizedError } from '@modelcontextprotocol/sdk/client/auth.js' +import { Client } from '@modelcontextprotocol/sdk/client/index.js' import { SSEClientTransport } from '@modelcontextprotocol/sdk/client/sse.js' import { StreamableHTTPClientTransport } from '@modelcontextprotocol/sdk/client/streamableHttp.js' import { Transport } from '@modelcontextprotocol/sdk/shared/transport.js' @@ -74,6 +75,7 @@ export function mcpProxy({ transportToClient, transportToServer }: { transportTo /** * Creates and connects to a remote server with OAuth authentication + * @param client The client to connect with * @param serverUrl The URL of the remote server * @param authProvider The OAuth client provider * @param headers Additional headers to send with the request @@ -84,6 +86,7 @@ export function mcpProxy({ transportToClient, transportToServer }: { transportTo * @returns The connected transport */ export async function connectToRemoteServer( + client: Client, serverUrl: string, authProvider: OAuthClientProvider, headers: Record, @@ -112,54 +115,68 @@ export async function connectToRemoteServer( }, } - // Choose transport based on user strategy and recursion history - let transport; - let shouldAttemptFallback = false; - - // If we've already tried falling back once, throw an error - if (recursionReasons.has(REASON_TRANSPORT_FALLBACK)) { - const errorMessage = `Already attempted transport fallback. Giving up.`; - log(errorMessage); - throw new Error(errorMessage); - } - - log(`Using transport strategy: ${transportStrategy}`); + log(`Using transport strategy: ${transportStrategy}`) // Determine if we should attempt to fallback on error - shouldAttemptFallback = transportStrategy === 'http-first' || transportStrategy === 'sse-first'; - + // Choose transport based on user strategy and recursion history + const shouldAttemptFallback = transportStrategy === 'http-first' || transportStrategy === 'sse-first' + // Create transport instance based on the strategy - if (transportStrategy === 'sse-only' || transportStrategy === 'sse-first') { - transport = new SSEClientTransport(url, { - authProvider, - requestInit: { headers }, - eventSourceInit, - }); - } else { // http-only or http-first - transport = new StreamableHTTPClientTransport(url, { - sessionId: crypto.randomUUID(), - }); - } + const sseTransport = transportStrategy === 'sse-only' || transportStrategy === 'sse-first' + const transport = sseTransport + ? new SSEClientTransport(url, { + authProvider, + requestInit: { headers }, + eventSourceInit, + }) + : new StreamableHTTPClientTransport(url, { + authProvider, + requestInit: { headers }, + }) try { - await transport.start() + await client.connect(transport) log(`Connected to remote server using ${transport.constructor.name}`) + + if (!sseTransport) { + console.log({ serverCapabilities: await client.getServerCapabilities() }) + } + return transport } catch (error) { + console.log('DID I CATCH OR WHAT?') // Check if it's a protocol error and we should attempt fallback - if (error instanceof Error && - shouldAttemptFallback && - (error.message.includes('405') || - error.message.includes('Method Not Allowed') || - error.message.toLowerCase().includes('protocol'))) { - + if ( + error instanceof Error && + shouldAttemptFallback && + (sseTransport + ? error.message.includes('405') || error.message.includes('Method Not Allowed') + : error.message.includes('404') || error.message.includes('Not Found')) + ) { log(`Received error: ${error.message}`) + + // If we've already tried falling back once, throw an error + if (recursionReasons.has(REASON_TRANSPORT_FALLBACK)) { + const errorMessage = `Already attempted transport fallback. Giving up.` + log(errorMessage) + throw new Error(errorMessage) + } + log(`Recursively reconnecting for reason: ${REASON_TRANSPORT_FALLBACK}`) - + // Add to recursion reasons set recursionReasons.add(REASON_TRANSPORT_FALLBACK) - + // Recursively call connectToRemoteServer with the updated recursion tracking - return connectToRemoteServer(serverUrl, authProvider, headers, waitForAuthCode, skipBrowserAuth, transportStrategy, recursionReasons) + return connectToRemoteServer( + client, + serverUrl, + authProvider, + headers, + waitForAuthCode, + skipBrowserAuth, + sseTransport ? 'http-only' : 'sse-only', + recursionReasons, + ) } else if (error instanceof UnauthorizedError || (error instanceof Error && error.message.includes('Unauthorized'))) { if (skipBrowserAuth) { log('Authentication required but skipping browser auth - using shared auth') @@ -179,13 +196,22 @@ export async function connectToRemoteServer( log(errorMessage) throw new Error(errorMessage) } - + // Track this reason for recursion recursionReasons.add(REASON_AUTH_NEEDED) log(`Recursively reconnecting for reason: ${REASON_AUTH_NEEDED}`) - + // Recursively call connectToRemoteServer with the updated recursion tracking - return connectToRemoteServer(serverUrl, authProvider, headers, waitForAuthCode, skipBrowserAuth, transportStrategy, recursionReasons) + return connectToRemoteServer( + client, + serverUrl, + authProvider, + headers, + waitForAuthCode, + skipBrowserAuth, + transportStrategy, + recursionReasons, + ) } catch (authError) { log('Authorization error:', authError) throw authError @@ -356,7 +382,7 @@ export async function parseCommandLineArgs(args: string[], defaultPort: number, const serverUrl = args[0] const specifiedPort = args[1] ? parseInt(args[1]) : undefined const allowHttp = args.includes('--allow-http') - + // Parse transport strategy let transportStrategy: TransportStrategy = 'http-first' // Default const transportIndex = args.indexOf('--transport') diff --git a/src/proxy.ts b/src/proxy.ts index 9fd87d1..a914bb4 100644 --- a/src/proxy.ts +++ b/src/proxy.ts @@ -11,9 +11,18 @@ import { EventEmitter } from 'events' import { StdioServerTransport } from '@modelcontextprotocol/sdk/server/stdio.js' -import { connectToRemoteServer, log, mcpProxy, parseCommandLineArgs, setupSignalHandlers, getServerUrlHash } from './lib/utils' +import { + connectToRemoteServer, + log, + mcpProxy, + parseCommandLineArgs, + setupSignalHandlers, + getServerUrlHash, + MCP_REMOTE_VERSION, +} from './lib/utils' import { NodeOAuthClientProvider } from './lib/node-oauth-client-provider' import { coordinateAuth } from './lib/coordination' +import { Client } from '@modelcontextprotocol/sdk/client/index.js' /** * Main function to run the proxy @@ -47,8 +56,17 @@ async function runProxy(serverUrl: string, callbackPort: number, headers: Record const localTransport = new StdioServerTransport() try { + const client = new Client( + { + name: 'mcp-remote', + version: MCP_REMOTE_VERSION, + }, + { + capabilities: {}, + }, + ) // Connect to remote server with authentication - const remoteTransport = await connectToRemoteServer(serverUrl, authProvider, headers, waitForAuthCode, skipBrowserAuth) + const remoteTransport = await connectToRemoteServer(client, serverUrl, authProvider, headers, waitForAuthCode, skipBrowserAuth) // Set up bidirectional proxy between local and remote transports mcpProxy({