wip 3
This commit is contained in:
parent
dee974b8b2
commit
f80c6c4850
1 changed files with 47 additions and 20 deletions
|
@ -5,7 +5,10 @@ import { Transport } from '@modelcontextprotocol/sdk/shared/transport.js'
|
||||||
|
|
||||||
// Connection constants
|
// Connection constants
|
||||||
export const REASON_AUTH_NEEDED = 'authentication-needed'
|
export const REASON_AUTH_NEEDED = 'authentication-needed'
|
||||||
export const REASON_FALLBACK_TO_SSE = 'falling-back-to-sse-transport'
|
export const REASON_TRANSPORT_FALLBACK = 'falling-back-to-alternate-transport'
|
||||||
|
|
||||||
|
// Transport strategy types
|
||||||
|
export type TransportStrategy = 'sse-only' | 'http-only' | 'sse-first' | 'http-first'
|
||||||
import { OAuthCallbackServerOptions } from './types'
|
import { OAuthCallbackServerOptions } from './types'
|
||||||
import express from 'express'
|
import express from 'express'
|
||||||
import net from 'net'
|
import net from 'net'
|
||||||
|
@ -70,13 +73,15 @@ export function mcpProxy({ transportToClient, transportToServer }: { transportTo
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Creates and connects to a remote SSE server with OAuth authentication
|
* Creates and connects to a remote server with OAuth authentication
|
||||||
* @param serverUrl The URL of the remote server
|
* @param serverUrl The URL of the remote server
|
||||||
* @param authProvider The OAuth client provider
|
* @param authProvider The OAuth client provider
|
||||||
* @param headers Additional headers to send with the request
|
* @param headers Additional headers to send with the request
|
||||||
* @param waitForAuthCode Function to wait for the auth code
|
* @param waitForAuthCode Function to wait for the auth code
|
||||||
* @param skipBrowserAuth Whether to skip browser auth and use shared auth
|
* @param skipBrowserAuth Whether to skip browser auth and use shared auth
|
||||||
* @returns The connected SSE client transport
|
* @param transportStrategy Strategy for selecting transport type ('sse-only', 'http-only', 'sse-first', 'http-first')
|
||||||
|
* @param recursionReasons Set of reasons for recursive calls (internal use)
|
||||||
|
* @returns The connected transport
|
||||||
*/
|
*/
|
||||||
export async function connectToRemoteServer(
|
export async function connectToRemoteServer(
|
||||||
serverUrl: string,
|
serverUrl: string,
|
||||||
|
@ -84,6 +89,7 @@ export async function connectToRemoteServer(
|
||||||
headers: Record<string, string>,
|
headers: Record<string, string>,
|
||||||
waitForAuthCode: () => Promise<string>,
|
waitForAuthCode: () => Promise<string>,
|
||||||
skipBrowserAuth: boolean = false,
|
skipBrowserAuth: boolean = false,
|
||||||
|
transportStrategy: TransportStrategy = 'http-first',
|
||||||
recursionReasons: Set<string> = new Set(),
|
recursionReasons: Set<string> = new Set(),
|
||||||
): Promise<Transport> {
|
): Promise<Transport> {
|
||||||
log(`[${pid}] Connecting to remote server: ${serverUrl}`)
|
log(`[${pid}] Connecting to remote server: ${serverUrl}`)
|
||||||
|
@ -106,21 +112,32 @@ export async function connectToRemoteServer(
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
// Choose transport based on recursion history
|
// Choose transport based on user strategy and recursion history
|
||||||
let transport;
|
let transport;
|
||||||
|
let shouldAttemptFallback = false;
|
||||||
|
|
||||||
if (recursionReasons.has(REASON_FALLBACK_TO_SSE)) {
|
// If we've already tried falling back once, throw an error
|
||||||
log('Using SSEClientTransport due to previous protocol failure')
|
if (recursionReasons.has(REASON_TRANSPORT_FALLBACK)) {
|
||||||
|
const errorMessage = `Already attempted transport fallback. Giving up.`;
|
||||||
|
log(errorMessage);
|
||||||
|
throw new Error(errorMessage);
|
||||||
|
}
|
||||||
|
|
||||||
|
log(`Using transport strategy: ${transportStrategy}`);
|
||||||
|
// Determine if we should attempt to fallback on error
|
||||||
|
shouldAttemptFallback = transportStrategy === 'http-first' || transportStrategy === 'sse-first';
|
||||||
|
|
||||||
|
// Create transport instance based on the strategy
|
||||||
|
if (transportStrategy === 'sse-only' || transportStrategy === 'sse-first') {
|
||||||
transport = new SSEClientTransport(url, {
|
transport = new SSEClientTransport(url, {
|
||||||
authProvider,
|
authProvider,
|
||||||
requestInit: { headers },
|
requestInit: { headers },
|
||||||
eventSourceInit,
|
eventSourceInit,
|
||||||
})
|
});
|
||||||
} else {
|
} else { // http-only or http-first
|
||||||
log('Trying StreamableHTTPClientTransport first')
|
|
||||||
transport = new StreamableHTTPClientTransport(url, {
|
transport = new StreamableHTTPClientTransport(url, {
|
||||||
sessionId: crypto.randomUUID(),
|
sessionId: crypto.randomUUID(),
|
||||||
})
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
try {
|
try {
|
||||||
|
@ -128,24 +145,21 @@ export async function connectToRemoteServer(
|
||||||
log(`Connected to remote server using ${transport.constructor.name}`)
|
log(`Connected to remote server using ${transport.constructor.name}`)
|
||||||
return transport
|
return transport
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
// Check if it's a 405 Method Not Allowed error or similar protocol issue
|
// Check if it's a protocol error and we should attempt fallback
|
||||||
if (error instanceof Error &&
|
if (error instanceof Error &&
|
||||||
!recursionReasons.has(REASON_FALLBACK_TO_SSE) &&
|
shouldAttemptFallback &&
|
||||||
(error.message.includes('405') ||
|
(error.message.includes('405') ||
|
||||||
error.message.includes('Method Not Allowed') ||
|
error.message.includes('Method Not Allowed') ||
|
||||||
error.message.toLowerCase().includes('protocol'))) {
|
error.message.toLowerCase().includes('protocol'))) {
|
||||||
|
|
||||||
// This condition is already checked above, so we will never reach here with REASON_FALLBACK_TO_SSE
|
|
||||||
// But keeping it as a safeguard
|
|
||||||
|
|
||||||
log(`Received error: ${error.message}`)
|
log(`Received error: ${error.message}`)
|
||||||
log(`Recursively reconnecting for reason: ${REASON_FALLBACK_TO_SSE}`)
|
log(`Recursively reconnecting for reason: ${REASON_TRANSPORT_FALLBACK}`)
|
||||||
|
|
||||||
// Add to recursion reasons set
|
// Add to recursion reasons set
|
||||||
recursionReasons.add(REASON_FALLBACK_TO_SSE)
|
recursionReasons.add(REASON_TRANSPORT_FALLBACK)
|
||||||
|
|
||||||
// Recursively call connectToRemoteServer with the updated recursion tracking
|
// Recursively call connectToRemoteServer with the updated recursion tracking
|
||||||
return connectToRemoteServer(serverUrl, authProvider, headers, waitForAuthCode, skipBrowserAuth, recursionReasons)
|
return connectToRemoteServer(serverUrl, authProvider, headers, waitForAuthCode, skipBrowserAuth, transportStrategy, recursionReasons)
|
||||||
} else if (error instanceof UnauthorizedError || (error instanceof Error && error.message.includes('Unauthorized'))) {
|
} else if (error instanceof UnauthorizedError || (error instanceof Error && error.message.includes('Unauthorized'))) {
|
||||||
if (skipBrowserAuth) {
|
if (skipBrowserAuth) {
|
||||||
log('Authentication required but skipping browser auth - using shared auth')
|
log('Authentication required but skipping browser auth - using shared auth')
|
||||||
|
@ -171,7 +185,7 @@ export async function connectToRemoteServer(
|
||||||
log(`Recursively reconnecting for reason: ${REASON_AUTH_NEEDED}`)
|
log(`Recursively reconnecting for reason: ${REASON_AUTH_NEEDED}`)
|
||||||
|
|
||||||
// Recursively call connectToRemoteServer with the updated recursion tracking
|
// Recursively call connectToRemoteServer with the updated recursion tracking
|
||||||
return connectToRemoteServer(serverUrl, authProvider, headers, waitForAuthCode, skipBrowserAuth, recursionReasons)
|
return connectToRemoteServer(serverUrl, authProvider, headers, waitForAuthCode, skipBrowserAuth, transportStrategy, recursionReasons)
|
||||||
} catch (authError) {
|
} catch (authError) {
|
||||||
log('Authorization error:', authError)
|
log('Authorization error:', authError)
|
||||||
throw authError
|
throw authError
|
||||||
|
@ -343,6 +357,19 @@ export async function parseCommandLineArgs(args: string[], defaultPort: number,
|
||||||
const specifiedPort = args[1] ? parseInt(args[1]) : undefined
|
const specifiedPort = args[1] ? parseInt(args[1]) : undefined
|
||||||
const allowHttp = args.includes('--allow-http')
|
const allowHttp = args.includes('--allow-http')
|
||||||
|
|
||||||
|
// Parse transport strategy
|
||||||
|
let transportStrategy: TransportStrategy = 'http-first' // Default
|
||||||
|
const transportIndex = args.indexOf('--transport')
|
||||||
|
if (transportIndex !== -1 && transportIndex < args.length - 1) {
|
||||||
|
const strategy = args[transportIndex + 1]
|
||||||
|
if (strategy === 'sse-only' || strategy === 'http-only' || strategy === 'sse-first' || strategy === 'http-first') {
|
||||||
|
transportStrategy = strategy as TransportStrategy
|
||||||
|
log(`Using transport strategy: ${transportStrategy}`)
|
||||||
|
} else {
|
||||||
|
log(`Warning: Ignoring invalid transport strategy: ${strategy}. Valid values are: sse-only, http-only, sse-first, http-first`)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if (!serverUrl) {
|
if (!serverUrl) {
|
||||||
log(usage)
|
log(usage)
|
||||||
process.exit(1)
|
process.exit(1)
|
||||||
|
@ -385,7 +412,7 @@ export async function parseCommandLineArgs(args: string[], defaultPort: number,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
return { serverUrl, callbackPort, headers }
|
return { serverUrl, callbackPort, headers, transportStrategy }
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue