wip 2
This commit is contained in:
parent
ec058d240d
commit
dee974b8b2
1 changed files with 53 additions and 18 deletions
|
@ -2,6 +2,10 @@ import { OAuthClientProvider, UnauthorizedError } from '@modelcontextprotocol/sd
|
|||
import { SSEClientTransport } from '@modelcontextprotocol/sdk/client/sse.js'
|
||||
import { StreamableHTTPClientTransport } from '@modelcontextprotocol/sdk/client/streamableHttp.js'
|
||||
import { Transport } from '@modelcontextprotocol/sdk/shared/transport.js'
|
||||
|
||||
// Connection constants
|
||||
export const REASON_AUTH_NEEDED = 'authentication-needed'
|
||||
export const REASON_FALLBACK_TO_SSE = 'falling-back-to-sse-transport'
|
||||
import { OAuthCallbackServerOptions } from './types'
|
||||
import express from 'express'
|
||||
import net from 'net'
|
||||
|
@ -80,6 +84,7 @@ export async function connectToRemoteServer(
|
|||
headers: Record<string, string>,
|
||||
waitForAuthCode: () => Promise<string>,
|
||||
skipBrowserAuth: boolean = false,
|
||||
recursionReasons: Set<string> = new Set(),
|
||||
): Promise<Transport> {
|
||||
log(`[${pid}] Connecting to remote server: ${serverUrl}`)
|
||||
const url = new URL(serverUrl)
|
||||
|
@ -101,23 +106,47 @@ export async function connectToRemoteServer(
|
|||
},
|
||||
}
|
||||
|
||||
const TESTING_NEW_TRANSPORT = true
|
||||
const transport = TESTING_NEW_TRANSPORT
|
||||
? new StreamableHTTPClientTransport(url, {
|
||||
sessionId: crypto.randomUUID(),
|
||||
})
|
||||
: new SSEClientTransport(url, {
|
||||
authProvider,
|
||||
requestInit: { headers },
|
||||
eventSourceInit,
|
||||
})
|
||||
// Choose transport based on recursion history
|
||||
let transport;
|
||||
|
||||
if (recursionReasons.has(REASON_FALLBACK_TO_SSE)) {
|
||||
log('Using SSEClientTransport due to previous protocol failure')
|
||||
transport = new SSEClientTransport(url, {
|
||||
authProvider,
|
||||
requestInit: { headers },
|
||||
eventSourceInit,
|
||||
})
|
||||
} else {
|
||||
log('Trying StreamableHTTPClientTransport first')
|
||||
transport = new StreamableHTTPClientTransport(url, {
|
||||
sessionId: crypto.randomUUID(),
|
||||
})
|
||||
}
|
||||
|
||||
try {
|
||||
await transport.start()
|
||||
log('Connected to remote server')
|
||||
log(`Connected to remote server using ${transport.constructor.name}`)
|
||||
return transport
|
||||
} catch (error) {
|
||||
if (error instanceof UnauthorizedError || (error instanceof Error && error.message.includes('Unauthorized'))) {
|
||||
// Check if it's a 405 Method Not Allowed error or similar protocol issue
|
||||
if (error instanceof Error &&
|
||||
!recursionReasons.has(REASON_FALLBACK_TO_SSE) &&
|
||||
(error.message.includes('405') ||
|
||||
error.message.includes('Method Not Allowed') ||
|
||||
error.message.toLowerCase().includes('protocol'))) {
|
||||
|
||||
// This condition is already checked above, so we will never reach here with REASON_FALLBACK_TO_SSE
|
||||
// But keeping it as a safeguard
|
||||
|
||||
log(`Received error: ${error.message}`)
|
||||
log(`Recursively reconnecting for reason: ${REASON_FALLBACK_TO_SSE}`)
|
||||
|
||||
// Add to recursion reasons set
|
||||
recursionReasons.add(REASON_FALLBACK_TO_SSE)
|
||||
|
||||
// Recursively call connectToRemoteServer with the updated recursion tracking
|
||||
return connectToRemoteServer(serverUrl, authProvider, headers, waitForAuthCode, skipBrowserAuth, recursionReasons)
|
||||
} else if (error instanceof UnauthorizedError || (error instanceof Error && error.message.includes('Unauthorized'))) {
|
||||
if (skipBrowserAuth) {
|
||||
log('Authentication required but skipping browser auth - using shared auth')
|
||||
} else {
|
||||
|
@ -131,12 +160,18 @@ export async function connectToRemoteServer(
|
|||
log('Completing authorization...')
|
||||
await transport.finishAuth(code)
|
||||
|
||||
// Create a new transport after auth
|
||||
// TODO: this needs to be the same transport type as the originals
|
||||
const newTransport = new SSEClientTransport(url, { authProvider, requestInit: { headers } })
|
||||
await newTransport.start()
|
||||
log('Connected to remote server after authentication')
|
||||
return newTransport
|
||||
if (recursionReasons.has(REASON_AUTH_NEEDED)) {
|
||||
const errorMessage = `Already attempted reconnection for reason: ${REASON_AUTH_NEEDED}. Giving up.`
|
||||
log(errorMessage)
|
||||
throw new Error(errorMessage)
|
||||
}
|
||||
|
||||
// Track this reason for recursion
|
||||
recursionReasons.add(REASON_AUTH_NEEDED)
|
||||
log(`Recursively reconnecting for reason: ${REASON_AUTH_NEEDED}`)
|
||||
|
||||
// Recursively call connectToRemoteServer with the updated recursion tracking
|
||||
return connectToRemoteServer(serverUrl, authProvider, headers, waitForAuthCode, skipBrowserAuth, recursionReasons)
|
||||
} catch (authError) {
|
||||
log('Authorization error:', authError)
|
||||
throw authError
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue