Able to catch exceptions from one transport and fall back to the other

This commit is contained in:
Glen Maddern 2025-04-17 16:04:50 +10:00
parent 0bf84d5d22
commit 14109a309f
5 changed files with 104 additions and 51 deletions

View file

@ -28,11 +28,11 @@
"check": "prettier --check . && tsc" "check": "prettier --check . && tsc"
}, },
"dependencies": { "dependencies": {
"@modelcontextprotocol/sdk": "link:/Users/glen/src/hax/mcp/typescript-sdk",
"express": "^4.21.2", "express": "^4.21.2",
"open": "^10.1.0" "open": "^10.1.0"
}, },
"devDependencies": { "devDependencies": {
"@modelcontextprotocol/sdk": "link:/Users/glen/src/hax/mcp/typescript-sdk",
"@types/express": "^5.0.0", "@types/express": "^5.0.0",
"@types/node": "^22.13.10", "@types/node": "^22.13.10",
"@types/react": "^19.0.12", "@types/react": "^19.0.12",
@ -54,7 +54,8 @@
"clean": true, "clean": true,
"outDir": "dist", "outDir": "dist",
"external": [ "external": [
"react" "react",
"@modelcontextprotocol/sdk"
] ]
} }
} }

6
pnpm-lock.yaml generated
View file

@ -8,6 +8,9 @@ importers:
.: .:
dependencies: dependencies:
'@modelcontextprotocol/sdk':
specifier: link:/Users/glen/src/hax/mcp/typescript-sdk
version: link:../../hax/mcp/typescript-sdk
express: express:
specifier: ^4.21.2 specifier: ^4.21.2
version: 4.21.2 version: 4.21.2
@ -15,9 +18,6 @@ importers:
specifier: ^10.1.0 specifier: ^10.1.0
version: 10.1.0 version: 10.1.0
devDependencies: devDependencies:
'@modelcontextprotocol/sdk':
specifier: link:/Users/glen/src/hax/mcp/typescript-sdk
version: link:../../hax/mcp/typescript-sdk
'@types/express': '@types/express':
specifier: ^5.0.0 specifier: ^5.0.0
version: 5.0.0 version: 5.0.0

View file

@ -70,7 +70,15 @@ async function runClient(
try { try {
// Connect to remote server with authentication // Connect to remote server with authentication
const transport = await connectToRemoteServer(serverUrl, authProvider, headers, waitForAuthCode, skipBrowserAuth, transportStrategy) const transport = await connectToRemoteServer(
client,
serverUrl,
authProvider,
headers,
waitForAuthCode,
skipBrowserAuth,
transportStrategy,
)
// Set up message and error handlers // Set up message and error handlers
transport.onmessage = (message) => { transport.onmessage = (message) => {
@ -94,9 +102,6 @@ async function runClient(
} }
setupSignalHandlers(cleanup) setupSignalHandlers(cleanup)
// Connect the client
log('Connecting client...')
await client.connect(transport)
log('Connected successfully!') log('Connected successfully!')
try { try {
@ -117,7 +122,10 @@ async function runClient(
log('Error requesting resources list:', e) log('Error requesting resources list:', e)
} }
log('Listening for messages. Press Ctrl+C to exit.') // log('Listening for messages. Press Ctrl+C to exit.')
log('Exiting OK...')
server.close()
process.exit(0)
} catch (error) { } catch (error) {
log('Fatal error:', error) log('Fatal error:', error)
server.close() server.close()

View file

@ -1,4 +1,5 @@
import { OAuthClientProvider, UnauthorizedError } from '@modelcontextprotocol/sdk/client/auth.js' import { OAuthClientProvider, UnauthorizedError } from '@modelcontextprotocol/sdk/client/auth.js'
import { Client } from '@modelcontextprotocol/sdk/client/index.js'
import { SSEClientTransport } from '@modelcontextprotocol/sdk/client/sse.js' import { SSEClientTransport } from '@modelcontextprotocol/sdk/client/sse.js'
import { StreamableHTTPClientTransport } from '@modelcontextprotocol/sdk/client/streamableHttp.js' import { StreamableHTTPClientTransport } from '@modelcontextprotocol/sdk/client/streamableHttp.js'
import { Transport } from '@modelcontextprotocol/sdk/shared/transport.js' import { Transport } from '@modelcontextprotocol/sdk/shared/transport.js'
@ -74,6 +75,7 @@ export function mcpProxy({ transportToClient, transportToServer }: { transportTo
/** /**
* Creates and connects to a remote server with OAuth authentication * Creates and connects to a remote server with OAuth authentication
* @param client The client to connect with
* @param serverUrl The URL of the remote server * @param serverUrl The URL of the remote server
* @param authProvider The OAuth client provider * @param authProvider The OAuth client provider
* @param headers Additional headers to send with the request * @param headers Additional headers to send with the request
@ -84,6 +86,7 @@ export function mcpProxy({ transportToClient, transportToServer }: { transportTo
* @returns The connected transport * @returns The connected transport
*/ */
export async function connectToRemoteServer( export async function connectToRemoteServer(
client: Client,
serverUrl: string, serverUrl: string,
authProvider: OAuthClientProvider, authProvider: OAuthClientProvider,
headers: Record<string, string>, headers: Record<string, string>,
@ -112,54 +115,68 @@ export async function connectToRemoteServer(
}, },
} }
// Choose transport based on user strategy and recursion history log(`Using transport strategy: ${transportStrategy}`)
let transport;
let shouldAttemptFallback = false;
// If we've already tried falling back once, throw an error
if (recursionReasons.has(REASON_TRANSPORT_FALLBACK)) {
const errorMessage = `Already attempted transport fallback. Giving up.`;
log(errorMessage);
throw new Error(errorMessage);
}
log(`Using transport strategy: ${transportStrategy}`);
// Determine if we should attempt to fallback on error // Determine if we should attempt to fallback on error
shouldAttemptFallback = transportStrategy === 'http-first' || transportStrategy === 'sse-first'; // Choose transport based on user strategy and recursion history
const shouldAttemptFallback = transportStrategy === 'http-first' || transportStrategy === 'sse-first'
// Create transport instance based on the strategy // Create transport instance based on the strategy
if (transportStrategy === 'sse-only' || transportStrategy === 'sse-first') { const sseTransport = transportStrategy === 'sse-only' || transportStrategy === 'sse-first'
transport = new SSEClientTransport(url, { const transport = sseTransport
? new SSEClientTransport(url, {
authProvider, authProvider,
requestInit: { headers }, requestInit: { headers },
eventSourceInit, eventSourceInit,
}); })
} else { // http-only or http-first : new StreamableHTTPClientTransport(url, {
transport = new StreamableHTTPClientTransport(url, { authProvider,
sessionId: crypto.randomUUID(), requestInit: { headers },
}); })
}
try { try {
await transport.start() await client.connect(transport)
log(`Connected to remote server using ${transport.constructor.name}`) log(`Connected to remote server using ${transport.constructor.name}`)
if (!sseTransport) {
console.log({ serverCapabilities: await client.getServerCapabilities() })
}
return transport return transport
} catch (error) { } catch (error) {
console.log('DID I CATCH OR WHAT?')
// Check if it's a protocol error and we should attempt fallback // Check if it's a protocol error and we should attempt fallback
if (error instanceof Error && if (
error instanceof Error &&
shouldAttemptFallback && shouldAttemptFallback &&
(error.message.includes('405') || (sseTransport
error.message.includes('Method Not Allowed') || ? error.message.includes('405') || error.message.includes('Method Not Allowed')
error.message.toLowerCase().includes('protocol'))) { : error.message.includes('404') || error.message.includes('Not Found'))
) {
log(`Received error: ${error.message}`) log(`Received error: ${error.message}`)
// If we've already tried falling back once, throw an error
if (recursionReasons.has(REASON_TRANSPORT_FALLBACK)) {
const errorMessage = `Already attempted transport fallback. Giving up.`
log(errorMessage)
throw new Error(errorMessage)
}
log(`Recursively reconnecting for reason: ${REASON_TRANSPORT_FALLBACK}`) log(`Recursively reconnecting for reason: ${REASON_TRANSPORT_FALLBACK}`)
// Add to recursion reasons set // Add to recursion reasons set
recursionReasons.add(REASON_TRANSPORT_FALLBACK) recursionReasons.add(REASON_TRANSPORT_FALLBACK)
// Recursively call connectToRemoteServer with the updated recursion tracking // Recursively call connectToRemoteServer with the updated recursion tracking
return connectToRemoteServer(serverUrl, authProvider, headers, waitForAuthCode, skipBrowserAuth, transportStrategy, recursionReasons) return connectToRemoteServer(
client,
serverUrl,
authProvider,
headers,
waitForAuthCode,
skipBrowserAuth,
sseTransport ? 'http-only' : 'sse-only',
recursionReasons,
)
} else if (error instanceof UnauthorizedError || (error instanceof Error && error.message.includes('Unauthorized'))) { } else if (error instanceof UnauthorizedError || (error instanceof Error && error.message.includes('Unauthorized'))) {
if (skipBrowserAuth) { if (skipBrowserAuth) {
log('Authentication required but skipping browser auth - using shared auth') log('Authentication required but skipping browser auth - using shared auth')
@ -185,7 +202,16 @@ export async function connectToRemoteServer(
log(`Recursively reconnecting for reason: ${REASON_AUTH_NEEDED}`) log(`Recursively reconnecting for reason: ${REASON_AUTH_NEEDED}`)
// Recursively call connectToRemoteServer with the updated recursion tracking // Recursively call connectToRemoteServer with the updated recursion tracking
return connectToRemoteServer(serverUrl, authProvider, headers, waitForAuthCode, skipBrowserAuth, transportStrategy, recursionReasons) return connectToRemoteServer(
client,
serverUrl,
authProvider,
headers,
waitForAuthCode,
skipBrowserAuth,
transportStrategy,
recursionReasons,
)
} catch (authError) { } catch (authError) {
log('Authorization error:', authError) log('Authorization error:', authError)
throw authError throw authError

View file

@ -11,9 +11,18 @@
import { EventEmitter } from 'events' import { EventEmitter } from 'events'
import { StdioServerTransport } from '@modelcontextprotocol/sdk/server/stdio.js' import { StdioServerTransport } from '@modelcontextprotocol/sdk/server/stdio.js'
import { connectToRemoteServer, log, mcpProxy, parseCommandLineArgs, setupSignalHandlers, getServerUrlHash } from './lib/utils' import {
connectToRemoteServer,
log,
mcpProxy,
parseCommandLineArgs,
setupSignalHandlers,
getServerUrlHash,
MCP_REMOTE_VERSION,
} from './lib/utils'
import { NodeOAuthClientProvider } from './lib/node-oauth-client-provider' import { NodeOAuthClientProvider } from './lib/node-oauth-client-provider'
import { coordinateAuth } from './lib/coordination' import { coordinateAuth } from './lib/coordination'
import { Client } from '@modelcontextprotocol/sdk/client/index.js'
/** /**
* Main function to run the proxy * Main function to run the proxy
@ -47,8 +56,17 @@ async function runProxy(serverUrl: string, callbackPort: number, headers: Record
const localTransport = new StdioServerTransport() const localTransport = new StdioServerTransport()
try { try {
const client = new Client(
{
name: 'mcp-remote',
version: MCP_REMOTE_VERSION,
},
{
capabilities: {},
},
)
// Connect to remote server with authentication // Connect to remote server with authentication
const remoteTransport = await connectToRemoteServer(serverUrl, authProvider, headers, waitForAuthCode, skipBrowserAuth) const remoteTransport = await connectToRemoteServer(client, serverUrl, authProvider, headers, waitForAuthCode, skipBrowserAuth)
// Set up bidirectional proxy between local and remote transports // Set up bidirectional proxy between local and remote transports
mcpProxy({ mcpProxy({