This commit is contained in:
Glen Maddern 2025-04-16 17:13:19 +10:00
parent dee974b8b2
commit f80c6c4850

View file

@ -5,7 +5,10 @@ import { Transport } from '@modelcontextprotocol/sdk/shared/transport.js'
// Connection constants // Connection constants
export const REASON_AUTH_NEEDED = 'authentication-needed' export const REASON_AUTH_NEEDED = 'authentication-needed'
export const REASON_FALLBACK_TO_SSE = 'falling-back-to-sse-transport' export const REASON_TRANSPORT_FALLBACK = 'falling-back-to-alternate-transport'
// Transport strategy types
export type TransportStrategy = 'sse-only' | 'http-only' | 'sse-first' | 'http-first'
import { OAuthCallbackServerOptions } from './types' import { OAuthCallbackServerOptions } from './types'
import express from 'express' import express from 'express'
import net from 'net' import net from 'net'
@ -70,13 +73,15 @@ export function mcpProxy({ transportToClient, transportToServer }: { transportTo
} }
/** /**
* Creates and connects to a remote SSE server with OAuth authentication * Creates and connects to a remote server with OAuth authentication
* @param serverUrl The URL of the remote server * @param serverUrl The URL of the remote server
* @param authProvider The OAuth client provider * @param authProvider The OAuth client provider
* @param headers Additional headers to send with the request * @param headers Additional headers to send with the request
* @param waitForAuthCode Function to wait for the auth code * @param waitForAuthCode Function to wait for the auth code
* @param skipBrowserAuth Whether to skip browser auth and use shared auth * @param skipBrowserAuth Whether to skip browser auth and use shared auth
* @returns The connected SSE client transport * @param transportStrategy Strategy for selecting transport type ('sse-only', 'http-only', 'sse-first', 'http-first')
* @param recursionReasons Set of reasons for recursive calls (internal use)
* @returns The connected transport
*/ */
export async function connectToRemoteServer( export async function connectToRemoteServer(
serverUrl: string, serverUrl: string,
@ -84,6 +89,7 @@ export async function connectToRemoteServer(
headers: Record<string, string>, headers: Record<string, string>,
waitForAuthCode: () => Promise<string>, waitForAuthCode: () => Promise<string>,
skipBrowserAuth: boolean = false, skipBrowserAuth: boolean = false,
transportStrategy: TransportStrategy = 'http-first',
recursionReasons: Set<string> = new Set(), recursionReasons: Set<string> = new Set(),
): Promise<Transport> { ): Promise<Transport> {
log(`[${pid}] Connecting to remote server: ${serverUrl}`) log(`[${pid}] Connecting to remote server: ${serverUrl}`)
@ -106,21 +112,32 @@ export async function connectToRemoteServer(
}, },
} }
// Choose transport based on recursion history // Choose transport based on user strategy and recursion history
let transport; let transport;
let shouldAttemptFallback = false;
if (recursionReasons.has(REASON_FALLBACK_TO_SSE)) { // If we've already tried falling back once, throw an error
log('Using SSEClientTransport due to previous protocol failure') if (recursionReasons.has(REASON_TRANSPORT_FALLBACK)) {
const errorMessage = `Already attempted transport fallback. Giving up.`;
log(errorMessage);
throw new Error(errorMessage);
}
log(`Using transport strategy: ${transportStrategy}`);
// Determine if we should attempt to fallback on error
shouldAttemptFallback = transportStrategy === 'http-first' || transportStrategy === 'sse-first';
// Create transport instance based on the strategy
if (transportStrategy === 'sse-only' || transportStrategy === 'sse-first') {
transport = new SSEClientTransport(url, { transport = new SSEClientTransport(url, {
authProvider, authProvider,
requestInit: { headers }, requestInit: { headers },
eventSourceInit, eventSourceInit,
}) });
} else { } else { // http-only or http-first
log('Trying StreamableHTTPClientTransport first')
transport = new StreamableHTTPClientTransport(url, { transport = new StreamableHTTPClientTransport(url, {
sessionId: crypto.randomUUID(), sessionId: crypto.randomUUID(),
}) });
} }
try { try {
@ -128,24 +145,21 @@ export async function connectToRemoteServer(
log(`Connected to remote server using ${transport.constructor.name}`) log(`Connected to remote server using ${transport.constructor.name}`)
return transport return transport
} catch (error) { } catch (error) {
// Check if it's a 405 Method Not Allowed error or similar protocol issue // Check if it's a protocol error and we should attempt fallback
if (error instanceof Error && if (error instanceof Error &&
!recursionReasons.has(REASON_FALLBACK_TO_SSE) && shouldAttemptFallback &&
(error.message.includes('405') || (error.message.includes('405') ||
error.message.includes('Method Not Allowed') || error.message.includes('Method Not Allowed') ||
error.message.toLowerCase().includes('protocol'))) { error.message.toLowerCase().includes('protocol'))) {
// This condition is already checked above, so we will never reach here with REASON_FALLBACK_TO_SSE
// But keeping it as a safeguard
log(`Received error: ${error.message}`) log(`Received error: ${error.message}`)
log(`Recursively reconnecting for reason: ${REASON_FALLBACK_TO_SSE}`) log(`Recursively reconnecting for reason: ${REASON_TRANSPORT_FALLBACK}`)
// Add to recursion reasons set // Add to recursion reasons set
recursionReasons.add(REASON_FALLBACK_TO_SSE) recursionReasons.add(REASON_TRANSPORT_FALLBACK)
// Recursively call connectToRemoteServer with the updated recursion tracking // Recursively call connectToRemoteServer with the updated recursion tracking
return connectToRemoteServer(serverUrl, authProvider, headers, waitForAuthCode, skipBrowserAuth, recursionReasons) return connectToRemoteServer(serverUrl, authProvider, headers, waitForAuthCode, skipBrowserAuth, transportStrategy, recursionReasons)
} else if (error instanceof UnauthorizedError || (error instanceof Error && error.message.includes('Unauthorized'))) { } else if (error instanceof UnauthorizedError || (error instanceof Error && error.message.includes('Unauthorized'))) {
if (skipBrowserAuth) { if (skipBrowserAuth) {
log('Authentication required but skipping browser auth - using shared auth') log('Authentication required but skipping browser auth - using shared auth')
@ -171,7 +185,7 @@ export async function connectToRemoteServer(
log(`Recursively reconnecting for reason: ${REASON_AUTH_NEEDED}`) log(`Recursively reconnecting for reason: ${REASON_AUTH_NEEDED}`)
// Recursively call connectToRemoteServer with the updated recursion tracking // Recursively call connectToRemoteServer with the updated recursion tracking
return connectToRemoteServer(serverUrl, authProvider, headers, waitForAuthCode, skipBrowserAuth, recursionReasons) return connectToRemoteServer(serverUrl, authProvider, headers, waitForAuthCode, skipBrowserAuth, transportStrategy, recursionReasons)
} catch (authError) { } catch (authError) {
log('Authorization error:', authError) log('Authorization error:', authError)
throw authError throw authError
@ -343,6 +357,19 @@ export async function parseCommandLineArgs(args: string[], defaultPort: number,
const specifiedPort = args[1] ? parseInt(args[1]) : undefined const specifiedPort = args[1] ? parseInt(args[1]) : undefined
const allowHttp = args.includes('--allow-http') const allowHttp = args.includes('--allow-http')
// Parse transport strategy
let transportStrategy: TransportStrategy = 'http-first' // Default
const transportIndex = args.indexOf('--transport')
if (transportIndex !== -1 && transportIndex < args.length - 1) {
const strategy = args[transportIndex + 1]
if (strategy === 'sse-only' || strategy === 'http-only' || strategy === 'sse-first' || strategy === 'http-first') {
transportStrategy = strategy as TransportStrategy
log(`Using transport strategy: ${transportStrategy}`)
} else {
log(`Warning: Ignoring invalid transport strategy: ${strategy}. Valid values are: sse-only, http-only, sse-first, http-first`)
}
}
if (!serverUrl) { if (!serverUrl) {
log(usage) log(usage)
process.exit(1) process.exit(1)
@ -385,7 +412,7 @@ export async function parseCommandLineArgs(args: string[], defaultPort: number,
}) })
} }
return { serverUrl, callbackPort, headers } return { serverUrl, callbackPort, headers, transportStrategy }
} }
/** /**