wip 2
This commit is contained in:
parent
ec058d240d
commit
dee974b8b2
1 changed files with 53 additions and 18 deletions
|
@ -2,6 +2,10 @@ import { OAuthClientProvider, UnauthorizedError } from '@modelcontextprotocol/sd
|
||||||
import { SSEClientTransport } from '@modelcontextprotocol/sdk/client/sse.js'
|
import { SSEClientTransport } from '@modelcontextprotocol/sdk/client/sse.js'
|
||||||
import { StreamableHTTPClientTransport } from '@modelcontextprotocol/sdk/client/streamableHttp.js'
|
import { StreamableHTTPClientTransport } from '@modelcontextprotocol/sdk/client/streamableHttp.js'
|
||||||
import { Transport } from '@modelcontextprotocol/sdk/shared/transport.js'
|
import { Transport } from '@modelcontextprotocol/sdk/shared/transport.js'
|
||||||
|
|
||||||
|
// Connection constants
|
||||||
|
export const REASON_AUTH_NEEDED = 'authentication-needed'
|
||||||
|
export const REASON_FALLBACK_TO_SSE = 'falling-back-to-sse-transport'
|
||||||
import { OAuthCallbackServerOptions } from './types'
|
import { OAuthCallbackServerOptions } from './types'
|
||||||
import express from 'express'
|
import express from 'express'
|
||||||
import net from 'net'
|
import net from 'net'
|
||||||
|
@ -80,6 +84,7 @@ export async function connectToRemoteServer(
|
||||||
headers: Record<string, string>,
|
headers: Record<string, string>,
|
||||||
waitForAuthCode: () => Promise<string>,
|
waitForAuthCode: () => Promise<string>,
|
||||||
skipBrowserAuth: boolean = false,
|
skipBrowserAuth: boolean = false,
|
||||||
|
recursionReasons: Set<string> = new Set(),
|
||||||
): Promise<Transport> {
|
): Promise<Transport> {
|
||||||
log(`[${pid}] Connecting to remote server: ${serverUrl}`)
|
log(`[${pid}] Connecting to remote server: ${serverUrl}`)
|
||||||
const url = new URL(serverUrl)
|
const url = new URL(serverUrl)
|
||||||
|
@ -101,23 +106,47 @@ export async function connectToRemoteServer(
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
const TESTING_NEW_TRANSPORT = true
|
// Choose transport based on recursion history
|
||||||
const transport = TESTING_NEW_TRANSPORT
|
let transport;
|
||||||
? new StreamableHTTPClientTransport(url, {
|
|
||||||
sessionId: crypto.randomUUID(),
|
if (recursionReasons.has(REASON_FALLBACK_TO_SSE)) {
|
||||||
})
|
log('Using SSEClientTransport due to previous protocol failure')
|
||||||
: new SSEClientTransport(url, {
|
transport = new SSEClientTransport(url, {
|
||||||
authProvider,
|
authProvider,
|
||||||
requestInit: { headers },
|
requestInit: { headers },
|
||||||
eventSourceInit,
|
eventSourceInit,
|
||||||
})
|
})
|
||||||
|
} else {
|
||||||
|
log('Trying StreamableHTTPClientTransport first')
|
||||||
|
transport = new StreamableHTTPClientTransport(url, {
|
||||||
|
sessionId: crypto.randomUUID(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
try {
|
try {
|
||||||
await transport.start()
|
await transport.start()
|
||||||
log('Connected to remote server')
|
log(`Connected to remote server using ${transport.constructor.name}`)
|
||||||
return transport
|
return transport
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
if (error instanceof UnauthorizedError || (error instanceof Error && error.message.includes('Unauthorized'))) {
|
// Check if it's a 405 Method Not Allowed error or similar protocol issue
|
||||||
|
if (error instanceof Error &&
|
||||||
|
!recursionReasons.has(REASON_FALLBACK_TO_SSE) &&
|
||||||
|
(error.message.includes('405') ||
|
||||||
|
error.message.includes('Method Not Allowed') ||
|
||||||
|
error.message.toLowerCase().includes('protocol'))) {
|
||||||
|
|
||||||
|
// This condition is already checked above, so we will never reach here with REASON_FALLBACK_TO_SSE
|
||||||
|
// But keeping it as a safeguard
|
||||||
|
|
||||||
|
log(`Received error: ${error.message}`)
|
||||||
|
log(`Recursively reconnecting for reason: ${REASON_FALLBACK_TO_SSE}`)
|
||||||
|
|
||||||
|
// Add to recursion reasons set
|
||||||
|
recursionReasons.add(REASON_FALLBACK_TO_SSE)
|
||||||
|
|
||||||
|
// Recursively call connectToRemoteServer with the updated recursion tracking
|
||||||
|
return connectToRemoteServer(serverUrl, authProvider, headers, waitForAuthCode, skipBrowserAuth, recursionReasons)
|
||||||
|
} else if (error instanceof UnauthorizedError || (error instanceof Error && error.message.includes('Unauthorized'))) {
|
||||||
if (skipBrowserAuth) {
|
if (skipBrowserAuth) {
|
||||||
log('Authentication required but skipping browser auth - using shared auth')
|
log('Authentication required but skipping browser auth - using shared auth')
|
||||||
} else {
|
} else {
|
||||||
|
@ -131,12 +160,18 @@ export async function connectToRemoteServer(
|
||||||
log('Completing authorization...')
|
log('Completing authorization...')
|
||||||
await transport.finishAuth(code)
|
await transport.finishAuth(code)
|
||||||
|
|
||||||
// Create a new transport after auth
|
if (recursionReasons.has(REASON_AUTH_NEEDED)) {
|
||||||
// TODO: this needs to be the same transport type as the originals
|
const errorMessage = `Already attempted reconnection for reason: ${REASON_AUTH_NEEDED}. Giving up.`
|
||||||
const newTransport = new SSEClientTransport(url, { authProvider, requestInit: { headers } })
|
log(errorMessage)
|
||||||
await newTransport.start()
|
throw new Error(errorMessage)
|
||||||
log('Connected to remote server after authentication')
|
}
|
||||||
return newTransport
|
|
||||||
|
// Track this reason for recursion
|
||||||
|
recursionReasons.add(REASON_AUTH_NEEDED)
|
||||||
|
log(`Recursively reconnecting for reason: ${REASON_AUTH_NEEDED}`)
|
||||||
|
|
||||||
|
// Recursively call connectToRemoteServer with the updated recursion tracking
|
||||||
|
return connectToRemoteServer(serverUrl, authProvider, headers, waitForAuthCode, skipBrowserAuth, recursionReasons)
|
||||||
} catch (authError) {
|
} catch (authError) {
|
||||||
log('Authorization error:', authError)
|
log('Authorization error:', authError)
|
||||||
throw authError
|
throw authError
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue