Able to catch exceptions from one transport and fall back to the other
This commit is contained in:
parent
0bf84d5d22
commit
14109a309f
5 changed files with 104 additions and 51 deletions
104
src/lib/utils.ts
104
src/lib/utils.ts
|
@ -1,4 +1,5 @@
|
|||
import { OAuthClientProvider, UnauthorizedError } from '@modelcontextprotocol/sdk/client/auth.js'
|
||||
import { Client } from '@modelcontextprotocol/sdk/client/index.js'
|
||||
import { SSEClientTransport } from '@modelcontextprotocol/sdk/client/sse.js'
|
||||
import { StreamableHTTPClientTransport } from '@modelcontextprotocol/sdk/client/streamableHttp.js'
|
||||
import { Transport } from '@modelcontextprotocol/sdk/shared/transport.js'
|
||||
|
@ -74,6 +75,7 @@ export function mcpProxy({ transportToClient, transportToServer }: { transportTo
|
|||
|
||||
/**
|
||||
* Creates and connects to a remote server with OAuth authentication
|
||||
* @param client The client to connect with
|
||||
* @param serverUrl The URL of the remote server
|
||||
* @param authProvider The OAuth client provider
|
||||
* @param headers Additional headers to send with the request
|
||||
|
@ -84,6 +86,7 @@ export function mcpProxy({ transportToClient, transportToServer }: { transportTo
|
|||
* @returns The connected transport
|
||||
*/
|
||||
export async function connectToRemoteServer(
|
||||
client: Client,
|
||||
serverUrl: string,
|
||||
authProvider: OAuthClientProvider,
|
||||
headers: Record<string, string>,
|
||||
|
@ -112,54 +115,68 @@ export async function connectToRemoteServer(
|
|||
},
|
||||
}
|
||||
|
||||
// Choose transport based on user strategy and recursion history
|
||||
let transport;
|
||||
let shouldAttemptFallback = false;
|
||||
|
||||
// If we've already tried falling back once, throw an error
|
||||
if (recursionReasons.has(REASON_TRANSPORT_FALLBACK)) {
|
||||
const errorMessage = `Already attempted transport fallback. Giving up.`;
|
||||
log(errorMessage);
|
||||
throw new Error(errorMessage);
|
||||
}
|
||||
|
||||
log(`Using transport strategy: ${transportStrategy}`);
|
||||
log(`Using transport strategy: ${transportStrategy}`)
|
||||
// Determine if we should attempt to fallback on error
|
||||
shouldAttemptFallback = transportStrategy === 'http-first' || transportStrategy === 'sse-first';
|
||||
|
||||
// Choose transport based on user strategy and recursion history
|
||||
const shouldAttemptFallback = transportStrategy === 'http-first' || transportStrategy === 'sse-first'
|
||||
|
||||
// Create transport instance based on the strategy
|
||||
if (transportStrategy === 'sse-only' || transportStrategy === 'sse-first') {
|
||||
transport = new SSEClientTransport(url, {
|
||||
authProvider,
|
||||
requestInit: { headers },
|
||||
eventSourceInit,
|
||||
});
|
||||
} else { // http-only or http-first
|
||||
transport = new StreamableHTTPClientTransport(url, {
|
||||
sessionId: crypto.randomUUID(),
|
||||
});
|
||||
}
|
||||
const sseTransport = transportStrategy === 'sse-only' || transportStrategy === 'sse-first'
|
||||
const transport = sseTransport
|
||||
? new SSEClientTransport(url, {
|
||||
authProvider,
|
||||
requestInit: { headers },
|
||||
eventSourceInit,
|
||||
})
|
||||
: new StreamableHTTPClientTransport(url, {
|
||||
authProvider,
|
||||
requestInit: { headers },
|
||||
})
|
||||
|
||||
try {
|
||||
await transport.start()
|
||||
await client.connect(transport)
|
||||
log(`Connected to remote server using ${transport.constructor.name}`)
|
||||
|
||||
if (!sseTransport) {
|
||||
console.log({ serverCapabilities: await client.getServerCapabilities() })
|
||||
}
|
||||
|
||||
return transport
|
||||
} catch (error) {
|
||||
console.log('DID I CATCH OR WHAT?')
|
||||
// Check if it's a protocol error and we should attempt fallback
|
||||
if (error instanceof Error &&
|
||||
shouldAttemptFallback &&
|
||||
(error.message.includes('405') ||
|
||||
error.message.includes('Method Not Allowed') ||
|
||||
error.message.toLowerCase().includes('protocol'))) {
|
||||
|
||||
if (
|
||||
error instanceof Error &&
|
||||
shouldAttemptFallback &&
|
||||
(sseTransport
|
||||
? error.message.includes('405') || error.message.includes('Method Not Allowed')
|
||||
: error.message.includes('404') || error.message.includes('Not Found'))
|
||||
) {
|
||||
log(`Received error: ${error.message}`)
|
||||
|
||||
// If we've already tried falling back once, throw an error
|
||||
if (recursionReasons.has(REASON_TRANSPORT_FALLBACK)) {
|
||||
const errorMessage = `Already attempted transport fallback. Giving up.`
|
||||
log(errorMessage)
|
||||
throw new Error(errorMessage)
|
||||
}
|
||||
|
||||
log(`Recursively reconnecting for reason: ${REASON_TRANSPORT_FALLBACK}`)
|
||||
|
||||
|
||||
// Add to recursion reasons set
|
||||
recursionReasons.add(REASON_TRANSPORT_FALLBACK)
|
||||
|
||||
|
||||
// Recursively call connectToRemoteServer with the updated recursion tracking
|
||||
return connectToRemoteServer(serverUrl, authProvider, headers, waitForAuthCode, skipBrowserAuth, transportStrategy, recursionReasons)
|
||||
return connectToRemoteServer(
|
||||
client,
|
||||
serverUrl,
|
||||
authProvider,
|
||||
headers,
|
||||
waitForAuthCode,
|
||||
skipBrowserAuth,
|
||||
sseTransport ? 'http-only' : 'sse-only',
|
||||
recursionReasons,
|
||||
)
|
||||
} else if (error instanceof UnauthorizedError || (error instanceof Error && error.message.includes('Unauthorized'))) {
|
||||
if (skipBrowserAuth) {
|
||||
log('Authentication required but skipping browser auth - using shared auth')
|
||||
|
@ -179,13 +196,22 @@ export async function connectToRemoteServer(
|
|||
log(errorMessage)
|
||||
throw new Error(errorMessage)
|
||||
}
|
||||
|
||||
|
||||
// Track this reason for recursion
|
||||
recursionReasons.add(REASON_AUTH_NEEDED)
|
||||
log(`Recursively reconnecting for reason: ${REASON_AUTH_NEEDED}`)
|
||||
|
||||
|
||||
// Recursively call connectToRemoteServer with the updated recursion tracking
|
||||
return connectToRemoteServer(serverUrl, authProvider, headers, waitForAuthCode, skipBrowserAuth, transportStrategy, recursionReasons)
|
||||
return connectToRemoteServer(
|
||||
client,
|
||||
serverUrl,
|
||||
authProvider,
|
||||
headers,
|
||||
waitForAuthCode,
|
||||
skipBrowserAuth,
|
||||
transportStrategy,
|
||||
recursionReasons,
|
||||
)
|
||||
} catch (authError) {
|
||||
log('Authorization error:', authError)
|
||||
throw authError
|
||||
|
@ -356,7 +382,7 @@ export async function parseCommandLineArgs(args: string[], defaultPort: number,
|
|||
const serverUrl = args[0]
|
||||
const specifiedPort = args[1] ? parseInt(args[1]) : undefined
|
||||
const allowHttp = args.includes('--allow-http')
|
||||
|
||||
|
||||
// Parse transport strategy
|
||||
let transportStrategy: TransportStrategy = 'http-first' // Default
|
||||
const transportIndex = args.indexOf('--transport')
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue