From f80c6c4850da543e445d93287db6085d2ca1ea21 Mon Sep 17 00:00:00 2001 From: Glen Maddern Date: Wed, 16 Apr 2025 17:13:19 +1000 Subject: [PATCH] wip 3 --- src/lib/utils.ts | 67 +++++++++++++++++++++++++++++++++--------------- 1 file changed, 47 insertions(+), 20 deletions(-) diff --git a/src/lib/utils.ts b/src/lib/utils.ts index 280a489..dde1b4c 100644 --- a/src/lib/utils.ts +++ b/src/lib/utils.ts @@ -5,7 +5,10 @@ import { Transport } from '@modelcontextprotocol/sdk/shared/transport.js' // Connection constants export const REASON_AUTH_NEEDED = 'authentication-needed' -export const REASON_FALLBACK_TO_SSE = 'falling-back-to-sse-transport' +export const REASON_TRANSPORT_FALLBACK = 'falling-back-to-alternate-transport' + +// Transport strategy types +export type TransportStrategy = 'sse-only' | 'http-only' | 'sse-first' | 'http-first' import { OAuthCallbackServerOptions } from './types' import express from 'express' import net from 'net' @@ -70,13 +73,15 @@ export function mcpProxy({ transportToClient, transportToServer }: { transportTo } /** - * Creates and connects to a remote SSE server with OAuth authentication + * Creates and connects to a remote server with OAuth authentication * @param serverUrl The URL of the remote server * @param authProvider The OAuth client provider * @param headers Additional headers to send with the request * @param waitForAuthCode Function to wait for the auth code * @param skipBrowserAuth Whether to skip browser auth and use shared auth - * @returns The connected SSE client transport + * @param transportStrategy Strategy for selecting transport type ('sse-only', 'http-only', 'sse-first', 'http-first') + * @param recursionReasons Set of reasons for recursive calls (internal use) + * @returns The connected transport */ export async function connectToRemoteServer( serverUrl: string, @@ -84,6 +89,7 @@ export async function connectToRemoteServer( headers: Record, waitForAuthCode: () => Promise, skipBrowserAuth: boolean = false, + transportStrategy: TransportStrategy = 'http-first', recursionReasons: Set = new Set(), ): Promise { log(`[${pid}] Connecting to remote server: ${serverUrl}`) @@ -106,21 +112,32 @@ export async function connectToRemoteServer( }, } - // Choose transport based on recursion history + // Choose transport based on user strategy and recursion history let transport; + let shouldAttemptFallback = false; - if (recursionReasons.has(REASON_FALLBACK_TO_SSE)) { - log('Using SSEClientTransport due to previous protocol failure') + // If we've already tried falling back once, throw an error + if (recursionReasons.has(REASON_TRANSPORT_FALLBACK)) { + const errorMessage = `Already attempted transport fallback. Giving up.`; + log(errorMessage); + throw new Error(errorMessage); + } + + log(`Using transport strategy: ${transportStrategy}`); + // Determine if we should attempt to fallback on error + shouldAttemptFallback = transportStrategy === 'http-first' || transportStrategy === 'sse-first'; + + // Create transport instance based on the strategy + if (transportStrategy === 'sse-only' || transportStrategy === 'sse-first') { transport = new SSEClientTransport(url, { authProvider, requestInit: { headers }, eventSourceInit, - }) - } else { - log('Trying StreamableHTTPClientTransport first') + }); + } else { // http-only or http-first transport = new StreamableHTTPClientTransport(url, { sessionId: crypto.randomUUID(), - }) + }); } try { @@ -128,24 +145,21 @@ export async function connectToRemoteServer( log(`Connected to remote server using ${transport.constructor.name}`) return transport } catch (error) { - // Check if it's a 405 Method Not Allowed error or similar protocol issue + // Check if it's a protocol error and we should attempt fallback if (error instanceof Error && - !recursionReasons.has(REASON_FALLBACK_TO_SSE) && + shouldAttemptFallback && (error.message.includes('405') || error.message.includes('Method Not Allowed') || error.message.toLowerCase().includes('protocol'))) { - // This condition is already checked above, so we will never reach here with REASON_FALLBACK_TO_SSE - // But keeping it as a safeguard - log(`Received error: ${error.message}`) - log(`Recursively reconnecting for reason: ${REASON_FALLBACK_TO_SSE}`) + log(`Recursively reconnecting for reason: ${REASON_TRANSPORT_FALLBACK}`) // Add to recursion reasons set - recursionReasons.add(REASON_FALLBACK_TO_SSE) + recursionReasons.add(REASON_TRANSPORT_FALLBACK) // Recursively call connectToRemoteServer with the updated recursion tracking - return connectToRemoteServer(serverUrl, authProvider, headers, waitForAuthCode, skipBrowserAuth, recursionReasons) + return connectToRemoteServer(serverUrl, authProvider, headers, waitForAuthCode, skipBrowserAuth, transportStrategy, recursionReasons) } else if (error instanceof UnauthorizedError || (error instanceof Error && error.message.includes('Unauthorized'))) { if (skipBrowserAuth) { log('Authentication required but skipping browser auth - using shared auth') @@ -171,7 +185,7 @@ export async function connectToRemoteServer( log(`Recursively reconnecting for reason: ${REASON_AUTH_NEEDED}`) // Recursively call connectToRemoteServer with the updated recursion tracking - return connectToRemoteServer(serverUrl, authProvider, headers, waitForAuthCode, skipBrowserAuth, recursionReasons) + return connectToRemoteServer(serverUrl, authProvider, headers, waitForAuthCode, skipBrowserAuth, transportStrategy, recursionReasons) } catch (authError) { log('Authorization error:', authError) throw authError @@ -342,6 +356,19 @@ export async function parseCommandLineArgs(args: string[], defaultPort: number, const serverUrl = args[0] const specifiedPort = args[1] ? parseInt(args[1]) : undefined const allowHttp = args.includes('--allow-http') + + // Parse transport strategy + let transportStrategy: TransportStrategy = 'http-first' // Default + const transportIndex = args.indexOf('--transport') + if (transportIndex !== -1 && transportIndex < args.length - 1) { + const strategy = args[transportIndex + 1] + if (strategy === 'sse-only' || strategy === 'http-only' || strategy === 'sse-first' || strategy === 'http-first') { + transportStrategy = strategy as TransportStrategy + log(`Using transport strategy: ${transportStrategy}`) + } else { + log(`Warning: Ignoring invalid transport strategy: ${strategy}. Valid values are: sse-only, http-only, sse-first, http-first`) + } + } if (!serverUrl) { log(usage) @@ -385,7 +412,7 @@ export async function parseCommandLineArgs(args: string[], defaultPort: number, }) } - return { serverUrl, callbackPort, headers } + return { serverUrl, callbackPort, headers, transportStrategy } } /**