diff --git a/src/lib/utils.ts b/src/lib/utils.ts index 940a8ca..280a489 100644 --- a/src/lib/utils.ts +++ b/src/lib/utils.ts @@ -2,6 +2,10 @@ import { OAuthClientProvider, UnauthorizedError } from '@modelcontextprotocol/sd import { SSEClientTransport } from '@modelcontextprotocol/sdk/client/sse.js' import { StreamableHTTPClientTransport } from '@modelcontextprotocol/sdk/client/streamableHttp.js' import { Transport } from '@modelcontextprotocol/sdk/shared/transport.js' + +// Connection constants +export const REASON_AUTH_NEEDED = 'authentication-needed' +export const REASON_FALLBACK_TO_SSE = 'falling-back-to-sse-transport' import { OAuthCallbackServerOptions } from './types' import express from 'express' import net from 'net' @@ -80,6 +84,7 @@ export async function connectToRemoteServer( headers: Record, waitForAuthCode: () => Promise, skipBrowserAuth: boolean = false, + recursionReasons: Set = new Set(), ): Promise { log(`[${pid}] Connecting to remote server: ${serverUrl}`) const url = new URL(serverUrl) @@ -101,23 +106,47 @@ export async function connectToRemoteServer( }, } - const TESTING_NEW_TRANSPORT = true - const transport = TESTING_NEW_TRANSPORT - ? new StreamableHTTPClientTransport(url, { - sessionId: crypto.randomUUID(), - }) - : new SSEClientTransport(url, { - authProvider, - requestInit: { headers }, - eventSourceInit, - }) + // Choose transport based on recursion history + let transport; + + if (recursionReasons.has(REASON_FALLBACK_TO_SSE)) { + log('Using SSEClientTransport due to previous protocol failure') + transport = new SSEClientTransport(url, { + authProvider, + requestInit: { headers }, + eventSourceInit, + }) + } else { + log('Trying StreamableHTTPClientTransport first') + transport = new StreamableHTTPClientTransport(url, { + sessionId: crypto.randomUUID(), + }) + } try { await transport.start() - log('Connected to remote server') + log(`Connected to remote server using ${transport.constructor.name}`) return transport } catch (error) { - if (error instanceof UnauthorizedError || (error instanceof Error && error.message.includes('Unauthorized'))) { + // Check if it's a 405 Method Not Allowed error or similar protocol issue + if (error instanceof Error && + !recursionReasons.has(REASON_FALLBACK_TO_SSE) && + (error.message.includes('405') || + error.message.includes('Method Not Allowed') || + error.message.toLowerCase().includes('protocol'))) { + + // This condition is already checked above, so we will never reach here with REASON_FALLBACK_TO_SSE + // But keeping it as a safeguard + + log(`Received error: ${error.message}`) + log(`Recursively reconnecting for reason: ${REASON_FALLBACK_TO_SSE}`) + + // Add to recursion reasons set + recursionReasons.add(REASON_FALLBACK_TO_SSE) + + // Recursively call connectToRemoteServer with the updated recursion tracking + return connectToRemoteServer(serverUrl, authProvider, headers, waitForAuthCode, skipBrowserAuth, recursionReasons) + } else if (error instanceof UnauthorizedError || (error instanceof Error && error.message.includes('Unauthorized'))) { if (skipBrowserAuth) { log('Authentication required but skipping browser auth - using shared auth') } else { @@ -131,12 +160,18 @@ export async function connectToRemoteServer( log('Completing authorization...') await transport.finishAuth(code) - // Create a new transport after auth - // TODO: this needs to be the same transport type as the originals - const newTransport = new SSEClientTransport(url, { authProvider, requestInit: { headers } }) - await newTransport.start() - log('Connected to remote server after authentication') - return newTransport + if (recursionReasons.has(REASON_AUTH_NEEDED)) { + const errorMessage = `Already attempted reconnection for reason: ${REASON_AUTH_NEEDED}. Giving up.` + log(errorMessage) + throw new Error(errorMessage) + } + + // Track this reason for recursion + recursionReasons.add(REASON_AUTH_NEEDED) + log(`Recursively reconnecting for reason: ${REASON_AUTH_NEEDED}`) + + // Recursively call connectToRemoteServer with the updated recursion tracking + return connectToRemoteServer(serverUrl, authProvider, headers, waitForAuthCode, skipBrowserAuth, recursionReasons) } catch (authError) { log('Authorization error:', authError) throw authError