This commit is contained in:
Glen Maddern 2025-04-16 16:59:44 +10:00
parent ec058d240d
commit dee974b8b2

View file

@ -2,6 +2,10 @@ import { OAuthClientProvider, UnauthorizedError } from '@modelcontextprotocol/sd
import { SSEClientTransport } from '@modelcontextprotocol/sdk/client/sse.js'
import { StreamableHTTPClientTransport } from '@modelcontextprotocol/sdk/client/streamableHttp.js'
import { Transport } from '@modelcontextprotocol/sdk/shared/transport.js'
// Connection constants
export const REASON_AUTH_NEEDED = 'authentication-needed'
export const REASON_FALLBACK_TO_SSE = 'falling-back-to-sse-transport'
import { OAuthCallbackServerOptions } from './types'
import express from 'express'
import net from 'net'
@ -80,6 +84,7 @@ export async function connectToRemoteServer(
headers: Record<string, string>,
waitForAuthCode: () => Promise<string>,
skipBrowserAuth: boolean = false,
recursionReasons: Set<string> = new Set(),
): Promise<Transport> {
log(`[${pid}] Connecting to remote server: ${serverUrl}`)
const url = new URL(serverUrl)
@ -101,23 +106,47 @@ export async function connectToRemoteServer(
},
}
const TESTING_NEW_TRANSPORT = true
const transport = TESTING_NEW_TRANSPORT
? new StreamableHTTPClientTransport(url, {
sessionId: crypto.randomUUID(),
})
: new SSEClientTransport(url, {
authProvider,
requestInit: { headers },
eventSourceInit,
})
// Choose transport based on recursion history
let transport;
if (recursionReasons.has(REASON_FALLBACK_TO_SSE)) {
log('Using SSEClientTransport due to previous protocol failure')
transport = new SSEClientTransport(url, {
authProvider,
requestInit: { headers },
eventSourceInit,
})
} else {
log('Trying StreamableHTTPClientTransport first')
transport = new StreamableHTTPClientTransport(url, {
sessionId: crypto.randomUUID(),
})
}
try {
await transport.start()
log('Connected to remote server')
log(`Connected to remote server using ${transport.constructor.name}`)
return transport
} catch (error) {
if (error instanceof UnauthorizedError || (error instanceof Error && error.message.includes('Unauthorized'))) {
// Check if it's a 405 Method Not Allowed error or similar protocol issue
if (error instanceof Error &&
!recursionReasons.has(REASON_FALLBACK_TO_SSE) &&
(error.message.includes('405') ||
error.message.includes('Method Not Allowed') ||
error.message.toLowerCase().includes('protocol'))) {
// This condition is already checked above, so we will never reach here with REASON_FALLBACK_TO_SSE
// But keeping it as a safeguard
log(`Received error: ${error.message}`)
log(`Recursively reconnecting for reason: ${REASON_FALLBACK_TO_SSE}`)
// Add to recursion reasons set
recursionReasons.add(REASON_FALLBACK_TO_SSE)
// Recursively call connectToRemoteServer with the updated recursion tracking
return connectToRemoteServer(serverUrl, authProvider, headers, waitForAuthCode, skipBrowserAuth, recursionReasons)
} else if (error instanceof UnauthorizedError || (error instanceof Error && error.message.includes('Unauthorized'))) {
if (skipBrowserAuth) {
log('Authentication required but skipping browser auth - using shared auth')
} else {
@ -131,12 +160,18 @@ export async function connectToRemoteServer(
log('Completing authorization...')
await transport.finishAuth(code)
// Create a new transport after auth
// TODO: this needs to be the same transport type as the originals
const newTransport = new SSEClientTransport(url, { authProvider, requestInit: { headers } })
await newTransport.start()
log('Connected to remote server after authentication')
return newTransport
if (recursionReasons.has(REASON_AUTH_NEEDED)) {
const errorMessage = `Already attempted reconnection for reason: ${REASON_AUTH_NEEDED}. Giving up.`
log(errorMessage)
throw new Error(errorMessage)
}
// Track this reason for recursion
recursionReasons.add(REASON_AUTH_NEEDED)
log(`Recursively reconnecting for reason: ${REASON_AUTH_NEEDED}`)
// Recursively call connectToRemoteServer with the updated recursion tracking
return connectToRemoteServer(serverUrl, authProvider, headers, waitForAuthCode, skipBrowserAuth, recursionReasons)
} catch (authError) {
log('Authorization error:', authError)
throw authError