Merge pull request #13 from geelen/no-multiball

Tried to stop multiple windows opening at once
This commit is contained in:
Glen Maddern 2025-03-31 22:35:44 +11:00 committed by GitHub
commit c6f98ff4b7
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 397 additions and 65 deletions

View file

@ -1,6 +1,6 @@
{
"name": "mcp-remote",
"version": "0.0.13",
"version": "0.0.15",
"description": "Remote proxy for Model Context Protocol, allowing local-only clients to connect to remote servers using oAuth",
"keywords": [
"mcp",
@ -10,7 +10,7 @@
"oauth"
],
"author": "Glen Maddern <glen@cloudflare.com>",
"repository": "https://github.com/geelen/remote-mcp",
"repository": "https://github.com/geelen/mcp-remote",
"type": "module",
"files": [
"dist",

View file

@ -18,7 +18,8 @@ import { SSEClientTransport } from '@modelcontextprotocol/sdk/client/sse.js'
import { ListResourcesResultSchema, ListToolsResultSchema } from '@modelcontextprotocol/sdk/types.js'
import { UnauthorizedError } from '@modelcontextprotocol/sdk/client/auth.js'
import { NodeOAuthClientProvider } from './lib/node-oauth-client-provider'
import { parseCommandLineArgs, setupOAuthCallbackServer, setupSignalHandlers, MCP_REMOTE_VERSION } from './lib/utils'
import { parseCommandLineArgs, setupSignalHandlers, log, MCP_REMOTE_VERSION, getServerUrlHash } from './lib/utils'
import { coordinateAuth } from './lib/coordination'
/**
* Main function to run the client
@ -27,6 +28,12 @@ async function runClient(serverUrl: string, callbackPort: number, clean: boolean
// Set up event emitter for auth flow
const events = new EventEmitter()
// Get the server URL hash for lockfile operations
const serverUrlHash = getServerUrlHash(serverUrl)
// Coordinate authentication with other instances
const { server, waitForAuthCode, skipBrowserAuth } = await coordinateAuth(serverUrlHash, callbackPort, events)
// Create the OAuth client provider
const authProvider = new NodeOAuthClientProvider({
serverUrl,
@ -35,6 +42,14 @@ async function runClient(serverUrl: string, callbackPort: number, clean: boolean
clean,
})
// If auth was completed by another instance, just log that we'll use the auth from disk
if (skipBrowserAuth) {
log('Authentication was completed by another instance - will use tokens from disk...')
// TODO: remove, the callback is happening before the tokens are exchanged
// so we're slightly too early
await new Promise((res) => setTimeout(res, 1_000))
}
// Create the client
const client = new Client(
{
@ -53,15 +68,15 @@ async function runClient(serverUrl: string, callbackPort: number, clean: boolean
// Set up message and error handlers
transport.onmessage = (message) => {
console.log('Received message:', JSON.stringify(message, null, 2))
log('Received message:', JSON.stringify(message, null, 2))
}
transport.onerror = (error) => {
console.error('Transport error:', error)
log('Transport error:', error)
}
transport.onclose = () => {
console.log('Connection closed.')
log('Connection closed.')
process.exit(0)
}
return transport
@ -69,16 +84,9 @@ async function runClient(serverUrl: string, callbackPort: number, clean: boolean
const transport = initTransport()
// Set up an HTTP server to handle OAuth callback
const { server, waitForAuthCode } = setupOAuthCallbackServer({
port: callbackPort,
path: '/oauth/callback',
events,
})
// Set up cleanup handler
const cleanup = async () => {
console.log('\nClosing connection...')
log('\nClosing connection...')
await client.close()
server.close()
}
@ -86,44 +94,44 @@ async function runClient(serverUrl: string, callbackPort: number, clean: boolean
// Try to connect
try {
console.log('Connecting to server...')
log('Connecting to server...')
await client.connect(transport)
console.log('Connected successfully!')
log('Connected successfully!')
} catch (error) {
if (error instanceof UnauthorizedError || (error instanceof Error && error.message.includes('Unauthorized'))) {
console.log('Authentication required. Waiting for authorization...')
log('Authentication required. Waiting for authorization...')
// Wait for the authorization code from the callback
// Wait for the authorization code from the callback or another instance
const code = await waitForAuthCode()
try {
console.log('Completing authorization...')
log('Completing authorization...')
await transport.finishAuth(code)
// Reconnect after authorization with a new transport
console.log('Connecting after authorization...')
log('Connecting after authorization...')
await client.connect(initTransport())
console.log('Connected successfully!')
log('Connected successfully!')
// Request tools list after auth
console.log('Requesting tools list...')
log('Requesting tools list...')
const tools = await client.request({ method: 'tools/list' }, ListToolsResultSchema)
console.log('Tools:', JSON.stringify(tools, null, 2))
log('Tools:', JSON.stringify(tools, null, 2))
// Request resources list after auth
console.log('Requesting resource list...')
log('Requesting resource list...')
const resources = await client.request({ method: 'resources/list' }, ListResourcesResultSchema)
console.log('Resources:', JSON.stringify(resources, null, 2))
log('Resources:', JSON.stringify(resources, null, 2))
console.log('Listening for messages. Press Ctrl+C to exit.')
log('Listening for messages. Press Ctrl+C to exit.')
} catch (authError) {
console.error('Authorization error:', authError)
log('Authorization error:', authError)
server.close()
process.exit(1)
}
} else {
console.error('Connection error:', error)
log('Connection error:', error)
server.close()
process.exit(1)
}
@ -131,23 +139,23 @@ async function runClient(serverUrl: string, callbackPort: number, clean: boolean
try {
// Request tools list
console.log('Requesting tools list...')
log('Requesting tools list...')
const tools = await client.request({ method: 'tools/list' }, ListToolsResultSchema)
console.log('Tools:', JSON.stringify(tools, null, 2))
log('Tools:', JSON.stringify(tools, null, 2))
} catch (e) {
console.log('Error requesting tools list:', e)
log('Error requesting tools list:', e)
}
try {
// Request resources list
console.log('Requesting resource list...')
log('Requesting resource list...')
const resources = await client.request({ method: 'resources/list' }, ListResourcesResultSchema)
console.log('Resources:', JSON.stringify(resources, null, 2))
log('Resources:', JSON.stringify(resources, null, 2))
} catch (e) {
console.log('Error requesting resources list:', e)
log('Error requesting resources list:', e)
}
console.log('Listening for messages. Press Ctrl+C to exit.')
log('Listening for messages. Press Ctrl+C to exit.')
}
// Parse command-line arguments and run the client

188
src/lib/coordination.ts Normal file
View file

@ -0,0 +1,188 @@
import { checkLockfile, createLockfile, deleteLockfile, getConfigFilePath, LockfileData } from './mcp-auth-config'
import { EventEmitter } from 'events'
import { Server } from 'http'
import express from 'express'
import { AddressInfo } from 'net'
import { log, setupOAuthCallbackServerWithLongPoll } from './utils'
/**
* Checks if a process with the given PID is running
* @param pid The process ID to check
* @returns True if the process is running, false otherwise
*/
export async function isPidRunning(pid: number): Promise<boolean> {
try {
process.kill(pid, 0) // Doesn't kill the process, just checks if it exists
return true
} catch {
return false
}
}
/**
* Checks if a lockfile is valid (process running and endpoint accessible)
* @param lockData The lockfile data
* @returns True if the lockfile is valid, false otherwise
*/
export async function isLockValid(lockData: LockfileData): Promise<boolean> {
// Check if the lockfile is too old (over 30 minutes)
const MAX_LOCK_AGE = 30 * 60 * 1000 // 30 minutes
if (Date.now() - lockData.timestamp > MAX_LOCK_AGE) {
log('Lockfile is too old')
return false
}
// Check if the process is still running
if (!(await isPidRunning(lockData.pid))) {
log('Process from lockfile is not running')
return false
}
// Check if the endpoint is accessible
try {
const controller = new AbortController()
const timeout = setTimeout(() => controller.abort(), 1000)
const response = await fetch(`http://127.0.0.1:${lockData.port}/wait-for-auth?poll=false`, {
signal: controller.signal,
})
clearTimeout(timeout)
return response.status === 200 || response.status === 202
} catch (error) {
log(`Error connecting to auth server: ${(error as Error).message}`)
return false
}
}
/**
* Waits for authentication from another server instance
* @param port The port to connect to
* @returns True if authentication completed successfully, false otherwise
*/
export async function waitForAuthentication(port: number): Promise<boolean> {
log(`Waiting for authentication from the server on port ${port}...`)
try {
while (true) {
const url = `http://127.0.0.1:${port}/wait-for-auth`
log(`Querying: ${url}`)
const response = await fetch(url)
if (response.status === 200) {
// Auth completed, but we don't return the code anymore
log(`Authentication completed by other instance`)
return true
} else if (response.status === 202) {
// Continue polling
log(`Authentication still in progress`)
await new Promise(resolve => setTimeout(resolve, 1000))
} else {
log(`Unexpected response status: ${response.status}`)
return false
}
}
} catch (error) {
log(`Error waiting for authentication: ${(error as Error).message}`)
return false
}
}
/**
* Coordinates authentication between multiple instances of the client/proxy
* @param serverUrlHash The hash of the server URL
* @param callbackPort The port to use for the callback server
* @param events The event emitter to use for signaling
* @returns An object with the server, waitForAuthCode function, and a flag indicating if browser auth can be skipped
*/
export async function coordinateAuth(
serverUrlHash: string,
callbackPort: number,
events: EventEmitter,
): Promise<{ server: Server; waitForAuthCode: () => Promise<string>; skipBrowserAuth: boolean }> {
// Check for a lockfile
const lockData = await checkLockfile(serverUrlHash)
// If there's a valid lockfile, try to use the existing auth process
if (lockData && (await isLockValid(lockData))) {
log(`Another instance is handling authentication on port ${lockData.port}`)
try {
// Try to wait for the authentication to complete
const authCompleted = await waitForAuthentication(lockData.port)
if (authCompleted) {
log('Authentication completed by another instance')
// Setup a dummy server - the client will use tokens directly from disk
const dummyServer = express().listen(0) // Listen on any available port
// This shouldn't actually be called in normal operation, but provide it for API compatibility
const dummyWaitForAuthCode = () => {
log('WARNING: waitForAuthCode called in secondary instance - this is unexpected')
// Return a promise that never resolves - the client should use the tokens from disk instead
return new Promise<string>(() => {})
}
return {
server: dummyServer,
waitForAuthCode: dummyWaitForAuthCode,
skipBrowserAuth: true,
}
} else {
log('Taking over authentication process...')
}
} catch (error) {
log(`Error waiting for authentication: ${error}`)
}
// If we get here, the other process didn't complete auth successfully
await deleteLockfile(serverUrlHash)
} else if (lockData) {
// Invalid lockfile, delete its
log('Found invalid lockfile, deleting it')
await deleteLockfile(serverUrlHash)
}
// Create our own lockfile
const { server, waitForAuthCode, authCompletedPromise } = setupOAuthCallbackServerWithLongPoll({
port: callbackPort,
path: '/oauth/callback',
events,
})
// Get the actual port the server is running on
const address = server.address() as AddressInfo
const actualPort = address.port
log(`Creating lockfile for server ${serverUrlHash} with process ${process.pid} on port ${actualPort}`)
await createLockfile(serverUrlHash, process.pid, actualPort)
// Make sure lockfile is deleted on process exit
const cleanupHandler = async () => {
try {
log(`Cleaning up lockfile for server ${serverUrlHash}`)
await deleteLockfile(serverUrlHash)
} catch (error) {
log(`Error cleaning up lockfile: ${error}`)
}
}
process.once('exit', () => {
try {
// Synchronous version for 'exit' event since we can't use async here
const configPath = getConfigFilePath(serverUrlHash, 'lock.json')
require('fs').unlinkSync(configPath)
} catch {}
})
// Also handle SIGINT separately
process.once('SIGINT', async () => {
await cleanupHandler()
})
return {
server,
waitForAuthCode,
skipBrowserAuth: false
}
}

View file

@ -1,4 +1,3 @@
import crypto from 'crypto'
import path from 'path'
import os from 'os'
import fs from 'fs/promises'
@ -27,7 +26,61 @@ import { log, MCP_REMOTE_VERSION } from './utils'
/**
* Known configuration file names that might need to be cleaned
*/
export const knownConfigFiles = ['client_info.json', 'tokens.json', 'code_verifier.txt']
export const knownConfigFiles = ['client_info.json', 'tokens.json', 'code_verifier.txt', 'lock.json']
/**
* Lockfile data structure
*/
export interface LockfileData {
pid: number
port: number
timestamp: number
}
/**
* Creates a lockfile for the given server
* @param serverUrlHash The hash of the server URL
* @param pid The process ID
* @param port The port the server is running on
*/
export async function createLockfile(serverUrlHash: string, pid: number, port: number): Promise<void> {
const lockData: LockfileData = {
pid,
port,
timestamp: Date.now(),
}
await writeJsonFile(serverUrlHash, 'lock.json', lockData)
}
/**
* Checks if a lockfile exists for the given server
* @param serverUrlHash The hash of the server URL
* @returns The lockfile data or null if it doesn't exist
*/
export async function checkLockfile(serverUrlHash: string): Promise<LockfileData | null> {
try {
const lockfile = await readJsonFile<LockfileData>(serverUrlHash, 'lock.json', {
async parseAsync(data: any) {
if (typeof data !== 'object' || data === null) return null
if (typeof data.pid !== 'number' || typeof data.port !== 'number' || typeof data.timestamp !== 'number') {
return null
}
return data as LockfileData
},
})
return lockfile || null
} catch {
return null
}
}
/**
* Deletes the lockfile for the given server
* @param serverUrlHash The hash of the server URL
*/
export async function deleteLockfile(serverUrlHash: string): Promise<void> {
await deleteConfigFile(serverUrlHash, 'lock.json')
}
/**
* Deletes all known configuration files for a specific server
@ -63,15 +116,6 @@ export async function ensureConfigDir(): Promise<void> {
}
}
/**
* Generates a hash for the server URL to use in filenames
* @param serverUrl The server URL to hash
* @returns The hashed server URL
*/
export function getServerUrlHash(serverUrl: string): string {
return crypto.createHash('md5').update(serverUrl).digest('hex')
}
/**
* Gets the file path for a config file
* @param serverUrlHash The hash of the server URL
@ -125,11 +169,15 @@ export async function readJsonFile<T>(
const filePath = getConfigFilePath(serverUrlHash, filename)
const content = await fs.readFile(filePath, 'utf-8')
return await schema.parseAsync(JSON.parse(content))
const result = await schema.parseAsync(JSON.parse(content))
// console.log({ filename: result })
return result
} catch (error) {
if ((error as NodeJS.ErrnoException).code === 'ENOENT') {
// console.log(`File ${filename} does not exist`)
return undefined
}
log(`Error reading ${filename}:`, error)
return undefined
}
}

View file

@ -8,8 +8,8 @@ import {
OAuthTokensSchema,
} from '@modelcontextprotocol/sdk/shared/auth.js'
import type { OAuthProviderOptions } from './types'
import { getServerUrlHash, readJsonFile, writeJsonFile, readTextFile, writeTextFile, cleanServerConfig } from './mcp-auth-config'
import { log } from './utils'
import { readJsonFile, writeJsonFile, readTextFile, writeTextFile, cleanServerConfig } from './mcp-auth-config'
import { getServerUrlHash, log } from './utils'
/**
* Implements the OAuthClientProvider interface for Node.js environments.
@ -59,6 +59,7 @@ export class NodeOAuthClientProvider implements OAuthClientProvider {
* @returns The client information or undefined
*/
async clientInformation(): Promise<OAuthClientInformation | undefined> {
// log('Reading client info')
return readJsonFile<OAuthClientInformation>(this.serverUrlHash, 'client_info.json', OAuthClientInformationSchema, this.options.clean)
}
@ -67,6 +68,7 @@ export class NodeOAuthClientProvider implements OAuthClientProvider {
* @param clientInformation The client information to save
*/
async saveClientInformation(clientInformation: OAuthClientInformationFull): Promise<void> {
// log('Saving client info')
await writeJsonFile(this.serverUrlHash, 'client_info.json', clientInformation)
}
@ -75,6 +77,8 @@ export class NodeOAuthClientProvider implements OAuthClientProvider {
* @returns The OAuth tokens or undefined
*/
async tokens(): Promise<OAuthTokens | undefined> {
// log('Reading tokens')
// console.log(new Error().stack)
return readJsonFile<OAuthTokens>(this.serverUrlHash, 'tokens.json', OAuthTokensSchema, this.options.clean)
}
@ -83,6 +87,7 @@ export class NodeOAuthClientProvider implements OAuthClientProvider {
* @param tokens The tokens to save
*/
async saveTokens(tokens: OAuthTokens): Promise<void> {
// log('Saving tokens')
await writeJsonFile(this.serverUrlHash, 'tokens.json', tokens)
}
@ -105,6 +110,7 @@ export class NodeOAuthClientProvider implements OAuthClientProvider {
* @param codeVerifier The code verifier to save
*/
async saveCodeVerifier(codeVerifier: string): Promise<void> {
// log('Saving code verifier')
await writeTextFile(this.serverUrlHash, 'code_verifier.txt', codeVerifier)
}
@ -113,6 +119,7 @@ export class NodeOAuthClientProvider implements OAuthClientProvider {
* @returns The code verifier
*/
async codeVerifier(): Promise<string> {
// log('Reading code verifier')
return await readTextFile(this.serverUrlHash, 'code_verifier.txt', 'No code verifier saved for session', this.options.clean)
}
}

View file

@ -4,6 +4,10 @@ import { Transport } from '@modelcontextprotocol/sdk/shared/transport.js'
import { OAuthCallbackServerOptions } from './types'
import express from 'express'
import net from 'net'
import crypto from 'crypto'
// Package version from package.json
export const MCP_REMOTE_VERSION = require('../../package.json').version
const pid = process.pid
export function log(str: string, ...rest: unknown[]) {
@ -65,12 +69,14 @@ export function mcpProxy({ transportToClient, transportToServer }: { transportTo
* @param serverUrl The URL of the remote server
* @param authProvider The OAuth client provider
* @param waitForAuthCode Function to wait for the auth code
* @param skipBrowserAuth Whether to skip browser auth and use shared auth
* @returns The connected SSE client transport
*/
export async function connectToRemoteServer(
serverUrl: string,
authProvider: OAuthClientProvider,
waitForAuthCode: () => Promise<string>,
skipBrowserAuth: boolean = false,
): Promise<SSEClientTransport> {
log(`[${pid}] Connecting to remote server: ${serverUrl}`)
const url = new URL(serverUrl)
@ -82,7 +88,11 @@ export async function connectToRemoteServer(
return transport
} catch (error) {
if (error instanceof UnauthorizedError || (error instanceof Error && error.message.includes('Unauthorized'))) {
if (skipBrowserAuth) {
log('Authentication required but skipping browser auth - using shared auth')
} else {
log('Authentication required. Waiting for authorization...')
}
// Wait for the authorization code from the callback
const code = await waitForAuthCode()
@ -112,10 +122,57 @@ export async function connectToRemoteServer(
* @param options The server options
* @returns An object with the server, authCode, and waitForAuthCode function
*/
export function setupOAuthCallbackServer(options: OAuthCallbackServerOptions) {
export function setupOAuthCallbackServerWithLongPoll(options: OAuthCallbackServerOptions) {
let authCode: string | null = null
const app = express()
// Create a promise to track when auth is completed
let authCompletedResolve: (code: string) => void
const authCompletedPromise = new Promise<string>((resolve) => {
authCompletedResolve = resolve
})
// Long-polling endpoint
app.get('/wait-for-auth', (req, res) => {
if (authCode) {
// Auth already completed - just return 200 without the actual code
// Secondary instances will read tokens from disk
log('Auth already completed, returning 200')
res.status(200).send('Authentication completed')
return
}
if (req.query.poll === 'false') {
log('Client requested no long poll, responding with 202')
res.status(202).send('Authentication in progress')
return
}
// Long poll - wait for up to 30 seconds
const longPollTimeout = setTimeout(() => {
log('Long poll timeout reached, responding with 202')
res.status(202).send('Authentication in progress')
}, 30000)
// If auth completes while we're waiting, send the response immediately
authCompletedPromise
.then(() => {
clearTimeout(longPollTimeout)
if (!res.headersSent) {
log('Auth completed during long poll, responding with 200')
res.status(200).send('Authentication completed')
}
})
.catch(() => {
clearTimeout(longPollTimeout)
if (!res.headersSent) {
log('Auth failed during long poll, responding with 500')
res.status(500).send('Authentication failed')
}
})
})
// OAuth callback endpoint
app.get(options.path, (req, res) => {
const code = req.query.code as string | undefined
if (!code) {
@ -124,6 +181,9 @@ export function setupOAuthCallbackServer(options: OAuthCallbackServerOptions) {
}
authCode = code
log('Auth code received, resolving promise')
authCompletedResolve(code)
res.send('Authorization successful! You may close this window and return to the CLI.')
// Notify main flow that auth code is available
@ -134,10 +194,6 @@ export function setupOAuthCallbackServer(options: OAuthCallbackServerOptions) {
log(`OAuth callback server running at http://127.0.0.1:${options.port}`)
})
/**
* Waits for the OAuth authorization code
* @returns A promise that resolves with the authorization code
*/
const waitForAuthCode = (): Promise<string> => {
return new Promise((resolve) => {
if (authCode) {
@ -151,6 +207,16 @@ export function setupOAuthCallbackServer(options: OAuthCallbackServerOptions) {
})
}
return { server, authCode, waitForAuthCode, authCompletedPromise }
}
/**
* Sets up an Express server to handle OAuth callbacks
* @param options The server options
* @returns An object with the server, authCode, and waitForAuthCode function
*/
export function setupOAuthCallbackServer(options: OAuthCallbackServerOptions) {
const { server, authCode, waitForAuthCode } = setupOAuthCallbackServerWithLongPoll(options)
return { server, authCode, waitForAuthCode }
}
@ -248,4 +314,11 @@ export function setupSignalHandlers(cleanup: () => Promise<void>) {
process.stdin.resume()
}
export const MCP_REMOTE_VERSION = require('../../package.json').version
/**
* Generates a hash for the server URL to use in filenames
* @param serverUrl The server URL to hash
* @returns The hashed server URL
*/
export function getServerUrlHash(serverUrl: string): string {
return crypto.createHash('md5').update(serverUrl).digest('hex')
}

View file

@ -14,8 +14,9 @@
import { EventEmitter } from 'events'
import { StdioServerTransport } from '@modelcontextprotocol/sdk/server/stdio.js'
import { connectToRemoteServer, log, mcpProxy, parseCommandLineArgs, setupOAuthCallbackServer, setupSignalHandlers } from './lib/utils'
import { connectToRemoteServer, log, mcpProxy, parseCommandLineArgs, setupSignalHandlers, getServerUrlHash } from './lib/utils'
import { NodeOAuthClientProvider } from './lib/node-oauth-client-provider'
import { coordinateAuth } from './lib/coordination'
/**
* Main function to run the proxy
@ -24,6 +25,12 @@ async function runProxy(serverUrl: string, callbackPort: number, clean: boolean
// Set up event emitter for auth flow
const events = new EventEmitter()
// Get the server URL hash for lockfile operations
const serverUrlHash = getServerUrlHash(serverUrl)
// Coordinate authentication with other instances
const { server, waitForAuthCode, skipBrowserAuth } = await coordinateAuth(serverUrlHash, callbackPort, events)
// Create the OAuth client provider
const authProvider = new NodeOAuthClientProvider({
serverUrl,
@ -32,19 +39,20 @@ async function runProxy(serverUrl: string, callbackPort: number, clean: boolean
clean,
})
// If auth was completed by another instance, just log that we'll use the auth from disk
if (skipBrowserAuth) {
log('Authentication was completed by another instance - will use tokens from disk')
// TODO: remove, the callback is happening before the tokens are exchanged
// so we're slightly too early
await new Promise((res) => setTimeout(res, 1_000))
}
// Create the STDIO transport for local connections
const localTransport = new StdioServerTransport()
// Set up an HTTP server to handle OAuth callback
const { server, waitForAuthCode } = setupOAuthCallbackServer({
port: callbackPort,
path: '/oauth/callback',
events,
})
try {
// Connect to remote server with authentication
const remoteTransport = await connectToRemoteServer(serverUrl, authProvider, waitForAuthCode)
const remoteTransport = await connectToRemoteServer(serverUrl, authProvider, waitForAuthCode, skipBrowserAuth)
// Set up bidirectional proxy between local and remote transports
mcpProxy({