Added Streamable HTTP support

This adds a new CLI argument, --transport, with the following values: http-first (the default), http-only, sse-first, and sse-only. Any of the -first tags attempts to connect to the URL as either an HTTP or SSE server and falls back to the other.
This commit is contained in:
Glen Maddern 2025-04-16 16:59:36 +10:00 committed by Glen Maddern
parent 504aa26761
commit 04e3d255b1
6 changed files with 373 additions and 231 deletions

View file

@ -1,6 +1,15 @@
import { OAuthClientProvider, UnauthorizedError } from '@modelcontextprotocol/sdk/client/auth.js'
import { Client } from '@modelcontextprotocol/sdk/client/index.js'
import { SSEClientTransport } from '@modelcontextprotocol/sdk/client/sse.js'
import { StreamableHTTPClientTransport } from '@modelcontextprotocol/sdk/client/streamableHttp.js'
import { Transport } from '@modelcontextprotocol/sdk/shared/transport.js'
// Connection constants
export const REASON_AUTH_NEEDED = 'authentication-needed'
export const REASON_TRANSPORT_FALLBACK = 'falling-back-to-alternate-transport'
// Transport strategy types
export type TransportStrategy = 'sse-only' | 'http-only' | 'sse-first' | 'http-first'
import { OAuthCallbackServerOptions } from './types'
import express from 'express'
import net from 'net'
@ -65,21 +74,33 @@ export function mcpProxy({ transportToClient, transportToServer }: { transportTo
}
/**
* Creates and connects to a remote SSE server with OAuth authentication
* Type for the auth initialization function
*/
export type AuthInitializer = () => Promise<{
waitForAuthCode: () => Promise<string>
skipBrowserAuth: boolean
}>
/**
* Creates and connects to a remote server with OAuth authentication
* @param client The client to connect with
* @param serverUrl The URL of the remote server
* @param authProvider The OAuth client provider
* @param headers Additional headers to send with the request
* @param waitForAuthCode Function to wait for the auth code
* @param skipBrowserAuth Whether to skip browser auth and use shared auth
* @returns The connected SSE client transport
* @param authInitializer Function to initialize authentication when needed
* @param transportStrategy Strategy for selecting transport type ('sse-only', 'http-only', 'sse-first', 'http-first')
* @param recursionReasons Set of reasons for recursive calls (internal use)
* @returns The connected transport
*/
export async function connectToRemoteServer(
client: Client | null,
serverUrl: string,
authProvider: OAuthClientProvider,
headers: Record<string, string>,
waitForAuthCode: () => Promise<string>,
skipBrowserAuth: boolean = false,
): Promise<SSEClientTransport> {
authInitializer: AuthInitializer,
transportStrategy: TransportStrategy = 'http-first',
recursionReasons: Set<string> = new Set(),
): Promise<Transport> {
log(`[${pid}] Connecting to remote server: ${serverUrl}`)
const url = new URL(serverUrl)
@ -93,25 +114,88 @@ export async function connectToRemoteServer(
...(init?.headers as Record<string, string> | undefined),
...headers,
...(tokens?.access_token ? { Authorization: `Bearer ${tokens.access_token}` } : {}),
Accept: "text/event-stream",
Accept: 'text/event-stream',
} as Record<string, string>,
})
);
}),
)
},
};
}
const transport = new SSEClientTransport(url, {
authProvider,
requestInit: { headers },
eventSourceInit,
})
log(`Using transport strategy: ${transportStrategy}`)
// Determine if we should attempt to fallback on error
// Choose transport based on user strategy and recursion history
const shouldAttemptFallback = transportStrategy === 'http-first' || transportStrategy === 'sse-first'
// Create transport instance based on the strategy
const sseTransport = transportStrategy === 'sse-only' || transportStrategy === 'sse-first'
const transport = sseTransport
? new SSEClientTransport(url, {
authProvider,
requestInit: { headers },
eventSourceInit,
})
: new StreamableHTTPClientTransport(url, {
authProvider,
requestInit: { headers },
})
try {
await transport.start()
log('Connected to remote server')
if (client) {
await client.connect(transport)
} else {
await transport.start()
if (!sseTransport) {
// Extremely hacky, but we didn't actually send a request when calling transport.start() above, so we don't
// know if we're even talking to an HTTP server. But if we forced that now we'd get an error later saying that
// the client is already connected. So let's just create a one-off client to make a single request and figure
// out if we're actually talking to an HTTP server or not.
const testTransport = new StreamableHTTPClientTransport(url, { authProvider, requestInit: { headers } })
const testClient = new Client({ name: 'mcp-remote-fallback-test', version: '0.0.0' }, { capabilities: {} })
await testClient.connect(testTransport)
}
}
log(`Connected to remote server using ${transport.constructor.name}`)
return transport
} catch (error) {
if (error instanceof UnauthorizedError || (error instanceof Error && error.message.includes('Unauthorized'))) {
// Check if it's a protocol error and we should attempt fallback
if (
error instanceof Error &&
shouldAttemptFallback &&
(sseTransport
? error.message.includes('405') || error.message.includes('Method Not Allowed')
: error.message.includes('404') || error.message.includes('Not Found'))
) {
log(`Received error: ${error.message}`)
// If we've already tried falling back once, throw an error
if (recursionReasons.has(REASON_TRANSPORT_FALLBACK)) {
const errorMessage = `Already attempted transport fallback. Giving up.`
log(errorMessage)
throw new Error(errorMessage)
}
log(`Recursively reconnecting for reason: ${REASON_TRANSPORT_FALLBACK}`)
// Add to recursion reasons set
recursionReasons.add(REASON_TRANSPORT_FALLBACK)
// Recursively call connectToRemoteServer with the updated recursion tracking
return connectToRemoteServer(
client,
serverUrl,
authProvider,
headers,
authInitializer,
sseTransport ? 'http-only' : 'sse-only',
recursionReasons,
)
} else if (error instanceof UnauthorizedError || (error instanceof Error && error.message.includes('Unauthorized'))) {
log('Authentication required. Initializing auth...')
// Initialize authentication on-demand
const { waitForAuthCode, skipBrowserAuth } = await authInitializer()
if (skipBrowserAuth) {
log('Authentication required but skipping browser auth - using shared auth')
} else {
@ -125,11 +209,18 @@ export async function connectToRemoteServer(
log('Completing authorization...')
await transport.finishAuth(code)
// Create a new transport after auth
const newTransport = new SSEClientTransport(url, { authProvider, requestInit: { headers } })
await newTransport.start()
log('Connected to remote server after authentication')
return newTransport
if (recursionReasons.has(REASON_AUTH_NEEDED)) {
const errorMessage = `Already attempted reconnection for reason: ${REASON_AUTH_NEEDED}. Giving up.`
log(errorMessage)
throw new Error(errorMessage)
}
// Track this reason for recursion
recursionReasons.add(REASON_AUTH_NEEDED)
log(`Recursively reconnecting for reason: ${REASON_AUTH_NEEDED}`)
// Recursively call connectToRemoteServer with the updated recursion tracking
return connectToRemoteServer(client, serverUrl, authProvider, headers, authInitializer, transportStrategy, recursionReasons)
} catch (authError) {
log('Authorization error:', authError)
throw authError
@ -301,6 +392,19 @@ export async function parseCommandLineArgs(args: string[], defaultPort: number,
const specifiedPort = args[1] ? parseInt(args[1]) : undefined
const allowHttp = args.includes('--allow-http')
// Parse transport strategy
let transportStrategy: TransportStrategy = 'http-first' // Default
const transportIndex = args.indexOf('--transport')
if (transportIndex !== -1 && transportIndex < args.length - 1) {
const strategy = args[transportIndex + 1]
if (strategy === 'sse-only' || strategy === 'http-only' || strategy === 'sse-first' || strategy === 'http-first') {
transportStrategy = strategy as TransportStrategy
log(`Using transport strategy: ${transportStrategy}`)
} else {
log(`Warning: Ignoring invalid transport strategy: ${strategy}. Valid values are: sse-only, http-only, sse-first, http-first`)
}
}
if (!serverUrl) {
log(usage)
process.exit(1)
@ -343,7 +447,7 @@ export async function parseCommandLineArgs(args: string[], defaultPort: number,
})
}
return { serverUrl, callbackPort, headers }
return { serverUrl, callbackPort, headers, transportStrategy }
}
/**