import { OAuthClientProvider, UnauthorizedError } from '@modelcontextprotocol/sdk/client/auth.js' import { Client } from '@modelcontextprotocol/sdk/client/index.js' import { SSEClientTransport } from '@modelcontextprotocol/sdk/client/sse.js' import { StreamableHTTPClientTransport } from '@modelcontextprotocol/sdk/client/streamableHttp.js' import { Transport } from '@modelcontextprotocol/sdk/shared/transport.js' // Connection constants export const REASON_AUTH_NEEDED = 'authentication-needed' export const REASON_TRANSPORT_FALLBACK = 'falling-back-to-alternate-transport' // Transport strategy types export type TransportStrategy = 'sse-only' | 'http-only' | 'sse-first' | 'http-first' import { OAuthCallbackServerOptions } from './types' import express from 'express' import net from 'net' import crypto from 'crypto' // Package version from package.json export const MCP_REMOTE_VERSION = require('../../package.json').version const pid = process.pid export function log(str: string, ...rest: unknown[]) { // Using stderr so that it doesn't interfere with stdout console.error(`[${pid}] ${str}`, ...rest) } /** * Creates a bidirectional proxy between two transports * @param params The transport connections to proxy between */ export function mcpProxy({ transportToClient, transportToServer }: { transportToClient: Transport; transportToServer: Transport }) { let transportToClientClosed = false let transportToServerClosed = false transportToClient.onmessage = (message) => { // @ts-expect-error TODO log('[Local→Remote]', message.method || message.id) transportToServer.send(message).catch(onServerError) } transportToServer.onmessage = (message) => { // @ts-expect-error TODO: fix this type log('[Remote→Local]', message.method || message.id) transportToClient.send(message).catch(onClientError) } transportToClient.onclose = () => { if (transportToServerClosed) { return } transportToClientClosed = true transportToServer.close().catch(onServerError) } transportToServer.onclose = () => { if (transportToClientClosed) { return } transportToServerClosed = true transportToClient.close().catch(onClientError) } transportToClient.onerror = onClientError transportToServer.onerror = onServerError function onClientError(error: Error) { log('Error from local client:', error) } function onServerError(error: Error) { log('Error from remote server:', error) } } /** * Type for the auth initialization function */ export type AuthInitializer = () => Promise<{ waitForAuthCode: () => Promise skipBrowserAuth: boolean }> /** * Creates and connects to a remote server with OAuth authentication * @param client The client to connect with * @param serverUrl The URL of the remote server * @param authProvider The OAuth client provider * @param headers Additional headers to send with the request * @param authInitializer Function to initialize authentication when needed * @param transportStrategy Strategy for selecting transport type ('sse-only', 'http-only', 'sse-first', 'http-first') * @param recursionReasons Set of reasons for recursive calls (internal use) * @returns The connected transport */ export async function connectToRemoteServer( client: Client | null, serverUrl: string, authProvider: OAuthClientProvider, headers: Record, authInitializer: AuthInitializer, transportStrategy: TransportStrategy = 'sse-first', recursionReasons: Set = new Set(), ): Promise { log(`[${pid}] Connecting to remote server: ${serverUrl}`) const url = new URL(serverUrl) // Create transport with eventSourceInit to pass Authorization header if present const eventSourceInit = { fetch: (url: string | URL, init?: RequestInit) => { return Promise.resolve(authProvider?.tokens?.()).then((tokens) => fetch(url, { ...init, headers: { ...(init?.headers as Record | undefined), ...headers, ...(tokens?.access_token ? { Authorization: `Bearer ${tokens.access_token}` } : {}), Accept: 'text/event-stream', } as Record, }), ) }, } log(`Using transport strategy: ${transportStrategy}`) // Determine if we should attempt to fallback on error // Choose transport based on user strategy and recursion history const shouldAttemptFallback = transportStrategy === 'http-first' || transportStrategy === 'sse-first' // Create transport instance based on the strategy const sseTransport = transportStrategy === 'sse-only' || transportStrategy === 'sse-first' const transport = sseTransport ? new SSEClientTransport(url, { authProvider, requestInit: { headers }, eventSourceInit, }) : new StreamableHTTPClientTransport(url, { authProvider, requestInit: { headers }, }) try { if (client) { await client.connect(transport) } else { await transport.start() } log(`Connected to remote server using ${transport.constructor.name}`) return transport } catch (error) { // Check if it's a protocol error and we should attempt fallback if ( error instanceof Error && shouldAttemptFallback && (sseTransport ? error.message.includes('405') || error.message.includes('Method Not Allowed') : error.message.includes('404') || error.message.includes('Not Found')) ) { log(`Received error: ${error.message}`) // If we've already tried falling back once, throw an error if (recursionReasons.has(REASON_TRANSPORT_FALLBACK)) { const errorMessage = `Already attempted transport fallback. Giving up.` log(errorMessage) throw new Error(errorMessage) } log(`Recursively reconnecting for reason: ${REASON_TRANSPORT_FALLBACK}`) // Add to recursion reasons set recursionReasons.add(REASON_TRANSPORT_FALLBACK) // Recursively call connectToRemoteServer with the updated recursion tracking return connectToRemoteServer( client, serverUrl, authProvider, headers, authInitializer, sseTransport ? 'http-only' : 'sse-only', recursionReasons, ) } else if (error instanceof UnauthorizedError || (error instanceof Error && error.message.includes('Unauthorized'))) { log('Authentication required. Initializing auth...') // Initialize authentication on-demand const { waitForAuthCode, skipBrowserAuth } = await authInitializer() if (skipBrowserAuth) { log('Authentication required but skipping browser auth - using shared auth') } else { log('Authentication required. Waiting for authorization...') } // Wait for the authorization code from the callback const code = await waitForAuthCode() try { log('Completing authorization...') await transport.finishAuth(code) if (recursionReasons.has(REASON_AUTH_NEEDED)) { const errorMessage = `Already attempted reconnection for reason: ${REASON_AUTH_NEEDED}. Giving up.` log(errorMessage) throw new Error(errorMessage) } // Track this reason for recursion recursionReasons.add(REASON_AUTH_NEEDED) log(`Recursively reconnecting for reason: ${REASON_AUTH_NEEDED}`) // Recursively call connectToRemoteServer with the updated recursion tracking return connectToRemoteServer(client, serverUrl, authProvider, headers, authInitializer, transportStrategy, recursionReasons) } catch (authError) { log('Authorization error:', authError) throw authError } } else { log('Connection error:', error) throw error } } } /** * Sets up an Express server to handle OAuth callbacks * @param options The server options * @returns An object with the server, authCode, and waitForAuthCode function */ export function setupOAuthCallbackServerWithLongPoll(options: OAuthCallbackServerOptions) { let authCode: string | null = null const app = express() // Create a promise to track when auth is completed let authCompletedResolve: (code: string) => void const authCompletedPromise = new Promise((resolve) => { authCompletedResolve = resolve }) // Long-polling endpoint app.get('/wait-for-auth', (req, res) => { if (authCode) { // Auth already completed - just return 200 without the actual code // Secondary instances will read tokens from disk log('Auth already completed, returning 200') res.status(200).send('Authentication completed') return } if (req.query.poll === 'false') { log('Client requested no long poll, responding with 202') res.status(202).send('Authentication in progress') return } // Long poll - wait for up to 30 seconds const longPollTimeout = setTimeout(() => { log('Long poll timeout reached, responding with 202') res.status(202).send('Authentication in progress') }, 30000) // If auth completes while we're waiting, send the response immediately authCompletedPromise .then(() => { clearTimeout(longPollTimeout) if (!res.headersSent) { log('Auth completed during long poll, responding with 200') res.status(200).send('Authentication completed') } }) .catch(() => { clearTimeout(longPollTimeout) if (!res.headersSent) { log('Auth failed during long poll, responding with 500') res.status(500).send('Authentication failed') } }) }) // OAuth callback endpoint app.get(options.path, (req, res) => { const code = req.query.code as string | undefined if (!code) { res.status(400).send('Error: No authorization code received') return } authCode = code log('Auth code received, resolving promise') authCompletedResolve(code) res.send('Authorization successful! You may close this window and return to the CLI.') // Notify main flow that auth code is available options.events.emit('auth-code-received', code) }) const server = app.listen(options.port, () => { log(`OAuth callback server running at http://127.0.0.1:${options.port}`) }) const waitForAuthCode = (): Promise => { return new Promise((resolve) => { if (authCode) { resolve(authCode) return } options.events.once('auth-code-received', (code) => { resolve(code) }) }) } return { server, authCode, waitForAuthCode, authCompletedPromise } } /** * Sets up an Express server to handle OAuth callbacks * @param options The server options * @returns An object with the server, authCode, and waitForAuthCode function */ export function setupOAuthCallbackServer(options: OAuthCallbackServerOptions) { const { server, authCode, waitForAuthCode } = setupOAuthCallbackServerWithLongPoll(options) return { server, authCode, waitForAuthCode } } /** * Finds an available port on the local machine * @param preferredPort Optional preferred port to try first * @returns A promise that resolves to an available port number */ export async function findAvailablePort(preferredPort?: number): Promise { return new Promise((resolve, reject) => { const server = net.createServer() server.on('error', (err: NodeJS.ErrnoException) => { if (err.code === 'EADDRINUSE') { // If preferred port is in use, get a random port server.listen(0) } else { reject(err) } }) server.on('listening', () => { const { port } = server.address() as net.AddressInfo server.close(() => { resolve(port) }) }) // Try preferred port first, or get a random port server.listen(preferredPort || 0) }) } /** * Parses command line arguments for MCP clients and proxies * @param args Command line arguments * @param defaultPort Default port for the callback server if specified port is unavailable * @param usage Usage message to show on error * @returns A promise that resolves to an object with parsed serverUrl, callbackPort and headers */ export async function parseCommandLineArgs(args: string[], defaultPort: number, usage: string) { // Process headers const headers: Record = {} args.forEach((arg, i) => { if (arg === '--header' && i < args.length - 1) { const value = args[i + 1] const match = value.match(/^([A-Za-z0-9_-]+):(.*)$/) if (match) { headers[match[1]] = match[2] } else { log(`Warning: ignoring invalid header argument: ${value}`) } args.splice(i, 2) } }) const serverUrl = args[0] const specifiedPort = args[1] ? parseInt(args[1]) : undefined const allowHttp = args.includes('--allow-http') // Parse transport strategy let transportStrategy: TransportStrategy = 'sse-first' // Default const transportIndex = args.indexOf('--transport') if (transportIndex !== -1 && transportIndex < args.length - 1) { const strategy = args[transportIndex + 1] if (strategy === 'sse-only' || strategy === 'http-only' || strategy === 'sse-first' || strategy === 'http-first') { transportStrategy = strategy as TransportStrategy log(`Using transport strategy: ${transportStrategy}`) } else { log(`Warning: Ignoring invalid transport strategy: ${strategy}. Valid values are: sse-only, http-only, sse-first, http-first`) } } if (!serverUrl) { log(usage) process.exit(1) } const url = new URL(serverUrl) const isLocalhost = (url.hostname === 'localhost' || url.hostname === '127.0.0.1') && url.protocol === 'http:' if (!(url.protocol == 'https:' || isLocalhost || allowHttp)) { log('Error: Non-HTTPS URLs are only allowed for localhost or when --allow-http flag is provided') log(usage) process.exit(1) } // Use the specified port, or find an available one const callbackPort = specifiedPort || (await findAvailablePort(defaultPort)) if (specifiedPort) { log(`Using specified callback port: ${callbackPort}`) } else { log(`Using automatically selected callback port: ${callbackPort}`) } if (Object.keys(headers).length > 0) { log(`Using custom headers: ${JSON.stringify(headers)}`) } // Replace environment variables in headers // example `Authorization: Bearer ${TOKEN}` will read process.env.TOKEN for (const [key, value] of Object.entries(headers)) { headers[key] = value.replace(/\$\{([^}]+)}/g, (match, envVarName) => { const envVarValue = process.env[envVarName] if (envVarValue !== undefined) { log(`Replacing ${match} with environment value in header '${key}'`) return envVarValue } else { log(`Warning: Environment variable '${envVarName}' not found for header '${key}'.`) return '' } }) } return { serverUrl, callbackPort, headers, transportStrategy } } /** * Sets up signal handlers for graceful shutdown * @param cleanup Cleanup function to run on shutdown */ export function setupSignalHandlers(cleanup: () => Promise) { process.on('SIGINT', async () => { log('\nShutting down...') await cleanup() process.exit(0) }) // Keep the process alive process.stdin.resume() } /** * Generates a hash for the server URL to use in filenames * @param serverUrl The server URL to hash * @returns The hashed server URL */ export function getServerUrlHash(serverUrl: string): string { return crypto.createHash('md5').update(serverUrl).digest('hex') }