Merge pull request #13 from geelen/no-multiball

Tried to stop multiple windows opening at once
This commit is contained in:
Glen Maddern 2025-03-31 22:35:44 +11:00 committed by GitHub
commit c6f98ff4b7
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 397 additions and 65 deletions

View file

@ -1,6 +1,6 @@
{ {
"name": "mcp-remote", "name": "mcp-remote",
"version": "0.0.13", "version": "0.0.15",
"description": "Remote proxy for Model Context Protocol, allowing local-only clients to connect to remote servers using oAuth", "description": "Remote proxy for Model Context Protocol, allowing local-only clients to connect to remote servers using oAuth",
"keywords": [ "keywords": [
"mcp", "mcp",
@ -10,7 +10,7 @@
"oauth" "oauth"
], ],
"author": "Glen Maddern <glen@cloudflare.com>", "author": "Glen Maddern <glen@cloudflare.com>",
"repository": "https://github.com/geelen/remote-mcp", "repository": "https://github.com/geelen/mcp-remote",
"type": "module", "type": "module",
"files": [ "files": [
"dist", "dist",

View file

@ -18,7 +18,8 @@ import { SSEClientTransport } from '@modelcontextprotocol/sdk/client/sse.js'
import { ListResourcesResultSchema, ListToolsResultSchema } from '@modelcontextprotocol/sdk/types.js' import { ListResourcesResultSchema, ListToolsResultSchema } from '@modelcontextprotocol/sdk/types.js'
import { UnauthorizedError } from '@modelcontextprotocol/sdk/client/auth.js' import { UnauthorizedError } from '@modelcontextprotocol/sdk/client/auth.js'
import { NodeOAuthClientProvider } from './lib/node-oauth-client-provider' import { NodeOAuthClientProvider } from './lib/node-oauth-client-provider'
import { parseCommandLineArgs, setupOAuthCallbackServer, setupSignalHandlers, MCP_REMOTE_VERSION } from './lib/utils' import { parseCommandLineArgs, setupSignalHandlers, log, MCP_REMOTE_VERSION, getServerUrlHash } from './lib/utils'
import { coordinateAuth } from './lib/coordination'
/** /**
* Main function to run the client * Main function to run the client
@ -27,6 +28,12 @@ async function runClient(serverUrl: string, callbackPort: number, clean: boolean
// Set up event emitter for auth flow // Set up event emitter for auth flow
const events = new EventEmitter() const events = new EventEmitter()
// Get the server URL hash for lockfile operations
const serverUrlHash = getServerUrlHash(serverUrl)
// Coordinate authentication with other instances
const { server, waitForAuthCode, skipBrowserAuth } = await coordinateAuth(serverUrlHash, callbackPort, events)
// Create the OAuth client provider // Create the OAuth client provider
const authProvider = new NodeOAuthClientProvider({ const authProvider = new NodeOAuthClientProvider({
serverUrl, serverUrl,
@ -35,6 +42,14 @@ async function runClient(serverUrl: string, callbackPort: number, clean: boolean
clean, clean,
}) })
// If auth was completed by another instance, just log that we'll use the auth from disk
if (skipBrowserAuth) {
log('Authentication was completed by another instance - will use tokens from disk...')
// TODO: remove, the callback is happening before the tokens are exchanged
// so we're slightly too early
await new Promise((res) => setTimeout(res, 1_000))
}
// Create the client // Create the client
const client = new Client( const client = new Client(
{ {
@ -53,15 +68,15 @@ async function runClient(serverUrl: string, callbackPort: number, clean: boolean
// Set up message and error handlers // Set up message and error handlers
transport.onmessage = (message) => { transport.onmessage = (message) => {
console.log('Received message:', JSON.stringify(message, null, 2)) log('Received message:', JSON.stringify(message, null, 2))
} }
transport.onerror = (error) => { transport.onerror = (error) => {
console.error('Transport error:', error) log('Transport error:', error)
} }
transport.onclose = () => { transport.onclose = () => {
console.log('Connection closed.') log('Connection closed.')
process.exit(0) process.exit(0)
} }
return transport return transport
@ -69,16 +84,9 @@ async function runClient(serverUrl: string, callbackPort: number, clean: boolean
const transport = initTransport() const transport = initTransport()
// Set up an HTTP server to handle OAuth callback
const { server, waitForAuthCode } = setupOAuthCallbackServer({
port: callbackPort,
path: '/oauth/callback',
events,
})
// Set up cleanup handler // Set up cleanup handler
const cleanup = async () => { const cleanup = async () => {
console.log('\nClosing connection...') log('\nClosing connection...')
await client.close() await client.close()
server.close() server.close()
} }
@ -86,44 +94,44 @@ async function runClient(serverUrl: string, callbackPort: number, clean: boolean
// Try to connect // Try to connect
try { try {
console.log('Connecting to server...') log('Connecting to server...')
await client.connect(transport) await client.connect(transport)
console.log('Connected successfully!') log('Connected successfully!')
} catch (error) { } catch (error) {
if (error instanceof UnauthorizedError || (error instanceof Error && error.message.includes('Unauthorized'))) { if (error instanceof UnauthorizedError || (error instanceof Error && error.message.includes('Unauthorized'))) {
console.log('Authentication required. Waiting for authorization...') log('Authentication required. Waiting for authorization...')
// Wait for the authorization code from the callback // Wait for the authorization code from the callback or another instance
const code = await waitForAuthCode() const code = await waitForAuthCode()
try { try {
console.log('Completing authorization...') log('Completing authorization...')
await transport.finishAuth(code) await transport.finishAuth(code)
// Reconnect after authorization with a new transport // Reconnect after authorization with a new transport
console.log('Connecting after authorization...') log('Connecting after authorization...')
await client.connect(initTransport()) await client.connect(initTransport())
console.log('Connected successfully!') log('Connected successfully!')
// Request tools list after auth // Request tools list after auth
console.log('Requesting tools list...') log('Requesting tools list...')
const tools = await client.request({ method: 'tools/list' }, ListToolsResultSchema) const tools = await client.request({ method: 'tools/list' }, ListToolsResultSchema)
console.log('Tools:', JSON.stringify(tools, null, 2)) log('Tools:', JSON.stringify(tools, null, 2))
// Request resources list after auth // Request resources list after auth
console.log('Requesting resource list...') log('Requesting resource list...')
const resources = await client.request({ method: 'resources/list' }, ListResourcesResultSchema) const resources = await client.request({ method: 'resources/list' }, ListResourcesResultSchema)
console.log('Resources:', JSON.stringify(resources, null, 2)) log('Resources:', JSON.stringify(resources, null, 2))
console.log('Listening for messages. Press Ctrl+C to exit.') log('Listening for messages. Press Ctrl+C to exit.')
} catch (authError) { } catch (authError) {
console.error('Authorization error:', authError) log('Authorization error:', authError)
server.close() server.close()
process.exit(1) process.exit(1)
} }
} else { } else {
console.error('Connection error:', error) log('Connection error:', error)
server.close() server.close()
process.exit(1) process.exit(1)
} }
@ -131,23 +139,23 @@ async function runClient(serverUrl: string, callbackPort: number, clean: boolean
try { try {
// Request tools list // Request tools list
console.log('Requesting tools list...') log('Requesting tools list...')
const tools = await client.request({ method: 'tools/list' }, ListToolsResultSchema) const tools = await client.request({ method: 'tools/list' }, ListToolsResultSchema)
console.log('Tools:', JSON.stringify(tools, null, 2)) log('Tools:', JSON.stringify(tools, null, 2))
} catch (e) { } catch (e) {
console.log('Error requesting tools list:', e) log('Error requesting tools list:', e)
} }
try { try {
// Request resources list // Request resources list
console.log('Requesting resource list...') log('Requesting resource list...')
const resources = await client.request({ method: 'resources/list' }, ListResourcesResultSchema) const resources = await client.request({ method: 'resources/list' }, ListResourcesResultSchema)
console.log('Resources:', JSON.stringify(resources, null, 2)) log('Resources:', JSON.stringify(resources, null, 2))
} catch (e) { } catch (e) {
console.log('Error requesting resources list:', e) log('Error requesting resources list:', e)
} }
console.log('Listening for messages. Press Ctrl+C to exit.') log('Listening for messages. Press Ctrl+C to exit.')
} }
// Parse command-line arguments and run the client // Parse command-line arguments and run the client

188
src/lib/coordination.ts Normal file
View file

@ -0,0 +1,188 @@
import { checkLockfile, createLockfile, deleteLockfile, getConfigFilePath, LockfileData } from './mcp-auth-config'
import { EventEmitter } from 'events'
import { Server } from 'http'
import express from 'express'
import { AddressInfo } from 'net'
import { log, setupOAuthCallbackServerWithLongPoll } from './utils'
/**
* Checks if a process with the given PID is running
* @param pid The process ID to check
* @returns True if the process is running, false otherwise
*/
export async function isPidRunning(pid: number): Promise<boolean> {
try {
process.kill(pid, 0) // Doesn't kill the process, just checks if it exists
return true
} catch {
return false
}
}
/**
* Checks if a lockfile is valid (process running and endpoint accessible)
* @param lockData The lockfile data
* @returns True if the lockfile is valid, false otherwise
*/
export async function isLockValid(lockData: LockfileData): Promise<boolean> {
// Check if the lockfile is too old (over 30 minutes)
const MAX_LOCK_AGE = 30 * 60 * 1000 // 30 minutes
if (Date.now() - lockData.timestamp > MAX_LOCK_AGE) {
log('Lockfile is too old')
return false
}
// Check if the process is still running
if (!(await isPidRunning(lockData.pid))) {
log('Process from lockfile is not running')
return false
}
// Check if the endpoint is accessible
try {
const controller = new AbortController()
const timeout = setTimeout(() => controller.abort(), 1000)
const response = await fetch(`http://127.0.0.1:${lockData.port}/wait-for-auth?poll=false`, {
signal: controller.signal,
})
clearTimeout(timeout)
return response.status === 200 || response.status === 202
} catch (error) {
log(`Error connecting to auth server: ${(error as Error).message}`)
return false
}
}
/**
* Waits for authentication from another server instance
* @param port The port to connect to
* @returns True if authentication completed successfully, false otherwise
*/
export async function waitForAuthentication(port: number): Promise<boolean> {
log(`Waiting for authentication from the server on port ${port}...`)
try {
while (true) {
const url = `http://127.0.0.1:${port}/wait-for-auth`
log(`Querying: ${url}`)
const response = await fetch(url)
if (response.status === 200) {
// Auth completed, but we don't return the code anymore
log(`Authentication completed by other instance`)
return true
} else if (response.status === 202) {
// Continue polling
log(`Authentication still in progress`)
await new Promise(resolve => setTimeout(resolve, 1000))
} else {
log(`Unexpected response status: ${response.status}`)
return false
}
}
} catch (error) {
log(`Error waiting for authentication: ${(error as Error).message}`)
return false
}
}
/**
* Coordinates authentication between multiple instances of the client/proxy
* @param serverUrlHash The hash of the server URL
* @param callbackPort The port to use for the callback server
* @param events The event emitter to use for signaling
* @returns An object with the server, waitForAuthCode function, and a flag indicating if browser auth can be skipped
*/
export async function coordinateAuth(
serverUrlHash: string,
callbackPort: number,
events: EventEmitter,
): Promise<{ server: Server; waitForAuthCode: () => Promise<string>; skipBrowserAuth: boolean }> {
// Check for a lockfile
const lockData = await checkLockfile(serverUrlHash)
// If there's a valid lockfile, try to use the existing auth process
if (lockData && (await isLockValid(lockData))) {
log(`Another instance is handling authentication on port ${lockData.port}`)
try {
// Try to wait for the authentication to complete
const authCompleted = await waitForAuthentication(lockData.port)
if (authCompleted) {
log('Authentication completed by another instance')
// Setup a dummy server - the client will use tokens directly from disk
const dummyServer = express().listen(0) // Listen on any available port
// This shouldn't actually be called in normal operation, but provide it for API compatibility
const dummyWaitForAuthCode = () => {
log('WARNING: waitForAuthCode called in secondary instance - this is unexpected')
// Return a promise that never resolves - the client should use the tokens from disk instead
return new Promise<string>(() => {})
}
return {
server: dummyServer,
waitForAuthCode: dummyWaitForAuthCode,
skipBrowserAuth: true,
}
} else {
log('Taking over authentication process...')
}
} catch (error) {
log(`Error waiting for authentication: ${error}`)
}
// If we get here, the other process didn't complete auth successfully
await deleteLockfile(serverUrlHash)
} else if (lockData) {
// Invalid lockfile, delete its
log('Found invalid lockfile, deleting it')
await deleteLockfile(serverUrlHash)
}
// Create our own lockfile
const { server, waitForAuthCode, authCompletedPromise } = setupOAuthCallbackServerWithLongPoll({
port: callbackPort,
path: '/oauth/callback',
events,
})
// Get the actual port the server is running on
const address = server.address() as AddressInfo
const actualPort = address.port
log(`Creating lockfile for server ${serverUrlHash} with process ${process.pid} on port ${actualPort}`)
await createLockfile(serverUrlHash, process.pid, actualPort)
// Make sure lockfile is deleted on process exit
const cleanupHandler = async () => {
try {
log(`Cleaning up lockfile for server ${serverUrlHash}`)
await deleteLockfile(serverUrlHash)
} catch (error) {
log(`Error cleaning up lockfile: ${error}`)
}
}
process.once('exit', () => {
try {
// Synchronous version for 'exit' event since we can't use async here
const configPath = getConfigFilePath(serverUrlHash, 'lock.json')
require('fs').unlinkSync(configPath)
} catch {}
})
// Also handle SIGINT separately
process.once('SIGINT', async () => {
await cleanupHandler()
})
return {
server,
waitForAuthCode,
skipBrowserAuth: false
}
}

View file

@ -1,4 +1,3 @@
import crypto from 'crypto'
import path from 'path' import path from 'path'
import os from 'os' import os from 'os'
import fs from 'fs/promises' import fs from 'fs/promises'
@ -27,7 +26,61 @@ import { log, MCP_REMOTE_VERSION } from './utils'
/** /**
* Known configuration file names that might need to be cleaned * Known configuration file names that might need to be cleaned
*/ */
export const knownConfigFiles = ['client_info.json', 'tokens.json', 'code_verifier.txt'] export const knownConfigFiles = ['client_info.json', 'tokens.json', 'code_verifier.txt', 'lock.json']
/**
* Lockfile data structure
*/
export interface LockfileData {
pid: number
port: number
timestamp: number
}
/**
* Creates a lockfile for the given server
* @param serverUrlHash The hash of the server URL
* @param pid The process ID
* @param port The port the server is running on
*/
export async function createLockfile(serverUrlHash: string, pid: number, port: number): Promise<void> {
const lockData: LockfileData = {
pid,
port,
timestamp: Date.now(),
}
await writeJsonFile(serverUrlHash, 'lock.json', lockData)
}
/**
* Checks if a lockfile exists for the given server
* @param serverUrlHash The hash of the server URL
* @returns The lockfile data or null if it doesn't exist
*/
export async function checkLockfile(serverUrlHash: string): Promise<LockfileData | null> {
try {
const lockfile = await readJsonFile<LockfileData>(serverUrlHash, 'lock.json', {
async parseAsync(data: any) {
if (typeof data !== 'object' || data === null) return null
if (typeof data.pid !== 'number' || typeof data.port !== 'number' || typeof data.timestamp !== 'number') {
return null
}
return data as LockfileData
},
})
return lockfile || null
} catch {
return null
}
}
/**
* Deletes the lockfile for the given server
* @param serverUrlHash The hash of the server URL
*/
export async function deleteLockfile(serverUrlHash: string): Promise<void> {
await deleteConfigFile(serverUrlHash, 'lock.json')
}
/** /**
* Deletes all known configuration files for a specific server * Deletes all known configuration files for a specific server
@ -63,15 +116,6 @@ export async function ensureConfigDir(): Promise<void> {
} }
} }
/**
* Generates a hash for the server URL to use in filenames
* @param serverUrl The server URL to hash
* @returns The hashed server URL
*/
export function getServerUrlHash(serverUrl: string): string {
return crypto.createHash('md5').update(serverUrl).digest('hex')
}
/** /**
* Gets the file path for a config file * Gets the file path for a config file
* @param serverUrlHash The hash of the server URL * @param serverUrlHash The hash of the server URL
@ -125,11 +169,15 @@ export async function readJsonFile<T>(
const filePath = getConfigFilePath(serverUrlHash, filename) const filePath = getConfigFilePath(serverUrlHash, filename)
const content = await fs.readFile(filePath, 'utf-8') const content = await fs.readFile(filePath, 'utf-8')
return await schema.parseAsync(JSON.parse(content)) const result = await schema.parseAsync(JSON.parse(content))
// console.log({ filename: result })
return result
} catch (error) { } catch (error) {
if ((error as NodeJS.ErrnoException).code === 'ENOENT') { if ((error as NodeJS.ErrnoException).code === 'ENOENT') {
// console.log(`File ${filename} does not exist`)
return undefined return undefined
} }
log(`Error reading ${filename}:`, error)
return undefined return undefined
} }
} }

View file

@ -8,8 +8,8 @@ import {
OAuthTokensSchema, OAuthTokensSchema,
} from '@modelcontextprotocol/sdk/shared/auth.js' } from '@modelcontextprotocol/sdk/shared/auth.js'
import type { OAuthProviderOptions } from './types' import type { OAuthProviderOptions } from './types'
import { getServerUrlHash, readJsonFile, writeJsonFile, readTextFile, writeTextFile, cleanServerConfig } from './mcp-auth-config' import { readJsonFile, writeJsonFile, readTextFile, writeTextFile, cleanServerConfig } from './mcp-auth-config'
import { log } from './utils' import { getServerUrlHash, log } from './utils'
/** /**
* Implements the OAuthClientProvider interface for Node.js environments. * Implements the OAuthClientProvider interface for Node.js environments.
@ -59,6 +59,7 @@ export class NodeOAuthClientProvider implements OAuthClientProvider {
* @returns The client information or undefined * @returns The client information or undefined
*/ */
async clientInformation(): Promise<OAuthClientInformation | undefined> { async clientInformation(): Promise<OAuthClientInformation | undefined> {
// log('Reading client info')
return readJsonFile<OAuthClientInformation>(this.serverUrlHash, 'client_info.json', OAuthClientInformationSchema, this.options.clean) return readJsonFile<OAuthClientInformation>(this.serverUrlHash, 'client_info.json', OAuthClientInformationSchema, this.options.clean)
} }
@ -67,6 +68,7 @@ export class NodeOAuthClientProvider implements OAuthClientProvider {
* @param clientInformation The client information to save * @param clientInformation The client information to save
*/ */
async saveClientInformation(clientInformation: OAuthClientInformationFull): Promise<void> { async saveClientInformation(clientInformation: OAuthClientInformationFull): Promise<void> {
// log('Saving client info')
await writeJsonFile(this.serverUrlHash, 'client_info.json', clientInformation) await writeJsonFile(this.serverUrlHash, 'client_info.json', clientInformation)
} }
@ -75,6 +77,8 @@ export class NodeOAuthClientProvider implements OAuthClientProvider {
* @returns The OAuth tokens or undefined * @returns The OAuth tokens or undefined
*/ */
async tokens(): Promise<OAuthTokens | undefined> { async tokens(): Promise<OAuthTokens | undefined> {
// log('Reading tokens')
// console.log(new Error().stack)
return readJsonFile<OAuthTokens>(this.serverUrlHash, 'tokens.json', OAuthTokensSchema, this.options.clean) return readJsonFile<OAuthTokens>(this.serverUrlHash, 'tokens.json', OAuthTokensSchema, this.options.clean)
} }
@ -83,6 +87,7 @@ export class NodeOAuthClientProvider implements OAuthClientProvider {
* @param tokens The tokens to save * @param tokens The tokens to save
*/ */
async saveTokens(tokens: OAuthTokens): Promise<void> { async saveTokens(tokens: OAuthTokens): Promise<void> {
// log('Saving tokens')
await writeJsonFile(this.serverUrlHash, 'tokens.json', tokens) await writeJsonFile(this.serverUrlHash, 'tokens.json', tokens)
} }
@ -105,6 +110,7 @@ export class NodeOAuthClientProvider implements OAuthClientProvider {
* @param codeVerifier The code verifier to save * @param codeVerifier The code verifier to save
*/ */
async saveCodeVerifier(codeVerifier: string): Promise<void> { async saveCodeVerifier(codeVerifier: string): Promise<void> {
// log('Saving code verifier')
await writeTextFile(this.serverUrlHash, 'code_verifier.txt', codeVerifier) await writeTextFile(this.serverUrlHash, 'code_verifier.txt', codeVerifier)
} }
@ -113,6 +119,7 @@ export class NodeOAuthClientProvider implements OAuthClientProvider {
* @returns The code verifier * @returns The code verifier
*/ */
async codeVerifier(): Promise<string> { async codeVerifier(): Promise<string> {
// log('Reading code verifier')
return await readTextFile(this.serverUrlHash, 'code_verifier.txt', 'No code verifier saved for session', this.options.clean) return await readTextFile(this.serverUrlHash, 'code_verifier.txt', 'No code verifier saved for session', this.options.clean)
} }
} }

View file

@ -4,6 +4,10 @@ import { Transport } from '@modelcontextprotocol/sdk/shared/transport.js'
import { OAuthCallbackServerOptions } from './types' import { OAuthCallbackServerOptions } from './types'
import express from 'express' import express from 'express'
import net from 'net' import net from 'net'
import crypto from 'crypto'
// Package version from package.json
export const MCP_REMOTE_VERSION = require('../../package.json').version
const pid = process.pid const pid = process.pid
export function log(str: string, ...rest: unknown[]) { export function log(str: string, ...rest: unknown[]) {
@ -65,12 +69,14 @@ export function mcpProxy({ transportToClient, transportToServer }: { transportTo
* @param serverUrl The URL of the remote server * @param serverUrl The URL of the remote server
* @param authProvider The OAuth client provider * @param authProvider The OAuth client provider
* @param waitForAuthCode Function to wait for the auth code * @param waitForAuthCode Function to wait for the auth code
* @param skipBrowserAuth Whether to skip browser auth and use shared auth
* @returns The connected SSE client transport * @returns The connected SSE client transport
*/ */
export async function connectToRemoteServer( export async function connectToRemoteServer(
serverUrl: string, serverUrl: string,
authProvider: OAuthClientProvider, authProvider: OAuthClientProvider,
waitForAuthCode: () => Promise<string>, waitForAuthCode: () => Promise<string>,
skipBrowserAuth: boolean = false,
): Promise<SSEClientTransport> { ): Promise<SSEClientTransport> {
log(`[${pid}] Connecting to remote server: ${serverUrl}`) log(`[${pid}] Connecting to remote server: ${serverUrl}`)
const url = new URL(serverUrl) const url = new URL(serverUrl)
@ -82,7 +88,11 @@ export async function connectToRemoteServer(
return transport return transport
} catch (error) { } catch (error) {
if (error instanceof UnauthorizedError || (error instanceof Error && error.message.includes('Unauthorized'))) { if (error instanceof UnauthorizedError || (error instanceof Error && error.message.includes('Unauthorized'))) {
log('Authentication required. Waiting for authorization...') if (skipBrowserAuth) {
log('Authentication required but skipping browser auth - using shared auth')
} else {
log('Authentication required. Waiting for authorization...')
}
// Wait for the authorization code from the callback // Wait for the authorization code from the callback
const code = await waitForAuthCode() const code = await waitForAuthCode()
@ -112,10 +122,57 @@ export async function connectToRemoteServer(
* @param options The server options * @param options The server options
* @returns An object with the server, authCode, and waitForAuthCode function * @returns An object with the server, authCode, and waitForAuthCode function
*/ */
export function setupOAuthCallbackServer(options: OAuthCallbackServerOptions) { export function setupOAuthCallbackServerWithLongPoll(options: OAuthCallbackServerOptions) {
let authCode: string | null = null let authCode: string | null = null
const app = express() const app = express()
// Create a promise to track when auth is completed
let authCompletedResolve: (code: string) => void
const authCompletedPromise = new Promise<string>((resolve) => {
authCompletedResolve = resolve
})
// Long-polling endpoint
app.get('/wait-for-auth', (req, res) => {
if (authCode) {
// Auth already completed - just return 200 without the actual code
// Secondary instances will read tokens from disk
log('Auth already completed, returning 200')
res.status(200).send('Authentication completed')
return
}
if (req.query.poll === 'false') {
log('Client requested no long poll, responding with 202')
res.status(202).send('Authentication in progress')
return
}
// Long poll - wait for up to 30 seconds
const longPollTimeout = setTimeout(() => {
log('Long poll timeout reached, responding with 202')
res.status(202).send('Authentication in progress')
}, 30000)
// If auth completes while we're waiting, send the response immediately
authCompletedPromise
.then(() => {
clearTimeout(longPollTimeout)
if (!res.headersSent) {
log('Auth completed during long poll, responding with 200')
res.status(200).send('Authentication completed')
}
})
.catch(() => {
clearTimeout(longPollTimeout)
if (!res.headersSent) {
log('Auth failed during long poll, responding with 500')
res.status(500).send('Authentication failed')
}
})
})
// OAuth callback endpoint
app.get(options.path, (req, res) => { app.get(options.path, (req, res) => {
const code = req.query.code as string | undefined const code = req.query.code as string | undefined
if (!code) { if (!code) {
@ -124,6 +181,9 @@ export function setupOAuthCallbackServer(options: OAuthCallbackServerOptions) {
} }
authCode = code authCode = code
log('Auth code received, resolving promise')
authCompletedResolve(code)
res.send('Authorization successful! You may close this window and return to the CLI.') res.send('Authorization successful! You may close this window and return to the CLI.')
// Notify main flow that auth code is available // Notify main flow that auth code is available
@ -134,10 +194,6 @@ export function setupOAuthCallbackServer(options: OAuthCallbackServerOptions) {
log(`OAuth callback server running at http://127.0.0.1:${options.port}`) log(`OAuth callback server running at http://127.0.0.1:${options.port}`)
}) })
/**
* Waits for the OAuth authorization code
* @returns A promise that resolves with the authorization code
*/
const waitForAuthCode = (): Promise<string> => { const waitForAuthCode = (): Promise<string> => {
return new Promise((resolve) => { return new Promise((resolve) => {
if (authCode) { if (authCode) {
@ -151,6 +207,16 @@ export function setupOAuthCallbackServer(options: OAuthCallbackServerOptions) {
}) })
} }
return { server, authCode, waitForAuthCode, authCompletedPromise }
}
/**
* Sets up an Express server to handle OAuth callbacks
* @param options The server options
* @returns An object with the server, authCode, and waitForAuthCode function
*/
export function setupOAuthCallbackServer(options: OAuthCallbackServerOptions) {
const { server, authCode, waitForAuthCode } = setupOAuthCallbackServerWithLongPoll(options)
return { server, authCode, waitForAuthCode } return { server, authCode, waitForAuthCode }
} }
@ -248,4 +314,11 @@ export function setupSignalHandlers(cleanup: () => Promise<void>) {
process.stdin.resume() process.stdin.resume()
} }
export const MCP_REMOTE_VERSION = require('../../package.json').version /**
* Generates a hash for the server URL to use in filenames
* @param serverUrl The server URL to hash
* @returns The hashed server URL
*/
export function getServerUrlHash(serverUrl: string): string {
return crypto.createHash('md5').update(serverUrl).digest('hex')
}

View file

@ -14,8 +14,9 @@
import { EventEmitter } from 'events' import { EventEmitter } from 'events'
import { StdioServerTransport } from '@modelcontextprotocol/sdk/server/stdio.js' import { StdioServerTransport } from '@modelcontextprotocol/sdk/server/stdio.js'
import { connectToRemoteServer, log, mcpProxy, parseCommandLineArgs, setupOAuthCallbackServer, setupSignalHandlers } from './lib/utils' import { connectToRemoteServer, log, mcpProxy, parseCommandLineArgs, setupSignalHandlers, getServerUrlHash } from './lib/utils'
import { NodeOAuthClientProvider } from './lib/node-oauth-client-provider' import { NodeOAuthClientProvider } from './lib/node-oauth-client-provider'
import { coordinateAuth } from './lib/coordination'
/** /**
* Main function to run the proxy * Main function to run the proxy
@ -24,6 +25,12 @@ async function runProxy(serverUrl: string, callbackPort: number, clean: boolean
// Set up event emitter for auth flow // Set up event emitter for auth flow
const events = new EventEmitter() const events = new EventEmitter()
// Get the server URL hash for lockfile operations
const serverUrlHash = getServerUrlHash(serverUrl)
// Coordinate authentication with other instances
const { server, waitForAuthCode, skipBrowserAuth } = await coordinateAuth(serverUrlHash, callbackPort, events)
// Create the OAuth client provider // Create the OAuth client provider
const authProvider = new NodeOAuthClientProvider({ const authProvider = new NodeOAuthClientProvider({
serverUrl, serverUrl,
@ -32,19 +39,20 @@ async function runProxy(serverUrl: string, callbackPort: number, clean: boolean
clean, clean,
}) })
// If auth was completed by another instance, just log that we'll use the auth from disk
if (skipBrowserAuth) {
log('Authentication was completed by another instance - will use tokens from disk')
// TODO: remove, the callback is happening before the tokens are exchanged
// so we're slightly too early
await new Promise((res) => setTimeout(res, 1_000))
}
// Create the STDIO transport for local connections // Create the STDIO transport for local connections
const localTransport = new StdioServerTransport() const localTransport = new StdioServerTransport()
// Set up an HTTP server to handle OAuth callback
const { server, waitForAuthCode } = setupOAuthCallbackServer({
port: callbackPort,
path: '/oauth/callback',
events,
})
try { try {
// Connect to remote server with authentication // Connect to remote server with authentication
const remoteTransport = await connectToRemoteServer(serverUrl, authProvider, waitForAuthCode) const remoteTransport = await connectToRemoteServer(serverUrl, authProvider, waitForAuthCode, skipBrowserAuth)
// Set up bidirectional proxy between local and remote transports // Set up bidirectional proxy between local and remote transports
mcpProxy({ mcpProxy({