Able to catch exceptions from one transport and fall back to the other
This commit is contained in:
parent
0bf84d5d22
commit
14109a309f
5 changed files with 104 additions and 51 deletions
|
@ -28,11 +28,11 @@
|
||||||
"check": "prettier --check . && tsc"
|
"check": "prettier --check . && tsc"
|
||||||
},
|
},
|
||||||
"dependencies": {
|
"dependencies": {
|
||||||
|
"@modelcontextprotocol/sdk": "link:/Users/glen/src/hax/mcp/typescript-sdk",
|
||||||
"express": "^4.21.2",
|
"express": "^4.21.2",
|
||||||
"open": "^10.1.0"
|
"open": "^10.1.0"
|
||||||
},
|
},
|
||||||
"devDependencies": {
|
"devDependencies": {
|
||||||
"@modelcontextprotocol/sdk": "link:/Users/glen/src/hax/mcp/typescript-sdk",
|
|
||||||
"@types/express": "^5.0.0",
|
"@types/express": "^5.0.0",
|
||||||
"@types/node": "^22.13.10",
|
"@types/node": "^22.13.10",
|
||||||
"@types/react": "^19.0.12",
|
"@types/react": "^19.0.12",
|
||||||
|
@ -54,7 +54,8 @@
|
||||||
"clean": true,
|
"clean": true,
|
||||||
"outDir": "dist",
|
"outDir": "dist",
|
||||||
"external": [
|
"external": [
|
||||||
"react"
|
"react",
|
||||||
|
"@modelcontextprotocol/sdk"
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
6
pnpm-lock.yaml
generated
6
pnpm-lock.yaml
generated
|
@ -8,6 +8,9 @@ importers:
|
||||||
|
|
||||||
.:
|
.:
|
||||||
dependencies:
|
dependencies:
|
||||||
|
'@modelcontextprotocol/sdk':
|
||||||
|
specifier: link:/Users/glen/src/hax/mcp/typescript-sdk
|
||||||
|
version: link:../../hax/mcp/typescript-sdk
|
||||||
express:
|
express:
|
||||||
specifier: ^4.21.2
|
specifier: ^4.21.2
|
||||||
version: 4.21.2
|
version: 4.21.2
|
||||||
|
@ -15,9 +18,6 @@ importers:
|
||||||
specifier: ^10.1.0
|
specifier: ^10.1.0
|
||||||
version: 10.1.0
|
version: 10.1.0
|
||||||
devDependencies:
|
devDependencies:
|
||||||
'@modelcontextprotocol/sdk':
|
|
||||||
specifier: link:/Users/glen/src/hax/mcp/typescript-sdk
|
|
||||||
version: link:../../hax/mcp/typescript-sdk
|
|
||||||
'@types/express':
|
'@types/express':
|
||||||
specifier: ^5.0.0
|
specifier: ^5.0.0
|
||||||
version: 5.0.0
|
version: 5.0.0
|
||||||
|
|
|
@ -70,7 +70,15 @@ async function runClient(
|
||||||
|
|
||||||
try {
|
try {
|
||||||
// Connect to remote server with authentication
|
// Connect to remote server with authentication
|
||||||
const transport = await connectToRemoteServer(serverUrl, authProvider, headers, waitForAuthCode, skipBrowserAuth, transportStrategy)
|
const transport = await connectToRemoteServer(
|
||||||
|
client,
|
||||||
|
serverUrl,
|
||||||
|
authProvider,
|
||||||
|
headers,
|
||||||
|
waitForAuthCode,
|
||||||
|
skipBrowserAuth,
|
||||||
|
transportStrategy,
|
||||||
|
)
|
||||||
|
|
||||||
// Set up message and error handlers
|
// Set up message and error handlers
|
||||||
transport.onmessage = (message) => {
|
transport.onmessage = (message) => {
|
||||||
|
@ -94,9 +102,6 @@ async function runClient(
|
||||||
}
|
}
|
||||||
setupSignalHandlers(cleanup)
|
setupSignalHandlers(cleanup)
|
||||||
|
|
||||||
// Connect the client
|
|
||||||
log('Connecting client...')
|
|
||||||
await client.connect(transport)
|
|
||||||
log('Connected successfully!')
|
log('Connected successfully!')
|
||||||
|
|
||||||
try {
|
try {
|
||||||
|
@ -117,7 +122,10 @@ async function runClient(
|
||||||
log('Error requesting resources list:', e)
|
log('Error requesting resources list:', e)
|
||||||
}
|
}
|
||||||
|
|
||||||
log('Listening for messages. Press Ctrl+C to exit.')
|
// log('Listening for messages. Press Ctrl+C to exit.')
|
||||||
|
log('Exiting OK...')
|
||||||
|
server.close()
|
||||||
|
process.exit(0)
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
log('Fatal error:', error)
|
log('Fatal error:', error)
|
||||||
server.close()
|
server.close()
|
||||||
|
|
104
src/lib/utils.ts
104
src/lib/utils.ts
|
@ -1,4 +1,5 @@
|
||||||
import { OAuthClientProvider, UnauthorizedError } from '@modelcontextprotocol/sdk/client/auth.js'
|
import { OAuthClientProvider, UnauthorizedError } from '@modelcontextprotocol/sdk/client/auth.js'
|
||||||
|
import { Client } from '@modelcontextprotocol/sdk/client/index.js'
|
||||||
import { SSEClientTransport } from '@modelcontextprotocol/sdk/client/sse.js'
|
import { SSEClientTransport } from '@modelcontextprotocol/sdk/client/sse.js'
|
||||||
import { StreamableHTTPClientTransport } from '@modelcontextprotocol/sdk/client/streamableHttp.js'
|
import { StreamableHTTPClientTransport } from '@modelcontextprotocol/sdk/client/streamableHttp.js'
|
||||||
import { Transport } from '@modelcontextprotocol/sdk/shared/transport.js'
|
import { Transport } from '@modelcontextprotocol/sdk/shared/transport.js'
|
||||||
|
@ -74,6 +75,7 @@ export function mcpProxy({ transportToClient, transportToServer }: { transportTo
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Creates and connects to a remote server with OAuth authentication
|
* Creates and connects to a remote server with OAuth authentication
|
||||||
|
* @param client The client to connect with
|
||||||
* @param serverUrl The URL of the remote server
|
* @param serverUrl The URL of the remote server
|
||||||
* @param authProvider The OAuth client provider
|
* @param authProvider The OAuth client provider
|
||||||
* @param headers Additional headers to send with the request
|
* @param headers Additional headers to send with the request
|
||||||
|
@ -84,6 +86,7 @@ export function mcpProxy({ transportToClient, transportToServer }: { transportTo
|
||||||
* @returns The connected transport
|
* @returns The connected transport
|
||||||
*/
|
*/
|
||||||
export async function connectToRemoteServer(
|
export async function connectToRemoteServer(
|
||||||
|
client: Client,
|
||||||
serverUrl: string,
|
serverUrl: string,
|
||||||
authProvider: OAuthClientProvider,
|
authProvider: OAuthClientProvider,
|
||||||
headers: Record<string, string>,
|
headers: Record<string, string>,
|
||||||
|
@ -112,54 +115,68 @@ export async function connectToRemoteServer(
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
// Choose transport based on user strategy and recursion history
|
log(`Using transport strategy: ${transportStrategy}`)
|
||||||
let transport;
|
|
||||||
let shouldAttemptFallback = false;
|
|
||||||
|
|
||||||
// If we've already tried falling back once, throw an error
|
|
||||||
if (recursionReasons.has(REASON_TRANSPORT_FALLBACK)) {
|
|
||||||
const errorMessage = `Already attempted transport fallback. Giving up.`;
|
|
||||||
log(errorMessage);
|
|
||||||
throw new Error(errorMessage);
|
|
||||||
}
|
|
||||||
|
|
||||||
log(`Using transport strategy: ${transportStrategy}`);
|
|
||||||
// Determine if we should attempt to fallback on error
|
// Determine if we should attempt to fallback on error
|
||||||
shouldAttemptFallback = transportStrategy === 'http-first' || transportStrategy === 'sse-first';
|
// Choose transport based on user strategy and recursion history
|
||||||
|
const shouldAttemptFallback = transportStrategy === 'http-first' || transportStrategy === 'sse-first'
|
||||||
|
|
||||||
// Create transport instance based on the strategy
|
// Create transport instance based on the strategy
|
||||||
if (transportStrategy === 'sse-only' || transportStrategy === 'sse-first') {
|
const sseTransport = transportStrategy === 'sse-only' || transportStrategy === 'sse-first'
|
||||||
transport = new SSEClientTransport(url, {
|
const transport = sseTransport
|
||||||
authProvider,
|
? new SSEClientTransport(url, {
|
||||||
requestInit: { headers },
|
authProvider,
|
||||||
eventSourceInit,
|
requestInit: { headers },
|
||||||
});
|
eventSourceInit,
|
||||||
} else { // http-only or http-first
|
})
|
||||||
transport = new StreamableHTTPClientTransport(url, {
|
: new StreamableHTTPClientTransport(url, {
|
||||||
sessionId: crypto.randomUUID(),
|
authProvider,
|
||||||
});
|
requestInit: { headers },
|
||||||
}
|
})
|
||||||
|
|
||||||
try {
|
try {
|
||||||
await transport.start()
|
await client.connect(transport)
|
||||||
log(`Connected to remote server using ${transport.constructor.name}`)
|
log(`Connected to remote server using ${transport.constructor.name}`)
|
||||||
|
|
||||||
|
if (!sseTransport) {
|
||||||
|
console.log({ serverCapabilities: await client.getServerCapabilities() })
|
||||||
|
}
|
||||||
|
|
||||||
return transport
|
return transport
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
|
console.log('DID I CATCH OR WHAT?')
|
||||||
// Check if it's a protocol error and we should attempt fallback
|
// Check if it's a protocol error and we should attempt fallback
|
||||||
if (error instanceof Error &&
|
if (
|
||||||
shouldAttemptFallback &&
|
error instanceof Error &&
|
||||||
(error.message.includes('405') ||
|
shouldAttemptFallback &&
|
||||||
error.message.includes('Method Not Allowed') ||
|
(sseTransport
|
||||||
error.message.toLowerCase().includes('protocol'))) {
|
? error.message.includes('405') || error.message.includes('Method Not Allowed')
|
||||||
|
: error.message.includes('404') || error.message.includes('Not Found'))
|
||||||
|
) {
|
||||||
log(`Received error: ${error.message}`)
|
log(`Received error: ${error.message}`)
|
||||||
|
|
||||||
|
// If we've already tried falling back once, throw an error
|
||||||
|
if (recursionReasons.has(REASON_TRANSPORT_FALLBACK)) {
|
||||||
|
const errorMessage = `Already attempted transport fallback. Giving up.`
|
||||||
|
log(errorMessage)
|
||||||
|
throw new Error(errorMessage)
|
||||||
|
}
|
||||||
|
|
||||||
log(`Recursively reconnecting for reason: ${REASON_TRANSPORT_FALLBACK}`)
|
log(`Recursively reconnecting for reason: ${REASON_TRANSPORT_FALLBACK}`)
|
||||||
|
|
||||||
// Add to recursion reasons set
|
// Add to recursion reasons set
|
||||||
recursionReasons.add(REASON_TRANSPORT_FALLBACK)
|
recursionReasons.add(REASON_TRANSPORT_FALLBACK)
|
||||||
|
|
||||||
// Recursively call connectToRemoteServer with the updated recursion tracking
|
// Recursively call connectToRemoteServer with the updated recursion tracking
|
||||||
return connectToRemoteServer(serverUrl, authProvider, headers, waitForAuthCode, skipBrowserAuth, transportStrategy, recursionReasons)
|
return connectToRemoteServer(
|
||||||
|
client,
|
||||||
|
serverUrl,
|
||||||
|
authProvider,
|
||||||
|
headers,
|
||||||
|
waitForAuthCode,
|
||||||
|
skipBrowserAuth,
|
||||||
|
sseTransport ? 'http-only' : 'sse-only',
|
||||||
|
recursionReasons,
|
||||||
|
)
|
||||||
} else if (error instanceof UnauthorizedError || (error instanceof Error && error.message.includes('Unauthorized'))) {
|
} else if (error instanceof UnauthorizedError || (error instanceof Error && error.message.includes('Unauthorized'))) {
|
||||||
if (skipBrowserAuth) {
|
if (skipBrowserAuth) {
|
||||||
log('Authentication required but skipping browser auth - using shared auth')
|
log('Authentication required but skipping browser auth - using shared auth')
|
||||||
|
@ -179,13 +196,22 @@ export async function connectToRemoteServer(
|
||||||
log(errorMessage)
|
log(errorMessage)
|
||||||
throw new Error(errorMessage)
|
throw new Error(errorMessage)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Track this reason for recursion
|
// Track this reason for recursion
|
||||||
recursionReasons.add(REASON_AUTH_NEEDED)
|
recursionReasons.add(REASON_AUTH_NEEDED)
|
||||||
log(`Recursively reconnecting for reason: ${REASON_AUTH_NEEDED}`)
|
log(`Recursively reconnecting for reason: ${REASON_AUTH_NEEDED}`)
|
||||||
|
|
||||||
// Recursively call connectToRemoteServer with the updated recursion tracking
|
// Recursively call connectToRemoteServer with the updated recursion tracking
|
||||||
return connectToRemoteServer(serverUrl, authProvider, headers, waitForAuthCode, skipBrowserAuth, transportStrategy, recursionReasons)
|
return connectToRemoteServer(
|
||||||
|
client,
|
||||||
|
serverUrl,
|
||||||
|
authProvider,
|
||||||
|
headers,
|
||||||
|
waitForAuthCode,
|
||||||
|
skipBrowserAuth,
|
||||||
|
transportStrategy,
|
||||||
|
recursionReasons,
|
||||||
|
)
|
||||||
} catch (authError) {
|
} catch (authError) {
|
||||||
log('Authorization error:', authError)
|
log('Authorization error:', authError)
|
||||||
throw authError
|
throw authError
|
||||||
|
@ -356,7 +382,7 @@ export async function parseCommandLineArgs(args: string[], defaultPort: number,
|
||||||
const serverUrl = args[0]
|
const serverUrl = args[0]
|
||||||
const specifiedPort = args[1] ? parseInt(args[1]) : undefined
|
const specifiedPort = args[1] ? parseInt(args[1]) : undefined
|
||||||
const allowHttp = args.includes('--allow-http')
|
const allowHttp = args.includes('--allow-http')
|
||||||
|
|
||||||
// Parse transport strategy
|
// Parse transport strategy
|
||||||
let transportStrategy: TransportStrategy = 'http-first' // Default
|
let transportStrategy: TransportStrategy = 'http-first' // Default
|
||||||
const transportIndex = args.indexOf('--transport')
|
const transportIndex = args.indexOf('--transport')
|
||||||
|
|
22
src/proxy.ts
22
src/proxy.ts
|
@ -11,9 +11,18 @@
|
||||||
|
|
||||||
import { EventEmitter } from 'events'
|
import { EventEmitter } from 'events'
|
||||||
import { StdioServerTransport } from '@modelcontextprotocol/sdk/server/stdio.js'
|
import { StdioServerTransport } from '@modelcontextprotocol/sdk/server/stdio.js'
|
||||||
import { connectToRemoteServer, log, mcpProxy, parseCommandLineArgs, setupSignalHandlers, getServerUrlHash } from './lib/utils'
|
import {
|
||||||
|
connectToRemoteServer,
|
||||||
|
log,
|
||||||
|
mcpProxy,
|
||||||
|
parseCommandLineArgs,
|
||||||
|
setupSignalHandlers,
|
||||||
|
getServerUrlHash,
|
||||||
|
MCP_REMOTE_VERSION,
|
||||||
|
} from './lib/utils'
|
||||||
import { NodeOAuthClientProvider } from './lib/node-oauth-client-provider'
|
import { NodeOAuthClientProvider } from './lib/node-oauth-client-provider'
|
||||||
import { coordinateAuth } from './lib/coordination'
|
import { coordinateAuth } from './lib/coordination'
|
||||||
|
import { Client } from '@modelcontextprotocol/sdk/client/index.js'
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Main function to run the proxy
|
* Main function to run the proxy
|
||||||
|
@ -47,8 +56,17 @@ async function runProxy(serverUrl: string, callbackPort: number, headers: Record
|
||||||
const localTransport = new StdioServerTransport()
|
const localTransport = new StdioServerTransport()
|
||||||
|
|
||||||
try {
|
try {
|
||||||
|
const client = new Client(
|
||||||
|
{
|
||||||
|
name: 'mcp-remote',
|
||||||
|
version: MCP_REMOTE_VERSION,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
capabilities: {},
|
||||||
|
},
|
||||||
|
)
|
||||||
// Connect to remote server with authentication
|
// Connect to remote server with authentication
|
||||||
const remoteTransport = await connectToRemoteServer(serverUrl, authProvider, headers, waitForAuthCode, skipBrowserAuth)
|
const remoteTransport = await connectToRemoteServer(client, serverUrl, authProvider, headers, waitForAuthCode, skipBrowserAuth)
|
||||||
|
|
||||||
// Set up bidirectional proxy between local and remote transports
|
// Set up bidirectional proxy between local and remote transports
|
||||||
mcpProxy({
|
mcpProxy({
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue